From 34a8aadb35e4cba0677c58086e442ef103ba9813 Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Tue, 16 May 2017 17:47:00 +0200 Subject: [PATCH] Implement merge for weighted average --- src/weighted_average.rs | 64 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/src/weighted_average.rs b/src/weighted_average.rs index ff50dc7..5c6fe08 100644 --- a/src/weighted_average.rs +++ b/src/weighted_average.rs @@ -118,6 +118,36 @@ impl WeightedAverage { }; (variance / self.weight_sum).sqrt() } + + /// Merge the weighted average of another sequence into this one. + /// + /// ``` + /// use average::WeightedAverage; + /// + /// let weighted_sequence: &[(f64, f64)] = &[ + /// (1., 0.1), (2., 0.2), (3., 0.3), (4., 0.4), (5., 0.5), + /// (6., 0.6), (7., 0.7), (8., 0.8), (9., 0.)]; + /// let (left, right) = weighted_sequence.split_at(3); + /// let avg_total: WeightedAverage = weighted_sequence.iter().map(|&x| x).collect(); + /// let mut avg_left: WeightedAverage = left.iter().map(|&x| x).collect(); + /// let avg_right: WeightedAverage = right.iter().map(|&x| x).collect(); + /// avg_left.merge(&avg_right); + /// assert!((avg_total.mean() - avg_left.mean()).abs() < 1e-15); + /// assert!((avg_total.sample_variance() - avg_left.sample_variance()).abs() < 1e-15); + /// ``` + pub fn merge(&mut self, other: &WeightedAverage) { + // This is similar to the algorithm proposed by Chan et al. in 1979. + // + // See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. + let delta = other.avg - self.avg; + let total_weight_sum = self.weight_sum + other.weight_sum; + self.avg = (self.weight_sum * self.avg + other.weight_sum * other.avg) + / (self.weight_sum + other.weight_sum); + self.v += other.v + delta*delta * self.weight_sum * other.weight_sum + / total_weight_sum; + self.weight_sum = total_weight_sum; + self.weight_sum_sq += other.weight_sum_sq; + } } impl core::default::Default for WeightedAverage { @@ -194,4 +224,38 @@ mod tests { .map(|(x, w)| (*x, *w)).collect(); assert_eq!(a.error(), 0.5); } + + #[test] + fn merge_unweighted() { + let sequence: &[f64] = &[1., 2., 3., 4., 5., 6., 7., 8., 9.]; + for mid in 0..sequence.len() { + let (left, right) = sequence.split_at(mid); + let avg_total: WeightedAverage = sequence.iter().map(|x| (*x, 1.)).collect(); + let mut avg_left: WeightedAverage = left.iter().map(|x| (*x, 1.)).collect(); + let avg_right: WeightedAverage = right.iter().map(|x| (*x, 1.)).collect(); + avg_left.merge(&avg_right); + assert_eq!(avg_total.weight_sum, avg_left.weight_sum); + assert_eq!(avg_total.weight_sum_sq, avg_left.weight_sum_sq); + assert_eq!(avg_total.avg, avg_left.avg); + assert_eq!(avg_total.v, avg_left.v); + } + } + + #[test] + fn merge_weighted() { + let sequence: &[(f64, f64)] = &[ + (1., 0.1), (2., 0.2), (3., 0.3), (4., 0.4), (5., 0.5), + (6., 0.6), (7., 0.7), (8., 0.8), (9., 0.)]; + for mid in 0..sequence.len() { + let (left, right) = sequence.split_at(mid); + let avg_total: WeightedAverage = sequence.iter().map(|&(x, w)| (x, w)).collect(); + let mut avg_left: WeightedAverage = left.iter().map(|&(x, w)| (x, w)).collect(); + let avg_right: WeightedAverage = right.iter().map(|&(x, w)| (x, w)).collect(); + avg_left.merge(&avg_right); + assert_almost_eq!(avg_total.weight_sum, avg_left.weight_sum, 1e-15); + assert_eq!(avg_total.weight_sum_sq, avg_left.weight_sum_sq); + assert_almost_eq!(avg_total.avg, avg_left.avg, 1e-15); + assert_almost_eq!(avg_total.v, avg_left.v, 1e-14); + } + } }