1+ use crate :: errors:: EmptyInput ;
12use ndarray:: prelude:: * ;
23use ndarray:: Data ;
34use num_traits:: { Float , FromPrimitive } ;
@@ -41,10 +42,10 @@ where
4142 /// ```
4243 /// and similarly for ̅y.
4344 ///
44- /// **Panics** if `ddof ` is greater than or equal to the number of
45- /// observations, if the number of observations is zero and division by
46- /// zero panics for type `A`, or if the type cast of `n_observations` from
47- /// `usize` to `A` fails.
45+ /// If `M ` is empty (either zero observations or zero random variables), it returns `Err(EmptyInput)`.
46+ ///
47+ /// **Panics** if `ddof` is negative or greater than or equal to the number of
48+ /// observations, or if the type cast of `n_observations` from `usize` to `A` fails.
4849 ///
4950 /// # Example
5051 ///
@@ -54,13 +55,13 @@ where
5455 ///
5556 /// let a = arr2(&[[1., 3., 5.],
5657 /// [2., 4., 6.]]);
57- /// let covariance = a.cov(1.);
58+ /// let covariance = a.cov(1.).unwrap() ;
5859 /// assert_eq!(
5960 /// covariance,
6061 /// aview2(&[[4., 4.], [4., 4.]])
6162 /// );
6263 /// ```
63- fn cov ( & self , ddof : A ) -> Array2 < A >
64+ fn cov ( & self , ddof : A ) -> Result < Array2 < A > , EmptyInput >
6465 where
6566 A : Float + FromPrimitive ;
6667
@@ -89,30 +90,35 @@ where
8990 /// R_ij = rho(X_i, X_j)
9091 /// ```
9192 ///
92- /// **Panics** if `M` is empty, if the type cast of `n_observations`
93- /// from `usize` to `A` fails or if the standard deviation of one of the random
93+ /// If `M` is empty (either zero observations or zero random variables), it returns `Err(EmptyInput)`.
94+ ///
95+ /// **Panics** if the type cast of `n_observations` from `usize` to `A` fails or
96+ /// if the standard deviation of one of the random variables is zero and
97+ /// division by zero panics for type A.
9498 ///
9599 /// # Example
96100 ///
97- /// variables is zero and division by zero panics for type A.
98101 /// ```
102+ /// use approx;
99103 /// use ndarray::arr2;
100104 /// use ndarray_stats::CorrelationExt;
105+ /// use approx::AbsDiffEq;
101106 ///
102107 /// let a = arr2(&[[1., 3., 5.],
103108 /// [2., 4., 6.]]);
104- /// let corr = a.pearson_correlation();
109+ /// let corr = a.pearson_correlation().unwrap();
110+ /// let epsilon = 1e-7;
105111 /// assert!(
106- /// corr.all_close (
112+ /// corr.abs_diff_eq (
107113 /// &arr2(&[
108114 /// [1., 1.],
109115 /// [1., 1.],
110116 /// ]),
111- /// 1e-7
117+ /// epsilon
112118 /// )
113119 /// );
114120 /// ```
115- fn pearson_correlation ( & self ) -> Array2 < A >
121+ fn pearson_correlation ( & self ) -> Result < Array2 < A > , EmptyInput >
116122 where
117123 A : Float + FromPrimitive ;
118124
@@ -123,7 +129,7 @@ impl<A: 'static, S> CorrelationExt<A, S> for ArrayBase<S, Ix2>
123129where
124130 S : Data < Elem = A > ,
125131{
126- fn cov ( & self , ddof : A ) -> Array2 < A >
132+ fn cov ( & self , ddof : A ) -> Result < Array2 < A > , EmptyInput >
127133 where
128134 A : Float + FromPrimitive ,
129135 {
@@ -139,28 +145,37 @@ where
139145 n_observations - ddof
140146 } ;
141147 let mean = self . mean_axis ( observation_axis) ;
142- let denoised = self - & mean. insert_axis ( observation_axis) ;
143- let covariance = denoised. dot ( & denoised. t ( ) ) ;
144- covariance. mapv_into ( |x| x / dof)
148+ match mean {
149+ Some ( mean) => {
150+ let denoised = self - & mean. insert_axis ( observation_axis) ;
151+ let covariance = denoised. dot ( & denoised. t ( ) ) ;
152+ Ok ( covariance. mapv_into ( |x| x / dof) )
153+ }
154+ None => Err ( EmptyInput ) ,
155+ }
145156 }
146157
147- fn pearson_correlation ( & self ) -> Array2 < A >
158+ fn pearson_correlation ( & self ) -> Result < Array2 < A > , EmptyInput >
148159 where
149160 A : Float + FromPrimitive ,
150161 {
151- let observation_axis = Axis ( 1 ) ;
152- // The ddof value doesn't matter, as long as we use the same one
153- // for computing covariance and standard deviation
154- // We choose -1 to avoid panicking when we only have one
155- // observation per random variable (or no observations at all)
156- let ddof = -A :: one ( ) ;
157- let cov = self . cov ( ddof) ;
158- let std = self
159- . std_axis ( observation_axis, ddof)
160- . insert_axis ( observation_axis) ;
161- let std_matrix = std. dot ( & std. t ( ) ) ;
162- // element-wise division
163- cov / std_matrix
162+ match self . dim ( ) {
163+ ( n, m) if n > 0 && m > 0 => {
164+ let observation_axis = Axis ( 1 ) ;
165+ // The ddof value doesn't matter, as long as we use the same one
166+ // for computing covariance and standard deviation
167+ // We choose 0 as it is the smallest number admitted by std_axis
168+ let ddof = A :: zero ( ) ;
169+ let cov = self . cov ( ddof) . unwrap ( ) ;
170+ let std = self
171+ . std_axis ( observation_axis, ddof)
172+ . insert_axis ( observation_axis) ;
173+ let std_matrix = std. dot ( & std. t ( ) ) ;
174+ // element-wise division
175+ Ok ( cov / std_matrix)
176+ }
177+ _ => Err ( EmptyInput ) ,
178+ }
164179 }
165180
166181 private_impl ! { }
@@ -180,9 +195,10 @@ mod cov_tests {
180195 let n_random_variables = 3 ;
181196 let n_observations = 4 ;
182197 let a = Array :: from_elem ( ( n_random_variables, n_observations) , value) ;
183- a. cov ( 1. ) . all_close (
198+ abs_diff_eq ! (
199+ a. cov( 1. ) . unwrap( ) ,
184200 & Array :: zeros( ( n_random_variables, n_random_variables) ) ,
185- 1e-8 ,
201+ epsilon = 1e-8 ,
186202 )
187203 }
188204
@@ -194,8 +210,8 @@ mod cov_tests {
194210 ( n_random_variables, n_observations) ,
195211 Uniform :: new ( -bound. abs ( ) , bound. abs ( ) ) ,
196212 ) ;
197- let covariance = a. cov ( 1. ) ;
198- covariance . all_close ( & covariance. t ( ) , 1e-8 )
213+ let covariance = a. cov ( 1. ) . unwrap ( ) ;
214+ abs_diff_eq ! ( covariance , & covariance. t( ) , epsilon = 1e-8 )
199215 }
200216
201217 #[ test]
@@ -205,31 +221,31 @@ mod cov_tests {
205221 let n_observations = 4 ;
206222 let a = Array :: random ( ( n_random_variables, n_observations) , Uniform :: new ( 0. , 10. ) ) ;
207223 let invalid_ddof = ( n_observations as f64 ) + rand:: random :: < f64 > ( ) . abs ( ) ;
208- a. cov ( invalid_ddof) ;
224+ let _ = a. cov ( invalid_ddof) ;
209225 }
210226
211227 #[ test]
212228 fn test_covariance_zero_variables ( ) {
213229 let a = Array2 :: < f32 > :: zeros ( ( 0 , 2 ) ) ;
214230 let cov = a. cov ( 1. ) ;
215- assert_eq ! ( cov. shape( ) , & [ 0 , 0 ] ) ;
231+ assert ! ( cov. is_ok( ) ) ;
232+ assert_eq ! ( cov. unwrap( ) . shape( ) , & [ 0 , 0 ] ) ;
216233 }
217234
218235 #[ test]
219236 fn test_covariance_zero_observations ( ) {
220237 let a = Array2 :: < f32 > :: zeros ( ( 2 , 0 ) ) ;
221238 // Negative ddof (-1 < 0) to avoid invalid-ddof panic
222239 let cov = a. cov ( -1. ) ;
223- assert_eq ! ( cov. shape( ) , & [ 2 , 2 ] ) ;
224- cov. mapv ( |x| assert_eq ! ( x, 0. ) ) ;
240+ assert_eq ! ( cov, Err ( EmptyInput ) ) ;
225241 }
226242
227243 #[ test]
228244 fn test_covariance_zero_variables_zero_observations ( ) {
229245 let a = Array2 :: < f32 > :: zeros ( ( 0 , 0 ) ) ;
230246 // Negative ddof (-1 < 0) to avoid invalid-ddof panic
231247 let cov = a. cov ( -1. ) ;
232- assert_eq ! ( cov. shape ( ) , & [ 0 , 0 ] ) ;
248+ assert_eq ! ( cov, Err ( EmptyInput ) ) ;
233249 }
234250
235251 #[ test]
@@ -255,7 +271,7 @@ mod cov_tests {
255271 ]
256272 ] ;
257273 assert_eq ! ( a. ndim( ) , 2 ) ;
258- assert ! ( a. cov( 1. ) . all_close ( & numpy_covariance, 1e-8 ) ) ;
274+ assert_abs_diff_eq ! ( a. cov( 1. ) . unwrap ( ) , & numpy_covariance, epsilon = 1e-8 ) ;
259275 }
260276
261277 #[ test]
@@ -264,7 +280,7 @@ mod cov_tests {
264280 fn test_covariance_for_badly_conditioned_array ( ) {
265281 let a: Array2 < f64 > = array ! [ [ 1e12 + 1. , 1e12 - 1. ] , [ 1e-6 + 1e-12 , 1e-6 - 1e-12 ] , ] ;
266282 let expected_covariance = array ! [ [ 2. , 2e-12 ] , [ 2e-12 , 2e-24 ] ] ;
267- assert ! ( a. cov( 1. ) . all_close ( & expected_covariance, 1e-24 ) ) ;
283+ assert_abs_diff_eq ! ( a. cov( 1. ) . unwrap ( ) , & expected_covariance, epsilon = 1e-24 ) ;
268284 }
269285}
270286
@@ -284,8 +300,12 @@ mod pearson_correlation_tests {
284300 ( n_random_variables, n_observations) ,
285301 Uniform :: new ( -bound. abs ( ) , bound. abs ( ) ) ,
286302 ) ;
287- let pearson_correlation = a. pearson_correlation ( ) ;
288- pearson_correlation. all_close ( & pearson_correlation. t ( ) , 1e-8 )
303+ let pearson_correlation = a. pearson_correlation ( ) . unwrap ( ) ;
304+ abs_diff_eq ! (
305+ pearson_correlation. view( ) ,
306+ pearson_correlation. t( ) ,
307+ epsilon = 1e-8
308+ )
289309 }
290310
291311 #[ quickcheck]
@@ -295,6 +315,7 @@ mod pearson_correlation_tests {
295315 let a = Array :: from_elem ( ( n_random_variables, n_observations) , value) ;
296316 let pearson_correlation = a. pearson_correlation ( ) ;
297317 pearson_correlation
318+ . unwrap ( )
298319 . iter ( )
299320 . map ( |x| x. is_nan ( ) )
300321 . fold ( true , |acc, flag| acc & flag)
@@ -304,21 +325,21 @@ mod pearson_correlation_tests {
304325 fn test_zero_variables ( ) {
305326 let a = Array2 :: < f32 > :: zeros ( ( 0 , 2 ) ) ;
306327 let pearson_correlation = a. pearson_correlation ( ) ;
307- assert_eq ! ( pearson_correlation. shape ( ) , & [ 0 , 0 ] ) ;
328+ assert_eq ! ( pearson_correlation, Err ( EmptyInput ) )
308329 }
309330
310331 #[ test]
311332 fn test_zero_observations ( ) {
312333 let a = Array2 :: < f32 > :: zeros ( ( 2 , 0 ) ) ;
313334 let pearson = a. pearson_correlation ( ) ;
314- pearson . mapv ( |x| x . is_nan ( ) ) ;
335+ assert_eq ! ( pearson , Err ( EmptyInput ) ) ;
315336 }
316337
317338 #[ test]
318339 fn test_zero_variables_zero_observations ( ) {
319340 let a = Array2 :: < f32 > :: zeros ( ( 0 , 0 ) ) ;
320341 let pearson = a. pearson_correlation ( ) ;
321- assert_eq ! ( pearson. shape ( ) , & [ 0 , 0 ] ) ;
342+ assert_eq ! ( pearson, Err ( EmptyInput ) ) ;
322343 }
323344
324345 #[ test]
@@ -338,6 +359,10 @@ mod pearson_correlation_tests {
338359 [ 0.1365648 , 0.38954398 , -0.17324776 , -0.8743213 , 1. ]
339360 ] ;
340361 assert_eq ! ( a. ndim( ) , 2 ) ;
341- assert ! ( a. pearson_correlation( ) . all_close( & numpy_corrcoeff, 1e-7 ) ) ;
362+ assert_abs_diff_eq ! (
363+ a. pearson_correlation( ) . unwrap( ) ,
364+ numpy_corrcoeff,
365+ epsilon = 1e-7
366+ ) ;
342367 }
343368}
0 commit comments