77// except according to those terms.
88
99use std:: ops:: { Add , Div , Mul } ;
10- use libnum:: { self , One , Zero , Float } ;
10+ use libnum:: { self , One , Zero , Float , FromPrimitive } ;
1111use itertools:: free:: enumerate;
1212
1313use imp_prelude:: * ;
@@ -174,8 +174,11 @@ impl<A, S, D> ArrayBase<S, D>
174174 /// n i=1
175175 /// ```
176176 ///
177- /// **Panics** if `ddof` is greater than or equal to the length of the
178- /// axis, if `axis` is out of bounds, or if the length of the axis is zero.
177+ /// and `n` is the length of the axis.
178+ ///
179+ /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
180+ /// is out of bounds, or if `A::from_usize()` fails for any any of the
181+ /// numbers in the range `0..=n`.
179182 ///
180183 /// # Example
181184 ///
@@ -190,27 +193,28 @@ impl<A, S, D> ArrayBase<S, D>
190193 /// ```
191194 pub fn var_axis ( & self , axis : Axis , ddof : A ) -> Array < A , D :: Smaller >
192195 where
193- A : Float ,
196+ A : Float + FromPrimitive ,
194197 D : RemoveAxis ,
195198 {
196- let mut count = A :: zero ( ) ;
199+ let zero = A :: from_usize ( 0 ) . expect ( "Converting 0 to `A` must not fail." ) ;
200+ let n = A :: from_usize ( self . len_of ( axis) ) . expect ( "Converting length to `A` must not fail." ) ;
201+ assert ! (
202+ !( ddof < zero || ddof > n) ,
203+ "`ddof` must not be less than zero or greater than the length of \
204+ the axis",
205+ ) ;
206+ let dof = n - ddof;
197207 let mut mean = Array :: < A , _ > :: zeros ( self . dim . remove_axis ( axis) ) ;
198208 let mut sum_sq = Array :: < A , _ > :: zeros ( self . dim . remove_axis ( axis) ) ;
199- for subview in self . axis_iter ( axis) {
200- count = count + A :: one ( ) ;
209+ for ( i , subview) in self . axis_iter ( axis) . enumerate ( ) {
210+ let count = A :: from_usize ( i + 1 ) . expect ( "Converting index to `A` must not fail." ) ;
201211 azip ! ( mut mean, mut sum_sq, x ( subview) in {
202212 let delta = x - * mean;
203213 * mean = * mean + delta / count;
204214 * sum_sq = ( x - * mean) . mul_add( delta, * sum_sq) ;
205215 } ) ;
206216 }
207- if ddof >= count {
208- panic ! ( "`ddof` needs to be strictly smaller than the length \
209- of the axis you are computing the variance for!")
210- } else {
211- let dof = count - ddof;
212- sum_sq. mapv_into ( |s| s / dof)
213- }
217+ sum_sq. mapv_into ( |s| s / dof)
214218 }
215219
216220 /// Return standard deviation along `axis`.
@@ -238,8 +242,11 @@ impl<A, S, D> ArrayBase<S, D>
238242 /// n i=1
239243 /// ```
240244 ///
241- /// **Panics** if `ddof` is greater than or equal to the length of the
242- /// axis, if `axis` is out of bounds, or if the length of the axis is zero.
245+ /// and `n` is the length of the axis.
246+ ///
247+ /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
248+ /// is out of bounds, or if `A::from_usize()` fails for any any of the
249+ /// numbers in the range `0..=n`.
243250 ///
244251 /// # Example
245252 ///
@@ -254,7 +261,7 @@ impl<A, S, D> ArrayBase<S, D>
254261 /// ```
255262 pub fn std_axis ( & self , axis : Axis , ddof : A ) -> Array < A , D :: Smaller >
256263 where
257- A : Float ,
264+ A : Float + FromPrimitive ,
258265 D : RemoveAxis ,
259266 {
260267 self . var_axis ( axis, ddof) . mapv_into ( |x| x. sqrt ( ) )
0 commit comments