Skip to content

Commit ed24fad

Browse files
committed
Change mean_axis to use FromPrimitive
The documentation for the `Zero` and `One` traits says only that they are the additive and multiplicative identities; it doesn't say anything about converting an integer to a float by adding `One::one()` to `Zero::zero()` repeatedly. Additionally, it's nice to convert the length to `A` directly instead of having to use a loop.
1 parent f7fb81f commit ed24fad

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

src/numeric/impl_numeric.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
// except according to those terms.
88

99
use std::ops::{Add, Div, Mul};
10-
use libnum::{self, One, Zero, Float};
10+
use libnum::{self, Zero, Float, FromPrimitive};
1111
use itertools::free::enumerate;
1212

1313
use imp_prelude::*;
@@ -112,8 +112,9 @@ impl<A, S, D> ArrayBase<S, D>
112112

113113
/// Return mean along `axis`.
114114
///
115-
/// **Panics** if `axis` is out of bounds or if the length of the axis is
116-
/// zero and division by zero panics for type `A`.
115+
/// **Panics** if `axis` is out of bounds, if the length of the axis is
116+
/// zero and division by zero panics for type `A`, or if `A::from_usize()`
117+
/// fails for the axis length.
117118
///
118119
/// ```
119120
/// use ndarray::{aview1, arr2, Axis};
@@ -126,16 +127,12 @@ impl<A, S, D> ArrayBase<S, D>
126127
/// );
127128
/// ```
128129
pub fn mean_axis(&self, axis: Axis) -> Array<A, D::Smaller>
129-
where A: Clone + Zero + One + Add<Output=A> + Div<Output=A>,
130+
where A: Clone + Zero + FromPrimitive + Add<Output=A> + Div<Output=A>,
130131
D: RemoveAxis,
131132
{
132-
let n = self.len_of(axis);
133+
let n = A::from_usize(self.len_of(axis)).expect("Converting axis length to `A` must not fail.");
133134
let sum = self.sum_axis(axis);
134-
let mut cnt = A::zero();
135-
for _ in 0..n {
136-
cnt = cnt + A::one();
137-
}
138-
sum / &aview0(&cnt)
135+
sum / &aview0(&n)
139136
}
140137

141138
/// Return variance along `axis`.

0 commit comments

Comments
 (0)