11use super :: SummaryStatisticsExt ;
2- use crate :: errors:: EmptyInput ;
3- use ndarray:: { ArrayBase , Data , Dimension } ;
2+ use crate :: errors:: { EmptyInput , MultiInputError , ShapeMismatch } ;
3+ use ndarray:: { Array , ArrayBase , Axis , Data , Dimension , Ix1 , RemoveAxis } ;
44use num_integer:: IterBinomial ;
55use num_traits:: { Float , FromPrimitive , Zero } ;
6- use std:: ops:: { Add , Div } ;
6+ use std:: ops:: { Add , Div , Mul } ;
77
88impl < A , S , D > SummaryStatisticsExt < A , S , D > for ArrayBase < S , D >
99where
2424 }
2525 }
2626
27+ fn weighted_mean ( & self , weights : & Self ) -> Result < A , MultiInputError >
28+ where
29+ A : Copy + Div < Output = A > + Mul < Output = A > + Zero ,
30+ {
31+ return_err_if_empty ! ( self ) ;
32+ let weighted_sum = self . weighted_sum ( weights) ?;
33+ Ok ( weighted_sum / weights. sum ( ) )
34+ }
35+
36+ fn weighted_sum ( & self , weights : & ArrayBase < S , D > ) -> Result < A , MultiInputError >
37+ where
38+ A : Copy + Mul < Output = A > + Zero ,
39+ {
40+ return_err_unless_same_shape ! ( self , weights) ;
41+ Ok ( self
42+ . iter ( )
43+ . zip ( weights)
44+ . fold ( A :: zero ( ) , |acc, ( & d, & w) | acc + d * w) )
45+ }
46+
47+ fn weighted_mean_axis (
48+ & self ,
49+ axis : Axis ,
50+ weights : & ArrayBase < S , Ix1 > ,
51+ ) -> Result < Array < A , D :: Smaller > , MultiInputError >
52+ where
53+ A : Copy + Div < Output = A > + Mul < Output = A > + Zero ,
54+ D : RemoveAxis ,
55+ {
56+ return_err_if_empty ! ( self ) ;
57+ let mut weighted_sum = self . weighted_sum_axis ( axis, weights) ?;
58+ let weights_sum = weights. sum ( ) ;
59+ weighted_sum. mapv_inplace ( |v| v / weights_sum) ;
60+ Ok ( weighted_sum)
61+ }
62+
63+ fn weighted_sum_axis (
64+ & self ,
65+ axis : Axis ,
66+ weights : & ArrayBase < S , Ix1 > ,
67+ ) -> Result < Array < A , D :: Smaller > , MultiInputError >
68+ where
69+ A : Copy + Mul < Output = A > + Zero ,
70+ D : RemoveAxis ,
71+ {
72+ if self . shape ( ) [ axis. index ( ) ] != weights. len ( ) {
73+ return Err ( MultiInputError :: ShapeMismatch ( ShapeMismatch {
74+ first_shape : self . shape ( ) . to_vec ( ) ,
75+ second_shape : weights. shape ( ) . to_vec ( ) ,
76+ } ) ) ;
77+ }
78+
79+ // We could use `lane.weighted_sum` here, but we're avoiding 2
80+ // conditions and an unwrap per lane.
81+ Ok ( self . map_axis ( axis, |lane| {
82+ lane. iter ( )
83+ . zip ( weights)
84+ . fold ( A :: zero ( ) , |acc, ( & d, & w) | acc + d * w)
85+ } ) )
86+ }
87+
2788 fn harmonic_mean ( & self ) -> Result < A , EmptyInput >
2889 where
2990 A : Float + FromPrimitive ,
@@ -194,18 +255,31 @@ where
194255#[ cfg( test) ]
195256mod tests {
196257 use super :: SummaryStatisticsExt ;
197- use crate :: errors:: EmptyInput ;
198- use approx:: assert_abs_diff_eq;
199- use ndarray:: { array, Array , Array1 } ;
258+ use crate :: errors:: { EmptyInput , MultiInputError , ShapeMismatch } ;
259+ use approx:: { abs_diff_eq , assert_abs_diff_eq} ;
260+ use ndarray:: { arr0 , array, Array , Array1 , Array2 , Axis } ;
200261 use ndarray_rand:: RandomExt ;
201262 use noisy_float:: types:: N64 ;
263+ use quickcheck:: { quickcheck, TestResult } ;
202264 use rand:: distributions:: Uniform ;
203265 use std:: f64;
204266
205267 #[ test]
206268 fn test_means_with_nan_values ( ) {
207269 let a = array ! [ f64 :: NAN , 1. ] ;
208270 assert ! ( a. mean( ) . unwrap( ) . is_nan( ) ) ;
271+ assert ! ( a. weighted_mean( & array![ 1.0 , f64 :: NAN ] ) . unwrap( ) . is_nan( ) ) ;
272+ assert ! ( a. weighted_sum( & array![ 1.0 , f64 :: NAN ] ) . unwrap( ) . is_nan( ) ) ;
273+ assert ! ( a
274+ . weighted_mean_axis( Axis ( 0 ) , & array![ 1.0 , f64 :: NAN ] )
275+ . unwrap( )
276+ . into_scalar( )
277+ . is_nan( ) ) ;
278+ assert ! ( a
279+ . weighted_sum_axis( Axis ( 0 ) , & array![ 1.0 , f64 :: NAN ] )
280+ . unwrap( )
281+ . into_scalar( )
282+ . is_nan( ) ) ;
209283 assert ! ( a. harmonic_mean( ) . unwrap( ) . is_nan( ) ) ;
210284 assert ! ( a. geometric_mean( ) . unwrap( ) . is_nan( ) ) ;
211285 }
@@ -214,16 +288,40 @@ mod tests {
214288 fn test_means_with_empty_array_of_floats ( ) {
215289 let a: Array1 < f64 > = array ! [ ] ;
216290 assert_eq ! ( a. mean( ) , None ) ;
291+ assert_eq ! (
292+ a. weighted_mean( & array![ 1.0 ] ) ,
293+ Err ( MultiInputError :: EmptyInput )
294+ ) ;
295+ assert_eq ! (
296+ a. weighted_mean_axis( Axis ( 0 ) , & array![ 1.0 ] ) ,
297+ Err ( MultiInputError :: EmptyInput )
298+ ) ;
217299 assert_eq ! ( a. harmonic_mean( ) , Err ( EmptyInput ) ) ;
218300 assert_eq ! ( a. geometric_mean( ) , Err ( EmptyInput ) ) ;
301+
302+ // The sum methods accept empty arrays
303+ assert_eq ! ( a. weighted_sum( & array![ ] ) , Ok ( 0.0 ) ) ;
304+ assert_eq ! ( a. weighted_sum_axis( Axis ( 0 ) , & array![ ] ) , Ok ( arr0( 0.0 ) ) ) ;
219305 }
220306
221307 #[ test]
222308 fn test_means_with_empty_array_of_noisy_floats ( ) {
223309 let a: Array1 < N64 > = array ! [ ] ;
224310 assert_eq ! ( a. mean( ) , None ) ;
311+ assert_eq ! ( a. weighted_mean( & array![ ] ) , Err ( MultiInputError :: EmptyInput ) ) ;
312+ assert_eq ! (
313+ a. weighted_mean_axis( Axis ( 0 ) , & array![ ] ) ,
314+ Err ( MultiInputError :: EmptyInput )
315+ ) ;
225316 assert_eq ! ( a. harmonic_mean( ) , Err ( EmptyInput ) ) ;
226317 assert_eq ! ( a. geometric_mean( ) , Err ( EmptyInput ) ) ;
318+
319+ // The sum methods accept empty arrays
320+ assert_eq ! ( a. weighted_sum( & array![ ] ) , Ok ( N64 :: new( 0.0 ) ) ) ;
321+ assert_eq ! (
322+ a. weighted_sum_axis( Axis ( 0 ) , & array![ ] ) ,
323+ Ok ( arr0( N64 :: new( 0.0 ) ) )
324+ ) ;
227325 }
228326
229327 #[ test]
@@ -240,9 +338,9 @@ mod tests {
240338 ] ;
241339 // Computed using NumPy
242340 let expected_mean = 0.5475494059146699 ;
341+ let expected_weighted_mean = 0.6782420496397121 ;
243342 // Computed using SciPy
244343 let expected_harmonic_mean = 0.21790094950226022 ;
245- // Computed using SciPy
246344 let expected_geometric_mean = 0.4345897639796527 ;
247345
248346 assert_abs_diff_eq ! ( a. mean( ) . unwrap( ) , expected_mean, epsilon = 1e-9 ) ;
@@ -256,6 +354,114 @@ mod tests {
256354 expected_geometric_mean,
257355 epsilon = 1e-12
258356 ) ;
357+
358+ // weighted_mean with itself, normalized
359+ let weights = & a / a. sum ( ) ;
360+ assert_abs_diff_eq ! (
361+ a. weighted_sum( & weights) . unwrap( ) ,
362+ expected_weighted_mean,
363+ epsilon = 1e-12
364+ ) ;
365+
366+ let data = a. into_shape ( ( 2 , 5 , 5 ) ) . unwrap ( ) ;
367+ let weights = array ! [ 0.1 , 0.5 , 0.25 , 0.15 , 0.2 ] ;
368+ assert_abs_diff_eq ! (
369+ data. weighted_mean_axis( Axis ( 1 ) , & weights) . unwrap( ) ,
370+ array![
371+ [ 0.50202721 , 0.53347361 , 0.29086033 , 0.56995637 , 0.37087139 ] ,
372+ [ 0.58028328 , 0.50485216 , 0.59349973 , 0.70308937 , 0.72280630 ]
373+ ] ,
374+ epsilon = 1e-8
375+ ) ;
376+ assert_abs_diff_eq ! (
377+ data. weighted_mean_axis( Axis ( 2 ) , & weights) . unwrap( ) ,
378+ array![
379+ [ 0.33434378 , 0.38365259 , 0.56405781 , 0.48676574 , 0.55016179 ] ,
380+ [ 0.71112376 , 0.55134174 , 0.45566513 , 0.74228516 , 0.68405851 ]
381+ ] ,
382+ epsilon = 1e-8
383+ ) ;
384+ assert_abs_diff_eq ! (
385+ data. weighted_sum_axis( Axis ( 1 ) , & weights) . unwrap( ) ,
386+ array![
387+ [ 0.60243266 , 0.64016833 , 0.34903240 , 0.68394765 , 0.44504567 ] ,
388+ [ 0.69633993 , 0.60582259 , 0.71219968 , 0.84370724 , 0.86736757 ]
389+ ] ,
390+ epsilon = 1e-8
391+ ) ;
392+ assert_abs_diff_eq ! (
393+ data. weighted_sum_axis( Axis ( 2 ) , & weights) . unwrap( ) ,
394+ array![
395+ [ 0.40121254 , 0.46038311 , 0.67686937 , 0.58411889 , 0.66019415 ] ,
396+ [ 0.85334851 , 0.66161009 , 0.54679815 , 0.89074219 , 0.82087021 ]
397+ ] ,
398+ epsilon = 1e-8
399+ ) ;
400+ }
401+
402+ #[ test]
403+ fn weighted_sum_dimension_zero ( ) {
404+ let a = Array2 :: < usize > :: zeros ( ( 0 , 20 ) ) ;
405+ assert_eq ! (
406+ a. weighted_sum_axis( Axis ( 0 ) , & Array1 :: zeros( 0 ) ) . unwrap( ) ,
407+ Array1 :: from_elem( 20 , 0 )
408+ ) ;
409+ assert_eq ! (
410+ a. weighted_sum_axis( Axis ( 1 ) , & Array1 :: zeros( 20 ) ) . unwrap( ) ,
411+ Array1 :: from_elem( 0 , 0 )
412+ ) ;
413+ assert_eq ! (
414+ a. weighted_sum_axis( Axis ( 0 ) , & Array1 :: zeros( 1 ) ) ,
415+ Err ( MultiInputError :: ShapeMismatch ( ShapeMismatch {
416+ first_shape: vec![ 0 , 20 ] ,
417+ second_shape: vec![ 1 ]
418+ } ) )
419+ ) ;
420+ assert_eq ! (
421+ a. weighted_sum( & Array2 :: zeros( ( 10 , 20 ) ) ) ,
422+ Err ( MultiInputError :: ShapeMismatch ( ShapeMismatch {
423+ first_shape: vec![ 0 , 20 ] ,
424+ second_shape: vec![ 10 , 20 ]
425+ } ) )
426+ ) ;
427+ }
428+
429+ #[ test]
430+ fn mean_eq_if_uniform_weights ( ) {
431+ fn prop ( a : Vec < f64 > ) -> TestResult {
432+ if a. len ( ) < 1 {
433+ return TestResult :: discard ( ) ;
434+ }
435+ let a = Array1 :: from ( a) ;
436+ let weights = Array1 :: from_elem ( a. len ( ) , 1.0 / a. len ( ) as f64 ) ;
437+ let m = a. mean ( ) . unwrap ( ) ;
438+ let wm = a. weighted_mean ( & weights) . unwrap ( ) ;
439+ let ws = a. weighted_sum ( & weights) . unwrap ( ) ;
440+ TestResult :: from_bool (
441+ abs_diff_eq ! ( m, wm, epsilon = 1e-9 ) && abs_diff_eq ! ( wm, ws, epsilon = 1e-9 ) ,
442+ )
443+ }
444+ quickcheck ( prop as fn ( Vec < f64 > ) -> TestResult ) ;
445+ }
446+
447+ #[ test]
448+ fn mean_axis_eq_if_uniform_weights ( ) {
449+ fn prop ( mut a : Vec < f64 > ) -> TestResult {
450+ if a. len ( ) < 24 {
451+ return TestResult :: discard ( ) ;
452+ }
453+ let depth = a. len ( ) / 12 ;
454+ a. truncate ( depth * 3 * 4 ) ;
455+ let weights = Array1 :: from_elem ( depth, 1.0 / depth as f64 ) ;
456+ let a = Array1 :: from ( a) . into_shape ( ( depth, 3 , 4 ) ) . unwrap ( ) ;
457+ let ma = a. mean_axis ( Axis ( 0 ) ) . unwrap ( ) ;
458+ let wm = a. weighted_mean_axis ( Axis ( 0 ) , & weights) . unwrap ( ) ;
459+ let ws = a. weighted_sum_axis ( Axis ( 0 ) , & weights) . unwrap ( ) ;
460+ TestResult :: from_bool (
461+ abs_diff_eq ! ( ma, wm, epsilon = 1e-12 ) && abs_diff_eq ! ( wm, ws, epsilon = 1e12 ) ,
462+ )
463+ }
464+ quickcheck ( prop as fn ( Vec < f64 > ) -> TestResult ) ;
259465 }
260466
261467 #[ test]
0 commit comments