@@ -11,6 +11,7 @@ use ndarray::{
1111 Data ,
1212 LinalgScalar ,
1313} ;
14+ use ndarray:: linalg:: general_mat_mul;
1415
1516use rand:: distributions:: Normal ;
1617
@@ -161,6 +162,42 @@ fn accurate_mul_f32() {
161162 }
162163}
163164
165+ #[ test]
166+ fn accurate_mul_f32_general ( ) {
167+ // pick a few random sizes
168+ let mut rng = SmallRng :: from_entropy ( ) ;
169+ for i in 0 ..20 {
170+ let m = rng. gen_range ( 15 , 512 ) ;
171+ let k = rng. gen_range ( 15 , 512 ) ;
172+ let n = rng. gen_range ( 15 , 1560 ) ;
173+ let a = gen ( Ix2 ( m, k) ) ;
174+ let b = gen ( Ix2 ( n, k) ) ;
175+ let mut c = gen ( Ix2 ( m, n) ) ;
176+ let b = b. t ( ) ;
177+ let ( a, b, mut c) = if i > 10 {
178+ ( a. slice ( s ! [ ..; 2 , ..; 2 ] ) ,
179+ b. slice ( s ! [ ..; 2 , ..; 2 ] ) ,
180+ c. slice_mut ( s ! [ ..; 2 , ..; 2 ] ) )
181+ } else { ( a. view ( ) , b, c. view_mut ( ) ) } ;
182+
183+ println ! ( "Testing size {} by {} by {}" , a. shape( ) [ 0 ] , a. shape( ) [ 1 ] , b. shape( ) [ 1 ] ) ;
184+ general_mat_mul ( 1. , & a, & b, 0. , & mut c) ;
185+ let reference = reference_mat_mul ( & a, & b) ;
186+ let diff = ( & c - & reference) . mapv_into ( f32:: abs) ;
187+
188+ let rtol = 1e-3 ;
189+ let atol = 1e-4 ;
190+ let crtol = c. mapv ( |x| x. abs ( ) * rtol) ;
191+ let tol = crtol + atol;
192+ let tol_m_diff = & diff - & tol;
193+ let maxdiff = * tol_m_diff. max ( ) ;
194+ println ! ( "diff offset from tolerance level= {:.2e}" , maxdiff) ;
195+ if maxdiff > 0. {
196+ panic ! ( "results differ" ) ;
197+ }
198+ }
199+ }
200+
164201#[ test]
165202fn accurate_mul_f64 ( ) {
166203 // pick a few random sizes
@@ -195,6 +232,41 @@ fn accurate_mul_f64() {
195232 }
196233}
197234
235+ #[ test]
236+ fn accurate_mul_f64_general ( ) {
237+ // pick a few random sizes
238+ let mut rng = SmallRng :: from_entropy ( ) ;
239+ for i in 0 ..20 {
240+ let m = rng. gen_range ( 15 , 512 ) ;
241+ let k = rng. gen_range ( 15 , 512 ) ;
242+ let n = rng. gen_range ( 15 , 1560 ) ;
243+ let a = gen_f64 ( Ix2 ( m, k) ) ;
244+ let b = gen_f64 ( Ix2 ( n, k) ) ;
245+ let mut c = gen_f64 ( Ix2 ( m, n) ) ;
246+ let b = b. t ( ) ;
247+ let ( a, b, mut c) = if i > 10 {
248+ ( a. slice ( s ! [ ..; 2 , ..; 2 ] ) ,
249+ b. slice ( s ! [ ..; 2 , ..; 2 ] ) ,
250+ c. slice_mut ( s ! [ ..; 2 , ..; 2 ] ) )
251+ } else { ( a. view ( ) , b, c. view_mut ( ) ) } ;
252+
253+ println ! ( "Testing size {} by {} by {}" , a. shape( ) [ 0 ] , a. shape( ) [ 1 ] , b. shape( ) [ 1 ] ) ;
254+ general_mat_mul ( 1. , & a, & b, 0. , & mut c) ;
255+ let reference = reference_mat_mul ( & a, & b) ;
256+ let diff = ( & c - & reference) . mapv_into ( f64:: abs) ;
257+
258+ let rtol = 1e-7 ;
259+ let atol = 1e-12 ;
260+ let crtol = c. mapv ( |x| x. abs ( ) * rtol) ;
261+ let tol = crtol + atol;
262+ let tol_m_diff = & diff - & tol;
263+ let maxdiff = * tol_m_diff. max ( ) ;
264+ println ! ( "diff offset from tolerance level= {:.2e}" , maxdiff) ;
265+ if maxdiff > 0. {
266+ panic ! ( "results differ" ) ;
267+ }
268+ }
269+ }
198270
199271#[ test]
200272fn accurate_mul_with_column_f64 ( ) {
0 commit comments