1- //! Element-wise methods for ndarray
1+ // Element-wise methods for ndarray
22
3+ #[ cfg( feature = "std" ) ]
34use num_traits:: Float ;
45
56use crate :: imp_prelude:: * ;
67
7- macro_rules! boolean_op {
8- ( $( $( #[ $meta1: meta] ) * fn $id1: ident $( #[ $meta2: meta] ) * fn $id2: ident -> $func: ident) +) => {
9- $( $( #[ $meta1] ) *
8+ #[ cfg( feature = "std" ) ]
9+ macro_rules! boolean_ops {
10+ ( $( #[ $meta1: meta] ) * fn $func: ident
11+ $( #[ $meta2: meta] ) * fn $all: ident
12+ $( #[ $meta3: meta] ) * fn $any: ident) => {
13+ $( #[ $meta1] ) *
1014 #[ must_use = "method returns a new array and does not mutate the original value" ]
11- pub fn $id1 ( & self ) -> Array <bool , D > {
15+ pub fn $func ( & self ) -> Array <bool , D > {
1216 self . mapv( A :: $func)
1317 }
1418 $( #[ $meta2] ) *
1519 #[ must_use = "method returns a new boolean value and does not mutate the original value" ]
16- pub fn $id2( & self ) -> bool {
17- self . mapv( A :: $func) . iter( ) . any( |& b| b)
18- } ) +
20+ pub fn $all( & self ) -> bool {
21+ $crate:: Zip :: from( self ) . all( |& elt| !elt. $func( ) )
22+ }
23+ $( #[ $meta3] ) *
24+ #[ must_use = "method returns a new boolean value and does not mutate the original value" ]
25+ pub fn $any( & self ) -> bool {
26+ !self . $all( )
27+ }
1928 } ;
2029}
2130
22- macro_rules! unary_op {
31+ #[ cfg( feature = "std" ) ]
32+ macro_rules! unary_ops {
2333 ( $( $( #[ $meta: meta] ) * fn $id: ident) +) => {
2434 $( $( #[ $meta] ) *
2535 #[ must_use = "method returns a new array and does not mutate the original value" ]
@@ -29,7 +39,8 @@ macro_rules! unary_op {
2939 } ;
3040}
3141
32- macro_rules! binary_op {
42+ #[ cfg( feature = "std" ) ]
43+ macro_rules! binary_ops {
3344 ( $( $( #[ $meta: meta] ) * fn $id: ident( $ty: ty) ) +) => {
3445 $( $( #[ $meta] ) *
3546 #[ must_use = "method returns a new array and does not mutate the original value" ]
@@ -39,103 +50,87 @@ macro_rules! binary_op {
3950 } ;
4051}
4152
42- /// # Element-wise methods for Float Array
53+ /// # Element-wise methods for float arrays
4354///
4455/// Element-wise math functions for any array type that contains float number.
56+ #[ cfg( feature = "std" ) ]
4557impl < A , S , D > ArrayBase < S , D >
4658where
47- A : Float ,
59+ A : ' static + Float ,
4860 S : Data < Elem = A > ,
4961 D : Dimension ,
5062{
51- boolean_op ! {
63+ boolean_ops ! {
5264 /// If the number is `NaN` (not a number), then `true` is returned for each element.
5365 fn is_nan
66+ /// Return `true` if all elements are `NaN` (not a number).
67+ fn is_all_nan
5468 /// Return `true` if any element is `NaN` (not a number).
55- fn is_any_nan -> is_nan
56-
69+ fn is_any_nan
70+ }
71+ boolean_ops ! {
5772 /// If the number is infinity, then `true` is returned for each element.
5873 fn is_infinite
74+ /// Return `true` if all elements are infinity.
75+ fn is_all_infinite
5976 /// Return `true` if any element is infinity.
60- fn is_any_infinite -> is_infinite
77+ fn is_any_infinite
6178 }
62- unary_op ! {
79+ unary_ops ! {
6380 /// The largest integer less than or equal to each element.
6481 fn floor
65-
6682 /// The smallest integer less than or equal to each element.
6783 fn ceil
68-
6984 /// The nearest integer of each element.
7085 fn round
71-
7286 /// The integer part of each element.
7387 fn trunc
74-
7588 /// The fractional part of each element.
7689 fn fract
77-
7890 /// Absolute of each element.
7991 fn abs
80-
8192 /// Sign number of each element.
8293 ///
8394 /// + `1.0` for all positive numbers.
8495 /// + `-1.0` for all negative numbers.
8596 /// + `NaN` for all `NaN` (not a number).
8697 fn signum
87-
8898 /// The reciprocal (inverse) of each element, `1/x`.
8999 fn recip
90-
91100 /// Square root of each element.
92101 fn sqrt
93-
94102 /// `e^x` of each element (exponential function).
95103 fn exp
96-
97104 /// `2^x` of each element.
98105 fn exp2
99-
100106 /// Natural logarithm of each element.
101107 fn ln
102-
103108 /// Base 2 logarithm of each element.
104109 fn log2
105-
106110 /// Base 10 logarithm of each element.
107111 fn log10
108-
109112 /// Cubic root of each element.
110113 fn cbrt
111-
112114 /// Sine of each element (in radians).
113115 fn sin
114-
115116 /// Cosine of each element (in radians).
116117 fn cos
117-
118118 /// Tangent of each element (in radians).
119119 fn tan
120-
121120 /// Converts radians to degrees for each element.
122121 fn to_degrees
123-
124122 /// Converts degrees to radians for each element.
125123 fn to_radians
126124 }
127- binary_op ! {
125+ binary_ops ! {
128126 /// Integer power of each element.
129127 ///
130128 /// This function is generally faster than using float power.
131129 fn powi( i32 )
132-
133130 /// Float power of each element.
134131 fn powf( A )
135-
136132 /// Logarithm of each element with respect to an arbitrary base.
137133 fn log( A )
138-
139134 /// The positive difference between given number and each element.
140135 fn abs_sub( A )
141136 }
@@ -145,18 +140,37 @@ where
145140 pub fn pow2 ( & self ) -> Array < A , D > {
146141 self . mapv ( |v : A | v * v)
147142 }
143+ }
148144
149- /// Limit the values for each element.
145+ impl < A , S , D > ArrayBase < S , D >
146+ where
147+ A : ' static + PartialOrd + Clone ,
148+ S : Data < Elem = A > ,
149+ D : Dimension ,
150+ {
151+ /// Limit the values for each element, similar to NumPy's `clip` function.
150152 ///
151153 /// ```
152- /// use ndarray::{Array1, array} ;
154+ /// use ndarray::array;
153155 ///
154- /// let a = Array1::range(0., 10., 1.);
155- /// assert_eq!(a.clip(1., 8.), array![1., 1., 2., 3., 4., 5., 6., 7., 8., 8.]);
156- /// assert_eq!(a.clip(8., 1.), array![1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]);
157- /// assert_eq!(a.clip(3., 6.), array![3., 3., 3., 3., 4., 5., 6., 6., 6., 6.]);
156+ /// let a = array![0., 1., 2., 3., 4., 5., 6., 7., 8., 9.];
157+ /// assert_eq!(a.clamp(1., 8.), array![1., 1., 2., 3., 4., 5., 6., 7., 8., 8.]);
158+ /// assert_eq!(a.clamp(3., 6.), array![3., 3., 3., 3., 4., 5., 6., 6., 6., 6.]);
158159 /// ```
159- pub fn clip ( & self , min : A , max : A ) -> Array < A , D > {
160- self . mapv ( |v| A :: min ( A :: max ( v, min) , max) )
160+ ///
161+ /// # Panics
162+ ///
163+ /// Panics if `min > max`, `min` is `NaN`, or `max` is `NaN`.
164+ pub fn clamp ( & self , min : A , max : A ) -> Array < A , D > {
165+ assert ! ( min <= max, "min must be less than or equal to max" ) ;
166+ self . mapv ( |v| {
167+ if v < min {
168+ min. clone ( )
169+ } else if v > max {
170+ max. clone ( )
171+ } else {
172+ v
173+ }
174+ } )
161175 }
162176}
0 commit comments