Implement merge for weighted average
This commit is contained in:
parent
e54d4aba8b
commit
34a8aadb35
@ -118,6 +118,36 @@ impl WeightedAverage {
|
|||||||
};
|
};
|
||||||
(variance / self.weight_sum).sqrt()
|
(variance / self.weight_sum).sqrt()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Merge the weighted average of another sequence into this one.
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use average::WeightedAverage;
|
||||||
|
///
|
||||||
|
/// let weighted_sequence: &[(f64, f64)] = &[
|
||||||
|
/// (1., 0.1), (2., 0.2), (3., 0.3), (4., 0.4), (5., 0.5),
|
||||||
|
/// (6., 0.6), (7., 0.7), (8., 0.8), (9., 0.)];
|
||||||
|
/// let (left, right) = weighted_sequence.split_at(3);
|
||||||
|
/// let avg_total: WeightedAverage = weighted_sequence.iter().map(|&x| x).collect();
|
||||||
|
/// let mut avg_left: WeightedAverage = left.iter().map(|&x| x).collect();
|
||||||
|
/// let avg_right: WeightedAverage = right.iter().map(|&x| x).collect();
|
||||||
|
/// avg_left.merge(&avg_right);
|
||||||
|
/// assert!((avg_total.mean() - avg_left.mean()).abs() < 1e-15);
|
||||||
|
/// assert!((avg_total.sample_variance() - avg_left.sample_variance()).abs() < 1e-15);
|
||||||
|
/// ```
|
||||||
|
pub fn merge(&mut self, other: &WeightedAverage) {
|
||||||
|
// This is similar to the algorithm proposed by Chan et al. in 1979.
|
||||||
|
//
|
||||||
|
// See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance.
|
||||||
|
let delta = other.avg - self.avg;
|
||||||
|
let total_weight_sum = self.weight_sum + other.weight_sum;
|
||||||
|
self.avg = (self.weight_sum * self.avg + other.weight_sum * other.avg)
|
||||||
|
/ (self.weight_sum + other.weight_sum);
|
||||||
|
self.v += other.v + delta*delta * self.weight_sum * other.weight_sum
|
||||||
|
/ total_weight_sum;
|
||||||
|
self.weight_sum = total_weight_sum;
|
||||||
|
self.weight_sum_sq += other.weight_sum_sq;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl core::default::Default for WeightedAverage {
|
impl core::default::Default for WeightedAverage {
|
||||||
@ -194,4 +224,38 @@ mod tests {
|
|||||||
.map(|(x, w)| (*x, *w)).collect();
|
.map(|(x, w)| (*x, *w)).collect();
|
||||||
assert_eq!(a.error(), 0.5);
|
assert_eq!(a.error(), 0.5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn merge_unweighted() {
|
||||||
|
let sequence: &[f64] = &[1., 2., 3., 4., 5., 6., 7., 8., 9.];
|
||||||
|
for mid in 0..sequence.len() {
|
||||||
|
let (left, right) = sequence.split_at(mid);
|
||||||
|
let avg_total: WeightedAverage = sequence.iter().map(|x| (*x, 1.)).collect();
|
||||||
|
let mut avg_left: WeightedAverage = left.iter().map(|x| (*x, 1.)).collect();
|
||||||
|
let avg_right: WeightedAverage = right.iter().map(|x| (*x, 1.)).collect();
|
||||||
|
avg_left.merge(&avg_right);
|
||||||
|
assert_eq!(avg_total.weight_sum, avg_left.weight_sum);
|
||||||
|
assert_eq!(avg_total.weight_sum_sq, avg_left.weight_sum_sq);
|
||||||
|
assert_eq!(avg_total.avg, avg_left.avg);
|
||||||
|
assert_eq!(avg_total.v, avg_left.v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn merge_weighted() {
|
||||||
|
let sequence: &[(f64, f64)] = &[
|
||||||
|
(1., 0.1), (2., 0.2), (3., 0.3), (4., 0.4), (5., 0.5),
|
||||||
|
(6., 0.6), (7., 0.7), (8., 0.8), (9., 0.)];
|
||||||
|
for mid in 0..sequence.len() {
|
||||||
|
let (left, right) = sequence.split_at(mid);
|
||||||
|
let avg_total: WeightedAverage = sequence.iter().map(|&(x, w)| (x, w)).collect();
|
||||||
|
let mut avg_left: WeightedAverage = left.iter().map(|&(x, w)| (x, w)).collect();
|
||||||
|
let avg_right: WeightedAverage = right.iter().map(|&(x, w)| (x, w)).collect();
|
||||||
|
avg_left.merge(&avg_right);
|
||||||
|
assert_almost_eq!(avg_total.weight_sum, avg_left.weight_sum, 1e-15);
|
||||||
|
assert_eq!(avg_total.weight_sum_sq, avg_left.weight_sum_sq);
|
||||||
|
assert_almost_eq!(avg_total.avg, avg_left.avg, 1e-15);
|
||||||
|
assert_almost_eq!(avg_total.v, avg_left.v, 1e-14);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user