Fixed notation issues in the “Faster multiset updates” section. Thank you Joonas.
Let’s say you have a multiset (bag) of “reals” (floats or rationals), where each value is a sampled observations. It’s easy to augment any implementation of the multiset ADT to also return the sample mean of the values in the multiset in constant time: track the sum of values in the multiset, as they are individually added and removed. This requires one accumulator and a counter for the number of observations in the multiset (i.e., constant space), and adds a constant time overhead to each update.
It’s not as simple when you also need the sample variance of the multiset \(X\), i.e.,
\[\frac{1}{n - 1} \sum\sb{x \in X} (x - \hat{x})\sp{2},\]
where \(n = |X|\) is the sample size and \(\hat{x}\) is the sample mean \(\sum\sb{x\in X} x/n,\) ideally with constant query time, and constant and update time overhead.
One could try to apply the textbook equality
\[s\sp{2} = \frac{1}{n(n-1)}\left[n\sum\sb{x\in X} x\sp{2} - \left(\sum\sb{x\in X} x\right)\sp{2}\right].\]
However, as Knuth notes in TAoCP volume 2,
this expression loses a lot of precision to round-off in floating point:
in extreme cases, the difference might be negative
(and we know the variance is never negative).
More commonly, we’ll lose precision
when the sampled values are clustered around a large mean.
For example, the sample standard deviation of 1e8
and 1e8 - 1
is 1
, same as for 0
and 1
.
However, the expression above would evaluate that to 0.0
, even in double precision:
while 1e8
is comfortably within range for double floats,
its square 1e16
is outside the range where all integers are represented exactly.
Knuth refers to a better behaved recurrence by Welford, where
a running sample mean is subtracted from each new observation
before squaring.
John Cook has a C++
implementation
of the recurrence that adds observations to a sample variance in constant time.
In Python, this streaming algorithm looks like this.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
|
That’s all we need for insert-only multisets, but does not handle removals; if only we had removals, we could always implement updates (replacement) as a removal and an insertion.
Luckily, StreamingVariance.observe
looks invertible.
It’s shouldn’t be hard to recover the previous sample mean, given v
,
and, given the current and previous sample means,
we can re-evaluate (v - old_mean) * (v - self.mean)
and
subtract it from self.var_sum
.
Let \(\hat{x}\sp{\prime}\) be the sample mean after observe(v)
.
We can derive the previous sample mean \(\hat{x}\) from \(v\):
\[(n - 1)\hat{x} = n\hat{x}\sp{\prime} - v \Leftrightarrow \hat{x} = \hat{x}\sp{\prime} + \frac{\hat{x}\sp{\prime} - v}{n-1}.\]
This invertibility means that we can undo calls to observe
in
LIFO order. We can’t handle arbitrary multiset updates, only a
stack of observation. That’s still better than nothing.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
|
Before going any further, let’s test this.
Testing the VarianceStack
The best way to test the VarianceStack
is to execute a series of
push
and pop
calls, and compare the results of get_mean
and
get_variance
with batch reference implementations.
I could hardcode calls in unit tests. However, that quickly hits diminishing returns in terms of marginal coverage VS developer time. Instead, I’ll be lazy, completely skip unit tests, and rely on Hypothesis, its high level “stateful” testing API in particular.
We’ll keep track of the values pushed and popped off the observation stack in the driver: we must make sure they’re matched in LIFO order, and we need the stack’s contents to compute the reference mean and variance. We’ll also want to compare the results with reference implementations, modulo some numerical noise. Let’s try to be aggressive and bound the number of float values between the reference and the actual results.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
|
This initial driver does not even use the VarianceStack
yet.
All it does is push values to the reference stack,
pop values when the stack has something to pop,
and check that the reference implementations match themselves after each call:
I want to first shake out any bug in the test harness itself.
Not surprisingly, Hypothesis does find an issue in the reference implementation:
Falsifying example:
state = VarianceStackDriver()
state.push(v=0.0)
state.push(v=2.6815615859885194e+154)
state.teardown()
We get a numerical OverflowError
in reference_variance
: 2.68...e154 / 2
is slightly greater than sqrt(sys.float_info.max) = 1.3407807929942596e+154
,
so taking the square of that value errors out instead of returning infinity.
Let’s start by clamping the range of the generated floats.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
|
Now that the test harness doesn’t find fault in itself,
let’s hook in the VarianceStack
, and see what happens
when only push
calls are generated (i.e., first test
only the standard streaming variance algorithm).
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
|
This already fails horribly.
Falsifying example:
state = VarianceStackDriver()
state.push(v=1.0)
state.push(v=1.488565707357403e+138)
state.teardown()
F
The reference finds a variance of 5.54e275
,
which is very much not the streaming computation’s 1.108e276
.
We can manually check that the reference is wrong:
it’s missing the n - 1
correction term in the denominator.
We should use this updated reference.
1 2 3 4 5 6 7 8 9 |
|
Let’s now re-enable calls to pop()
.
1 2 3 4 5 6 7 8 |
|
And now things fail in new and excitingly numerical ways.
Falsifying example:
state = VarianceStackDriver()
state.push(v=0.0)
state.push(v=0.00014142319560050964)
state.push(v=14188.9609375)
state.pop()
state.teardown()
F
This counter-example fails with the online variance returning 0.0
instead of 1e-8
.
That’s not unexpected:
removing (the square of) a large value from a running sum
spells catastrophic cancellation.
It’s also not that bad for my use case,
where I don’t expect to observe very large values.
Another problem for our test harness is that
floats are very dense around 0.0
, and
I’m ok with small (around 1e-8
) absolute error
because the input and output will be single floats.
Let’s relax assert_almost_equal
, and
restrict generated observations to fall
in \([-2\sp{-12}, 2\sp{12}].\)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
|
With all these tweaks to make sure we generate easy (i.e., interesting) test cases, Hypothesis fails to find a failure after its default time budget.
I’m willing to call that a victory.
From stack to full multiset
We have tested code to undo updates in Welford’s classic streaming variance algorithm.
Unfortunately, inverting push
es away only works for LIFO edits,
and we’re looking for arbitrary inserts and removals (and updates) to a multiset
of observations.
However, both the mean \(\hat{x} = \sum\sb{x\in X} x/n\) and the centered second moment \(\sum\sb{x\in X}(x - \hat{x})\sp{2}\) are order-independent: they’re just sums over all observations. Disregarding round-off, we’ll find the same mean and second moment regardless of the order in which the observations were pushed in. Thus, whenever we wish to remove an observation from the multiset, we can assume it was the last one added to the estimates, and pop it off.
We think we know how to implement running mean and variance for a multiset of observations. How do we test that with Hypothesis?
The hardest part about testing dictionary (map)-like interfaces is making sure to generate valid identifiers when removing values. As it turns out, Hypothesis has built-in support for this important use case, with its Bundles. We’ll use that to test a dictionary from observation name to observation value, augmented to keep track of the current mean and variance of all values.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
|
Each call to add_entry
will either go to update_entry
if
the key already exists, or add an observation to the dictionary
and streaming estimator. If we have a new key, it is added
to the keys
Bundle; calls to del_entry
and update_entry
draw keys from this Bundle. When we remove an entry, it’s
also consumed from the keys
Bundle.
Hypothesis finds no fault with our new implementation of dictionary-with-variance,
but update
seems like it could be much faster and numerically stable,
and I intend to mostly use this data structure for calls to update
.
Faster multiset updates
The key operation for my use-case is to update one observation
by replacing its old
value with a new
one.
We can maintain the estimator by popping old
away and pushing new
in,
but this business with updating the number of observation n
and
rescaling everything seems like a lot of numerical trouble.
We should be able to do better.
We’re replacing the multiset of sampled observations \(X\) with \(X\sp{\prime} = X \setminus \{\textrm{old}\} \cup \{\textrm{new}\}.\) It’s easy to maintain the mean after this update: \(\hat{x}\sp{\prime} = \hat{x} + (\textrm{new} - \textrm{old})/n.\)
The update to self.var_sum
, the sum of squared differences from the mean, is trickier.
We start with \(v = \sum\sb{x\in X} (x - \hat{x})\sp{2},\)
and we wish to find \(v\sp{\prime} = \sum\sb{x\sp{\prime}\in X\sp{\prime}} (x\sp{\prime} - \hat{x}\sp{\prime})\sp{2}.\)
Let \(\delta = \textrm{new} - \textrm{old}\) and \(\delta\sb{\hat{x}} = \delta/n.\) We have \[\sum\sb{x\in X} (x - \hat{x}\sp{\prime})\sp{2} = \sum\sb{x\in X} [(x - \hat{x}) - \delta\sb{\hat{x}}]\sp{2},\] and \[[(x - \hat{x}) - \delta\sb{\hat{x}}]\sp{2} = (x - \hat{x})\sp{2} - 2\delta\sb{\hat{x}} (x - \hat{x}) + \delta\sb{\hat{x}}\sp{2}.\]
We can reassociate the sum, and find
\[\sum\sb{x\in X} (x - \hat{x}\sp{\prime})\sp{2} = \sum\sb{x\in X} (x - \hat{x})\sp{2} - 2\delta\sb{\hat{x}} \left(\sum\sb{x \in X} x - \hat{x}\right) + n \delta\sb{\hat{x}}\sp{2}\]
Once we notice that \(\hat{x} = \sum\sb{x\in X} x/n,\) it’s clear that the middle term sums to zero, and we find the very reasonable
\[v\sb{\hat{x}\sp{\prime}} = \sum\sb{x\in X} (x - \hat{x})\sp{2} + n \delta\sb{\hat{x}}\sp{2} = v + \delta \delta\sb{\hat{x}}.\]
This new accumulator \(v\sb{\hat{x}\sp{\prime}}\) corresponds to the sum of the
squared differences between the old observations \(X\) and the new mean \(\hat{x}\sp{\prime}\).
We still have to update one observation from old
to new
.
The remaining adjustment to \(v\) (self.var_sum
) corresponds to
going from \((\textrm{old} - \hat{x}\sp{\prime})\sp{2}\)
to \((\textrm{new} - \hat{x}\sp{\prime})\sp{2},\)
where \(\textrm{new} = \textrm{old} + \delta.\)
After a bit of algebra, we get \[(\textrm{new} - \hat{x}\sp{\prime})\sp{2} = [(\textrm{old} - \hat{x}\sp{\prime}) + \delta]\sp{2} = (\textrm{old} - \hat{x}\sp{\prime})\sp{2} + \delta (\textrm{old} - \hat{x} + \textrm{new} - \hat{x}\sp{\prime}).\]
The adjusted \(v\sb{\hat{x}\sp{\prime}}\) already includes
\((\textrm{old} - \hat{x}\sp{\prime})\sp{2}\)
in its sum, so we only have to add the last term
to obtain the final updated self.var_sum
\[v\sp{\prime} = v\sb{\hat{x}\sp{\prime}} + \delta (\textrm{old} - \hat{x} + \textrm{new} - \hat{x}\sp{\prime}) = v + \delta [2 (\textrm{old} - \hat{x}) + \textrm{new} - \hat{x}\sp{\prime}].\]
That’s our final implementation for VarianceBag.update
,
for which Hypothesis also fails to find failures.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
|
How much do you trust testing?
We have automated property-based tests and some human-checked proofs. Ship it?
I was initially going to ask a CAS
to check my reformulations,
but the implicit \(\forall\) looked messy.
Instead, I decided to check the induction hypothesis implicit in
VarianceBag.update
, and enumerate all cases up to a certain number
of values with Z3 in IPython.
In [1]: from z3 import *
In [2]: x, y, z, new_x = Reals("x y z new_x")
In [3]: mean = (x + y + z) / 3
In [4]: var_sum = sum((v - mean) * (v - mean) for v in (x, y, z))
In [5]: delta = new_x - x
In [6]: new_mean = mean + delta / 3
In [7]: delta_mean = delta / 3
In [8]: adjustment = delta * (2 * (x - mean) + (delta - delta_mean))
In [9]: new_var_sum = var_sum + adjustment
# We have our expressions. Let's check equivalence for mean, then var_sum
In [10]: s = Solver()
In [11]: s.push()
In [12]: s.add(new_mean != (new_x + y + z) / 3)
In [13]: s.check()
Out[13]: unsat # No counter example of size 3 for the updated mean
In [14]: s.pop()
In [15]: s.push()
In [16]: s.add(new_mean == (new_x + y + z) / 3) # We know the mean matches
In [17]: s.add(new_var_sum != sum((v - new_mean) * (v - new_mean) for v in (new_x, y, z)))
In [18]: s.check()
Out[18]: unsat # No counter example of size 3 for the updated variance
Given this script, it’s a small matter of programming to generalise
from 3 values (x
, y
, and z
) to any fixed number of values, and
generate all small cases up to, e.g., 10 values.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
|
I find the most important thing when it comes to using automated proofs is to insert errors and confirm we can find the bugs we’re looking for.
I did that by manually mutating the expressions for new_mean
and new_var_sum
in updated_expressions
. This let me find a simple bug in the initial
implementation of test_num_var
: I used if not result
instead of result != unsat
,
and both sat
and unsat
are truthy. The code initially failed to flag a failure
when z3
found a counter-example for our correctness condition!
And now I’m satisfied
I have code to augment an arbitrary multiset or dictionary with a running estimate of the mean and variance; that code is based on a classic recurrence, with some new math checked by hand, with automated tests, and with some exhaustive checking of small inputs (to which I claim most bugs can be reduced).
I’m now pretty sure the code works, but there’s another more obviously correct way to solve that update problem. This 2008 report by Philippe Pébay1 presents formulas to compute the mean, variance, and arbitrary moments in one pass, and shows how to combine accumulators, a useful operation in parallel computing.
We could use these formulas to augment an arbitrary \(k\)-ary tree and re-combine the merged accumulator as we go back up the (search) tree from the modified leaf to the root. The update would be much more stable (we only add and merge observations), and incur logarithmic time overhead (with linear space overhead). However, given the same time budget, and a logarithmic space overhead, we could also implement the constant-time update with arbitrary precision software floats, and probably guarantee even better precision.
The constant-time update I described in this post demanded more effort to convince myself of its correctness, but I think it’s always a better option than an augmented tree for serial code, especially if initial values are available to populate the accumulators with batch-computed mean and variance. I’m pretty sure the code works, and it’s up in this gist. I’ll be re-implementing it in C++ because that’s the language used by the project that lead me to this problem; feel free to steal that gist.
-
There’s also a 2016 journal article by Pébay and others with numerical experiments, but I failed to implement their simpler-looking scalar update… ↩