Implement merge for weighted average

This commit is contained in:
Vinzent Steinberg 2017-05-16 17:47:00 +02:00
parent e54d4aba8b
commit 34a8aadb35

View File

@ -118,6 +118,36 @@ impl WeightedAverage {
}; };
(variance / self.weight_sum).sqrt() (variance / self.weight_sum).sqrt()
} }
/// Merge the weighted average of another sequence into this one.
///
/// ```
/// use average::WeightedAverage;
///
/// let weighted_sequence: &[(f64, f64)] = &[
/// (1., 0.1), (2., 0.2), (3., 0.3), (4., 0.4), (5., 0.5),
/// (6., 0.6), (7., 0.7), (8., 0.8), (9., 0.)];
/// let (left, right) = weighted_sequence.split_at(3);
/// let avg_total: WeightedAverage = weighted_sequence.iter().map(|&x| x).collect();
/// let mut avg_left: WeightedAverage = left.iter().map(|&x| x).collect();
/// let avg_right: WeightedAverage = right.iter().map(|&x| x).collect();
/// avg_left.merge(&avg_right);
/// assert!((avg_total.mean() - avg_left.mean()).abs() < 1e-15);
/// assert!((avg_total.sample_variance() - avg_left.sample_variance()).abs() < 1e-15);
/// ```
pub fn merge(&mut self, other: &WeightedAverage) {
// This is similar to the algorithm proposed by Chan et al. in 1979.
//
// See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance.
let delta = other.avg - self.avg;
let total_weight_sum = self.weight_sum + other.weight_sum;
self.avg = (self.weight_sum * self.avg + other.weight_sum * other.avg)
/ (self.weight_sum + other.weight_sum);
self.v += other.v + delta*delta * self.weight_sum * other.weight_sum
/ total_weight_sum;
self.weight_sum = total_weight_sum;
self.weight_sum_sq += other.weight_sum_sq;
}
} }
impl core::default::Default for WeightedAverage { impl core::default::Default for WeightedAverage {
@ -194,4 +224,38 @@ mod tests {
.map(|(x, w)| (*x, *w)).collect(); .map(|(x, w)| (*x, *w)).collect();
assert_eq!(a.error(), 0.5); assert_eq!(a.error(), 0.5);
} }
#[test]
fn merge_unweighted() {
let sequence: &[f64] = &[1., 2., 3., 4., 5., 6., 7., 8., 9.];
for mid in 0..sequence.len() {
let (left, right) = sequence.split_at(mid);
let avg_total: WeightedAverage = sequence.iter().map(|x| (*x, 1.)).collect();
let mut avg_left: WeightedAverage = left.iter().map(|x| (*x, 1.)).collect();
let avg_right: WeightedAverage = right.iter().map(|x| (*x, 1.)).collect();
avg_left.merge(&avg_right);
assert_eq!(avg_total.weight_sum, avg_left.weight_sum);
assert_eq!(avg_total.weight_sum_sq, avg_left.weight_sum_sq);
assert_eq!(avg_total.avg, avg_left.avg);
assert_eq!(avg_total.v, avg_left.v);
}
}
#[test]
fn merge_weighted() {
let sequence: &[(f64, f64)] = &[
(1., 0.1), (2., 0.2), (3., 0.3), (4., 0.4), (5., 0.5),
(6., 0.6), (7., 0.7), (8., 0.8), (9., 0.)];
for mid in 0..sequence.len() {
let (left, right) = sequence.split_at(mid);
let avg_total: WeightedAverage = sequence.iter().map(|&(x, w)| (x, w)).collect();
let mut avg_left: WeightedAverage = left.iter().map(|&(x, w)| (x, w)).collect();
let avg_right: WeightedAverage = right.iter().map(|&(x, w)| (x, w)).collect();
avg_left.merge(&avg_right);
assert_almost_eq!(avg_total.weight_sum, avg_left.weight_sum, 1e-15);
assert_eq!(avg_total.weight_sum_sq, avg_left.weight_sum_sq);
assert_almost_eq!(avg_total.avg, avg_left.avg, 1e-15);
assert_almost_eq!(avg_total.v, avg_left.v, 1e-14);
}
}
} }