diff --git a/src/weighted_average.rs b/src/weighted_average.rs index 5c6fe08..667f749 100644 --- a/src/weighted_average.rs +++ b/src/weighted_average.rs @@ -6,8 +6,6 @@ use core; pub struct WeightedAverage { /// Sum of the weights. weight_sum: f64, - /// Sum of the squares of the weights. - weight_sum_sq: f64, /// Average value. avg: f64, /// Intermediate sum of squares for calculating the variance. @@ -17,7 +15,7 @@ pub struct WeightedAverage { impl WeightedAverage { /// Create a new weighted average. pub fn new() -> WeightedAverage { - WeightedAverage { weight_sum: 0., weight_sum_sq: 0., avg: 0., v: 0. } + WeightedAverage { weight_sum: 0., avg: 0., v: 0. } } /// Add a sample to the weighted sequence of which the average is calculated. @@ -29,7 +27,6 @@ impl WeightedAverage { // and // http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf. self.weight_sum += weight; - self.weight_sum_sq += weight*weight; let prev_avg = self.avg; self.avg = prev_avg + (weight / self.weight_sum) * (sample - prev_avg); self.v += weight * (sample - prev_avg) * (sample - self.avg); @@ -37,7 +34,7 @@ impl WeightedAverage { /// Determine whether the sequence is empty. pub fn is_empty(&self) -> bool { - self.weight_sum_sq == 0. + self.weight_sum == 0. && self.v == 0. && self.avg == 0. } /// Return the sum of the weights. @@ -45,24 +42,11 @@ impl WeightedAverage { self.weight_sum } - /// Return the sum of the squared weights. - pub fn sum_weights_sq(&self) -> f64 { - self.weight_sum_sq - } - /// Estimate the weighted mean of the sequence. pub fn mean(&self) -> f64 { self.avg } - /// Calculate the effective sample size. - pub fn effective_len(&self) -> f64 { - if self.is_empty() { - return 0. - } - self.weight_sum * self.weight_sum / self.weight_sum_sq - } - /// Calculate the population variance of the weighted sequence. /// /// This assumes that the sequence consists of the entire population and the @@ -80,41 +64,30 @@ impl WeightedAverage { /// This assumes that the sequence consists of samples of a larger /// population and the weights represent *frequency*. /// - /// Note that this is undefined if the sum of the weights is 1. + /// Note that this will return 0 if the sum of the weights is <= 1. pub fn sample_variance(&self) -> f64 { - if self.effective_len() <= 1. { + if self.weight_sum <= 1. { 0. } else { self.v / (self.weight_sum - 1.0) } } - /// Calculate the reliability variance of the weighted sequence. - /// - /// This assumes weights represent *reliability*. - pub fn reliability_variance(&self) -> f64 { - if self.is_empty() { - 0. - } else { - self.v / (self.weight_sum - self.weight_sum_sq / self.weight_sum) - } - } - /// Estimate the standard error of the weighted mean of the sequence. + /// + /// Note that this will return 0 if the sum of the weights is 0. + /// For this estimator the sum of weights should be larger than 1. pub fn error(&self) -> f64 { // This uses the same estimate as SPSS. // // See http://www.analyticalgroup.com/download/WEIGHTED_MEAN.pdf. - if self.is_empty() { + if self.weight_sum == 0. { return 0.; } - let variance = if self.weight_sum != 1. { - // We generally want to use the weighted sample variance... - self.sample_variance() - } else { - // ...but in this case it is undefined, so we use the weighted - // population variance instead. + let variance = if self.weight_sum <= 1. { self.population_variance() + } else { + self.sample_variance() }; (variance / self.weight_sum).sqrt() } @@ -146,7 +119,6 @@ impl WeightedAverage { self.v += other.v + delta*delta * self.weight_sum * other.weight_sum / total_weight_sum; self.weight_sum = total_weight_sum; - self.weight_sum_sq += other.weight_sum_sq; } } @@ -178,17 +150,14 @@ mod tests { fn trivial() { let mut a = WeightedAverage::new(); assert_eq!(a.sum_weights(), 0.); - assert_eq!(a.sum_weights_sq(), 0.); a.add(1.0, 1.0); assert_eq!(a.mean(), 1.0); assert_eq!(a.sum_weights(), 1.0); - assert_eq!(a.sum_weights_sq(), 1.0); assert_eq!(a.population_variance(), 0.0); assert_eq!(a.error(), 0.0); a.add(1.0, 1.0); assert_eq!(a.mean(), 1.0); assert_eq!(a.sum_weights(), 2.0); - assert_eq!(a.sum_weights_sq(), 2.0); assert_eq!(a.population_variance(), 0.0); assert_eq!(a.error(), 0.0); } @@ -212,7 +181,6 @@ mod tests { assert_almost_eq!(a.mean(), 3.53486, 1e-5); assert_almost_eq!(a.sample_variance(), 1.8210, 1e-4); assert_eq!(a.sum_weights(), 10.47); - assert_almost_eq!(a.effective_len(), 8.2315, 1e-4); assert_almost_eq!(a.error(), f64::sqrt(0.1739), 1e-4); } @@ -235,7 +203,6 @@ mod tests { let avg_right: WeightedAverage = right.iter().map(|x| (*x, 1.)).collect(); avg_left.merge(&avg_right); assert_eq!(avg_total.weight_sum, avg_left.weight_sum); - assert_eq!(avg_total.weight_sum_sq, avg_left.weight_sum_sq); assert_eq!(avg_total.avg, avg_left.avg); assert_eq!(avg_total.v, avg_left.v); } @@ -253,7 +220,6 @@ mod tests { let avg_right: WeightedAverage = right.iter().map(|&(x, w)| (x, w)).collect(); avg_left.merge(&avg_right); assert_almost_eq!(avg_total.weight_sum, avg_left.weight_sum, 1e-15); - assert_eq!(avg_total.weight_sum_sq, avg_left.weight_sum_sq); assert_almost_eq!(avg_total.avg, avg_left.avg, 1e-15); assert_almost_eq!(avg_total.v, avg_left.v, 1e-14); }