@@ -18,6 +18,7 @@ use defmac::defmac;
1818use itertools:: iproduct;
1919use num_complex:: Complex32 ;
2020use num_complex:: Complex64 ;
21+ use num_traits:: Num ;
2122
2223#[ test]
2324fn mat_vec_product_1d ( )
@@ -49,46 +50,29 @@ fn mat_vec_product_1d_inverted_axis()
4950 assert_eq ! ( a. t( ) . dot( & b) , ans) ;
5051}
5152
52- fn range_mat ( m : Ix , n : Ix ) -> Array2 < f32 >
53+ fn range_mat < A : Num + Copy > ( m : Ix , n : Ix ) -> Array2 < A >
5354{
54- Array :: linspace ( 0. , ( m * n) as f32 - 1. , m * n)
55- . into_shape_with_order ( ( m, n) )
56- . unwrap ( )
57- }
58-
59- fn range_mat64 ( m : Ix , n : Ix ) -> Array2 < f64 >
60- {
61- Array :: linspace ( 0. , ( m * n) as f64 - 1. , m * n)
62- . into_shape_with_order ( ( m, n) )
63- . unwrap ( )
55+ ArrayBuilder :: new ( ( m, n) ) . build ( )
6456}
6557
6658fn range_mat_complex ( m : Ix , n : Ix ) -> Array2 < Complex32 >
6759{
68- Array :: linspace ( 0. , ( m * n) as f32 - 1. , m * n)
69- . into_shape_with_order ( ( m, n) )
70- . unwrap ( )
71- . map ( |& f| Complex32 :: new ( f, 0. ) )
60+ ArrayBuilder :: new ( ( m, n) ) . build ( )
7261}
7362
7463fn range_mat_complex64 ( m : Ix , n : Ix ) -> Array2 < Complex64 >
7564{
76- Array :: linspace ( 0. , ( m * n) as f64 - 1. , m * n)
77- . into_shape_with_order ( ( m, n) )
78- . unwrap ( )
79- . map ( |& f| Complex64 :: new ( f, 0. ) )
65+ ArrayBuilder :: new ( ( m, n) ) . build ( )
8066}
8167
8268fn range1_mat64 ( m : Ix ) -> Array1 < f64 >
8369{
84- Array :: linspace ( 0. , m as f64 - 1. , m )
70+ ArrayBuilder :: new ( m ) . build ( )
8571}
8672
8773fn range_i32 ( m : Ix , n : Ix ) -> Array2 < i32 >
8874{
89- Array :: from_iter ( 0 ..( m * n) as i32 )
90- . into_shape_with_order ( ( m, n) )
91- . unwrap ( )
75+ ArrayBuilder :: new ( ( m, n) ) . build ( )
9276}
9377
9478// simple, slow, correct (hopefully) mat mul
@@ -163,8 +147,8 @@ where
163147fn mat_mul_order ( )
164148{
165149 let ( m, n, k) = ( 50 , 50 , 50 ) ;
166- let a = range_mat ( m, n) ;
167- let b = range_mat ( n, k) ;
150+ let a = range_mat :: < f32 > ( m, n) ;
151+ let b = range_mat :: < f32 > ( n, k) ;
168152 let mut af = Array :: zeros ( a. dim ( ) . f ( ) ) ;
169153 let mut bf = Array :: zeros ( b. dim ( ) . f ( ) ) ;
170154 af. assign ( & a) ;
@@ -183,7 +167,7 @@ fn mat_mul_order()
183167fn mat_mul_broadcast ( )
184168{
185169 let ( m, n, k) = ( 16 , 16 , 16 ) ;
186- let a = range_mat ( m, n) ;
170+ let a = range_mat :: < f32 > ( m, n) ;
187171 let x1 = 1. ;
188172 let x = Array :: from ( vec ! [ x1] ) ;
189173 let b0 = x. broadcast ( ( n, k) ) . unwrap ( ) ;
@@ -203,8 +187,8 @@ fn mat_mul_broadcast()
203187fn mat_mul_rev ( )
204188{
205189 let ( m, n, k) = ( 16 , 16 , 16 ) ;
206- let a = range_mat ( m, n) ;
207- let b = range_mat ( n, k) ;
190+ let a = range_mat :: < f32 > ( m, n) ;
191+ let b = range_mat :: < f32 > ( n, k) ;
208192 let mut rev = Array :: zeros ( b. dim ( ) ) ;
209193 let mut rev = rev. slice_mut ( s ! [ ..; -1 , ..] ) ;
210194 rev. assign ( & b) ;
@@ -233,8 +217,8 @@ fn mat_mut_zero_len()
233217 }
234218 }
235219 } ) ;
236- mat_mul_zero_len ! ( range_mat) ;
237- mat_mul_zero_len ! ( range_mat64 ) ;
220+ mat_mul_zero_len ! ( range_mat:: < f32 > ) ;
221+ mat_mul_zero_len ! ( range_mat :: < f64 > ) ;
238222 mat_mul_zero_len ! ( range_i32) ;
239223}
240224
@@ -307,11 +291,11 @@ fn gen_mat_mul()
307291#[ test]
308292fn gemm_64_1_f ( )
309293{
310- let a = range_mat64 ( 64 , 64 ) . reversed_axes ( ) ;
294+ let a = range_mat :: < f64 > ( 64 , 64 ) . reversed_axes ( ) ;
311295 let ( m, n) = a. dim ( ) ;
312296 // m x n times n x 1 == m x 1
313- let x = range_mat64 ( n, 1 ) ;
314- let mut y = range_mat64 ( m, 1 ) ;
297+ let x = range_mat :: < f64 > ( n, 1 ) ;
298+ let mut y = range_mat :: < f64 > ( m, 1 ) ;
315299 let answer = reference_mat_mul ( & a, & x) + & y;
316300 general_mat_mul ( 1.0 , & a, & x, 1.0 , & mut y) ;
317301 assert_relative_eq ! ( y, answer, epsilon = 1e-12 , max_relative = 1e-7 ) ;
@@ -393,11 +377,8 @@ fn gen_mat_vec_mul()
393377 for & s1 in & [ 1 , 2 , -1 , -2 ] {
394378 for & s2 in & [ 1 , 2 , -1 , -2 ] {
395379 for & ( m, k) in & sizes {
396- for & rev in & [ false , true ] {
397- let mut a = range_mat64 ( m, k) ;
398- if rev {
399- a = a. reversed_axes ( ) ;
400- }
380+ for order in [ Order :: C , Order :: F ] {
381+ let a = ArrayBuilder :: new ( ( m, k) ) . memory_order ( order) . build ( ) ;
401382 let ( m, k) = a. dim ( ) ;
402383 let b = range1_mat64 ( k) ;
403384 let mut c = range1_mat64 ( m) ;
@@ -438,11 +419,8 @@ fn vec_mat_mul()
438419 for & s1 in & [ 1 , 2 , -1 , -2 ] {
439420 for & s2 in & [ 1 , 2 , -1 , -2 ] {
440421 for & ( m, n) in & sizes {
441- for & rev in & [ false , true ] {
442- let mut b = range_mat64 ( m, n) ;
443- if rev {
444- b = b. reversed_axes ( ) ;
445- }
422+ for order in [ Order :: C , Order :: F ] {
423+ let b = ArrayBuilder :: new ( ( m, n) ) . memory_order ( order) . build ( ) ;
446424 let ( m, n) = b. dim ( ) ;
447425 let a = range1_mat64 ( m) ;
448426 let mut c = range1_mat64 ( n) ;
0 commit comments