77// except according to those terms.
88
99use std:: ops:: { Add , Div , Mul } ;
10- use libnum:: { self , One , Zero , Float } ;
10+ use libnum:: { self , One , Zero , Float , FromPrimitive } ;
1111use itertools:: free:: enumerate;
1212
1313use imp_prelude:: * ;
@@ -163,8 +163,11 @@ impl<A, S, D> ArrayBase<S, D>
163163 /// n i=1
164164 /// ```
165165 ///
166- /// **Panics** if `ddof` is less than zero or greater than the length of
167- /// the axis or if `axis` is out of bounds.
166+ /// and `n` is the length of the axis.
167+ ///
168+ /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
169+ /// is out of bounds, or if `A::from_usize()` fails for any any of the
170+ /// numbers in the range `0..=n`.
168171 ///
169172 /// # Example
170173 ///
@@ -179,26 +182,27 @@ impl<A, S, D> ArrayBase<S, D>
179182 /// ```
180183 pub fn var_axis ( & self , axis : Axis , ddof : A ) -> Array < A , D :: Smaller >
181184 where
182- A : Float ,
185+ A : Float + FromPrimitive ,
183186 D : RemoveAxis ,
184187 {
185- let mut count = A :: zero ( ) ;
188+ let zero = A :: from_usize ( 0 ) . expect ( "Converting 0 to `A` must not fail." ) ;
189+ let n = A :: from_usize ( self . len_of ( axis) ) . expect ( "Converting length to `A` must not fail." ) ;
190+ assert ! (
191+ !( ddof < zero || ddof > n) ,
192+ "`ddof` must not be less than zero or greater than the length of \
193+ the axis",
194+ ) ;
195+ let dof = n - ddof;
186196 let mut mean = Array :: < A , _ > :: zeros ( self . dim . remove_axis ( axis) ) ;
187197 let mut sum_sq = Array :: < A , _ > :: zeros ( self . dim . remove_axis ( axis) ) ;
188- for subview in self . axis_iter ( axis) {
189- count = count + A :: one ( ) ;
198+ for ( i , subview) in self . axis_iter ( axis) . enumerate ( ) {
199+ let count = A :: from_usize ( i + 1 ) . expect ( "Converting index to `A` must not fail." ) ;
190200 azip ! ( mut mean, mut sum_sq, x ( subview) in {
191201 let delta = x - * mean;
192202 * mean = * mean + delta / count;
193203 * sum_sq = ( x - * mean) . mul_add( delta, * sum_sq) ;
194204 } ) ;
195205 }
196- assert ! (
197- !( ddof < A :: zero( ) || ddof > count) ,
198- "`ddof` must not be less than zero or greater than the length of \
199- the axis",
200- ) ;
201- let dof = count - ddof;
202206 sum_sq. mapv_into ( |s| s / dof)
203207 }
204208
@@ -227,8 +231,11 @@ impl<A, S, D> ArrayBase<S, D>
227231 /// n i=1
228232 /// ```
229233 ///
230- /// **Panics** if `ddof` is less than zero or greater than the length of
231- /// the axis or if `axis` is out of bounds.
234+ /// and `n` is the length of the axis.
235+ ///
236+ /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
237+ /// is out of bounds, or if `A::from_usize()` fails for any any of the
238+ /// numbers in the range `0..=n`.
232239 ///
233240 /// # Example
234241 ///
@@ -243,7 +250,7 @@ impl<A, S, D> ArrayBase<S, D>
243250 /// ```
244251 pub fn std_axis ( & self , axis : Axis , ddof : A ) -> Array < A , D :: Smaller >
245252 where
246- A : Float ,
253+ A : Float + FromPrimitive ,
247254 D : RemoveAxis ,
248255 {
249256 self . var_axis ( axis, ddof) . mapv_into ( |x| x. sqrt ( ) )
0 commit comments