11//! Information theory (e.g. entropy, KL divergence, etc.).
2- use crate :: errors:: ShapeMismatch ;
2+ use crate :: errors:: { EmptyInput , MultiInputError , ShapeMismatch } ;
33use ndarray:: { Array , ArrayBase , Data , Dimension , Zip } ;
44use num_traits:: Float ;
55
1919 /// i=1
2020 /// ```
2121 ///
22- /// If the array is empty, `None ` is returned.
22+ /// If the array is empty, `Err(EmptyInput) ` is returned.
2323 ///
2424 /// **Panics** if `ln` of any element in the array panics (which can occur for negative values for some `A`).
2525 ///
3838 ///
3939 /// [entropy]: https://en.wikipedia.org/wiki/Entropy_(information_theory)
4040 /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
41- fn entropy ( & self ) -> Option < A >
41+ fn entropy ( & self ) -> Result < A , EmptyInput >
4242 where
4343 A : Float ;
4444
5353 /// i=1
5454 /// ```
5555 ///
56- /// If the arrays are empty, Ok(`None`) is returned.
57- /// If the array shapes are not identical, `Err(ShapeMismatch)` is returned.
56+ /// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned.
57+ /// If the array shapes are not identical,
58+ /// `Err(MultiInputError::ShapeMismatch)` is returned.
5859 ///
5960 /// **Panics** if, for a pair of elements *(pᵢ, qᵢ)* from *p* and *q*, computing
6061 /// *ln(qᵢ/pᵢ)* is a panic cause for `A`.
7374 ///
7475 /// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
7576 /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
76- fn kl_divergence < S2 > ( & self , q : & ArrayBase < S2 , D > ) -> Result < Option < A > , ShapeMismatch >
77+ fn kl_divergence < S2 > ( & self , q : & ArrayBase < S2 , D > ) -> Result < A , MultiInputError >
7778 where
7879 S2 : Data < Elem = A > ,
7980 A : Float ;
8990 /// i=1
9091 /// ```
9192 ///
92- /// If the arrays are empty, Ok(`None`) is returned.
93- /// If the array shapes are not identical, `Err(ShapeMismatch)` is returned.
93+ /// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned.
94+ /// If the array shapes are not identical,
95+ /// `Err(MultiInputError::ShapeMismatch)` is returned.
9496 ///
9597 /// **Panics** if any element in *q* is negative and taking the logarithm of a negative number
9698 /// is a panic cause for `A`.
@@ -114,7 +116,7 @@ where
114116 /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
115117 /// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method
116118 /// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression
117- fn cross_entropy < S2 > ( & self , q : & ArrayBase < S2 , D > ) -> Result < Option < A > , ShapeMismatch >
119+ fn cross_entropy < S2 > ( & self , q : & ArrayBase < S2 , D > ) -> Result < A , MultiInputError >
118120 where
119121 S2 : Data < Elem = A > ,
120122 A : Float ;
@@ -125,14 +127,14 @@ where
125127 S : Data < Elem = A > ,
126128 D : Dimension ,
127129{
128- fn entropy ( & self ) -> Option < A >
130+ fn entropy ( & self ) -> Result < A , EmptyInput >
129131 where
130132 A : Float ,
131133 {
132134 if self . len ( ) == 0 {
133- None
135+ Err ( EmptyInput )
134136 } else {
135- let entropy = self
137+ let entropy = - self
136138 . mapv ( |x| {
137139 if x == A :: zero ( ) {
138140 A :: zero ( )
@@ -141,23 +143,24 @@ where
141143 }
142144 } )
143145 . sum ( ) ;
144- Some ( - entropy)
146+ Ok ( entropy)
145147 }
146148 }
147149
148- fn kl_divergence < S2 > ( & self , q : & ArrayBase < S2 , D > ) -> Result < Option < A > , ShapeMismatch >
150+ fn kl_divergence < S2 > ( & self , q : & ArrayBase < S2 , D > ) -> Result < A , MultiInputError >
149151 where
150152 A : Float ,
151153 S2 : Data < Elem = A > ,
152154 {
153155 if self . len ( ) == 0 {
154- return Ok ( None ) ;
156+ return Err ( MultiInputError :: EmptyInput ) ;
155157 }
156158 if self . shape ( ) != q. shape ( ) {
157159 return Err ( ShapeMismatch {
158160 first_shape : self . shape ( ) . to_vec ( ) ,
159161 second_shape : q. shape ( ) . to_vec ( ) ,
160- } ) ;
162+ }
163+ . into ( ) ) ;
161164 }
162165
163166 let mut temp = Array :: zeros ( self . raw_dim ( ) ) ;
@@ -174,22 +177,23 @@ where
174177 }
175178 } ) ;
176179 let kl_divergence = -temp. sum ( ) ;
177- Ok ( Some ( kl_divergence) )
180+ Ok ( kl_divergence)
178181 }
179182
180- fn cross_entropy < S2 > ( & self , q : & ArrayBase < S2 , D > ) -> Result < Option < A > , ShapeMismatch >
183+ fn cross_entropy < S2 > ( & self , q : & ArrayBase < S2 , D > ) -> Result < A , MultiInputError >
181184 where
182185 S2 : Data < Elem = A > ,
183186 A : Float ,
184187 {
185188 if self . len ( ) == 0 {
186- return Ok ( None ) ;
189+ return Err ( MultiInputError :: EmptyInput ) ;
187190 }
188191 if self . shape ( ) != q. shape ( ) {
189192 return Err ( ShapeMismatch {
190193 first_shape : self . shape ( ) . to_vec ( ) ,
191194 second_shape : q. shape ( ) . to_vec ( ) ,
192- } ) ;
195+ }
196+ . into ( ) ) ;
193197 }
194198
195199 let mut temp = Array :: zeros ( self . raw_dim ( ) ) ;
@@ -206,15 +210,15 @@ where
206210 }
207211 } ) ;
208212 let cross_entropy = -temp. sum ( ) ;
209- Ok ( Some ( cross_entropy) )
213+ Ok ( cross_entropy)
210214 }
211215}
212216
213217#[ cfg( test) ]
214218mod tests {
215219 use super :: EntropyExt ;
216220 use approx:: assert_abs_diff_eq;
217- use errors:: ShapeMismatch ;
221+ use errors:: { EmptyInput , MultiInputError } ;
218222 use ndarray:: { array, Array1 } ;
219223 use noisy_float:: types:: n64;
220224 use std:: f64;
@@ -228,7 +232,7 @@ mod tests {
228232 #[ test]
229233 fn test_entropy_with_empty_array_of_floats ( ) {
230234 let a: Array1 < f64 > = array ! [ ] ;
231- assert ! ( a. entropy( ) . is_none ( ) ) ;
235+ assert_eq ! ( a. entropy( ) , Err ( EmptyInput ) ) ;
232236 }
233237
234238 #[ test]
@@ -251,13 +255,13 @@ mod tests {
251255 }
252256
253257 #[ test]
254- fn test_cross_entropy_and_kl_with_nan_values ( ) -> Result < ( ) , ShapeMismatch > {
258+ fn test_cross_entropy_and_kl_with_nan_values ( ) -> Result < ( ) , MultiInputError > {
255259 let a = array ! [ f64 :: NAN , 1. ] ;
256260 let b = array ! [ 2. , 1. ] ;
257- assert ! ( a. cross_entropy( & b) ?. unwrap ( ) . is_nan( ) ) ;
258- assert ! ( b. cross_entropy( & a) ?. unwrap ( ) . is_nan( ) ) ;
259- assert ! ( a. kl_divergence( & b) ?. unwrap ( ) . is_nan( ) ) ;
260- assert ! ( b. kl_divergence( & a) ?. unwrap ( ) . is_nan( ) ) ;
261+ assert ! ( a. cross_entropy( & b) ?. is_nan( ) ) ;
262+ assert ! ( b. cross_entropy( & a) ?. is_nan( ) ) ;
263+ assert ! ( a. kl_divergence( & b) ?. is_nan( ) ) ;
264+ assert ! ( b. kl_divergence( & a) ?. is_nan( ) ) ;
261265 Ok ( ( ) )
262266 }
263267
@@ -284,20 +288,19 @@ mod tests {
284288 }
285289
286290 #[ test]
287- fn test_cross_entropy_and_kl_with_empty_array_of_floats ( ) -> Result < ( ) , ShapeMismatch > {
291+ fn test_cross_entropy_and_kl_with_empty_array_of_floats ( ) {
288292 let p: Array1 < f64 > = array ! [ ] ;
289293 let q: Array1 < f64 > = array ! [ ] ;
290- assert ! ( p. cross_entropy( & q) ?. is_none( ) ) ;
291- assert ! ( p. kl_divergence( & q) ?. is_none( ) ) ;
292- Ok ( ( ) )
294+ assert ! ( p. cross_entropy( & q) . unwrap_err( ) . is_empty_input( ) ) ;
295+ assert ! ( p. kl_divergence( & q) . unwrap_err( ) . is_empty_input( ) ) ;
293296 }
294297
295298 #[ test]
296- fn test_cross_entropy_and_kl_with_negative_qs ( ) -> Result < ( ) , ShapeMismatch > {
299+ fn test_cross_entropy_and_kl_with_negative_qs ( ) -> Result < ( ) , MultiInputError > {
297300 let p = array ! [ 1. ] ;
298301 let q = array ! [ -1. ] ;
299- let cross_entropy: f64 = p. cross_entropy ( & q) ?. unwrap ( ) ;
300- let kl_divergence: f64 = p. kl_divergence ( & q) ?. unwrap ( ) ;
302+ let cross_entropy: f64 = p. cross_entropy ( & q) ?;
303+ let kl_divergence: f64 = p. kl_divergence ( & q) ?;
301304 assert ! ( cross_entropy. is_nan( ) ) ;
302305 assert ! ( kl_divergence. is_nan( ) ) ;
303306 Ok ( ( ) )
@@ -320,26 +323,26 @@ mod tests {
320323 }
321324
322325 #[ test]
323- fn test_cross_entropy_and_kl_with_zeroes_p ( ) -> Result < ( ) , ShapeMismatch > {
326+ fn test_cross_entropy_and_kl_with_zeroes_p ( ) -> Result < ( ) , MultiInputError > {
324327 let p = array ! [ 0. , 0. ] ;
325328 let q = array ! [ 0. , 0.5 ] ;
326- assert_eq ! ( p. cross_entropy( & q) ?. unwrap ( ) , 0. ) ;
327- assert_eq ! ( p. kl_divergence( & q) ?. unwrap ( ) , 0. ) ;
329+ assert_eq ! ( p. cross_entropy( & q) ?, 0. ) ;
330+ assert_eq ! ( p. kl_divergence( & q) ?, 0. ) ;
328331 Ok ( ( ) )
329332 }
330333
331334 #[ test]
332335 fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership (
333- ) -> Result < ( ) , ShapeMismatch > {
336+ ) -> Result < ( ) , MultiInputError > {
334337 let p = array ! [ 0.5 , 0.5 ] ;
335338 let mut q = array ! [ 0.5 , 0. ] ;
336- assert_eq ! ( p. cross_entropy( & q. view_mut( ) ) ?. unwrap ( ) , f64 :: INFINITY ) ;
337- assert_eq ! ( p. kl_divergence( & q. view_mut( ) ) ?. unwrap ( ) , f64 :: INFINITY ) ;
339+ assert_eq ! ( p. cross_entropy( & q. view_mut( ) ) ?, f64 :: INFINITY ) ;
340+ assert_eq ! ( p. kl_divergence( & q. view_mut( ) ) ?, f64 :: INFINITY ) ;
338341 Ok ( ( ) )
339342 }
340343
341344 #[ test]
342- fn test_cross_entropy ( ) -> Result < ( ) , ShapeMismatch > {
345+ fn test_cross_entropy ( ) -> Result < ( ) , MultiInputError > {
343346 // Arrays of probability values - normalized and positive.
344347 let p: Array1 < f64 > = array ! [
345348 0.05340169 , 0.02508511 , 0.03460454 , 0.00352313 , 0.07837615 , 0.05859495 , 0.05782189 ,
@@ -356,16 +359,12 @@ mod tests {
356359 // Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q)
357360 let expected_cross_entropy = 3.385347705020779 ;
358361
359- assert_abs_diff_eq ! (
360- p. cross_entropy( & q) ?. unwrap( ) ,
361- expected_cross_entropy,
362- epsilon = 1e-6
363- ) ;
362+ assert_abs_diff_eq ! ( p. cross_entropy( & q) ?, expected_cross_entropy, epsilon = 1e-6 ) ;
364363 Ok ( ( ) )
365364 }
366365
367366 #[ test]
368- fn test_kl ( ) -> Result < ( ) , ShapeMismatch > {
367+ fn test_kl ( ) -> Result < ( ) , MultiInputError > {
369368 // Arrays of probability values - normalized and positive.
370369 let p: Array1 < f64 > = array ! [
371370 0.00150472 , 0.01388706 , 0.03495376 , 0.03264211 , 0.03067355 , 0.02183501 , 0.00137516 ,
@@ -390,7 +389,7 @@ mod tests {
390389 // Computed using scipy.stats.entropy(p, q)
391390 let expected_kl = 0.3555862567800096 ;
392391
393- assert_abs_diff_eq ! ( p. kl_divergence( & q) ?. unwrap ( ) , expected_kl, epsilon = 1e-6 ) ;
392+ assert_abs_diff_eq ! ( p. kl_divergence( & q) ?, expected_kl, epsilon = 1e-6 ) ;
394393 Ok ( ( ) )
395394 }
396395}
0 commit comments