@@ -4,64 +4,15 @@ extern crate ndarray;
44extern crate num_traits;
55extern crate blas_src;
66
7+ use ndarray:: prelude:: * ;
8+
79use ndarray:: linalg:: general_mat_mul;
810use ndarray:: linalg:: general_mat_vec_mul;
9- use ndarray:: prelude:: * ;
10- use ndarray:: { Data , Ix , Ixs , LinalgScalar , Slice , SliceInfoElem } ;
11+ use ndarray:: { Data , Ix , LinalgScalar } ;
1112
12- use approx:: { assert_abs_diff_eq , assert_relative_eq} ;
13+ use approx:: assert_relative_eq;
1314use defmac:: defmac;
1415
15- fn reference_dot < ' a , A , V1 , V2 > ( a : V1 , b : V2 ) -> A
16- where
17- A : NdFloat ,
18- V1 : AsArray < ' a , A > ,
19- V2 : AsArray < ' a , A > ,
20- {
21- let a = a. into ( ) ;
22- let b = b. into ( ) ;
23- a. iter ( )
24- . zip ( b. iter ( ) )
25- . fold ( A :: zero ( ) , |acc, ( & x, & y) | acc + x * y)
26- }
27-
28- #[ test]
29- fn dot_product ( ) {
30- let a = Array :: range ( 0. , 69. , 1. ) ;
31- let b = & a * 2. - 7. ;
32- let dot = 197846. ;
33- assert_abs_diff_eq ! ( a. dot( & b) , reference_dot( & a, & b) , epsilon = 1e-5 ) ;
34-
35- // test different alignments
36- let max = 8 as Ixs ;
37- for i in 1 ..max {
38- let a1 = a. slice ( s ! [ i..] ) ;
39- let b1 = b. slice ( s ! [ i..] ) ;
40- assert_abs_diff_eq ! ( a1. dot( & b1) , reference_dot( & a1, & b1) , epsilon = 1e-5 ) ;
41- let a2 = a. slice ( s ! [ ..-i] ) ;
42- let b2 = b. slice ( s ! [ i..] ) ;
43- assert_abs_diff_eq ! ( a2. dot( & b2) , reference_dot( & a2, & b2) , epsilon = 1e-5 ) ;
44- }
45-
46- let a = a. map ( |f| * f as f32 ) ;
47- let b = b. map ( |f| * f as f32 ) ;
48- assert_abs_diff_eq ! ( a. dot( & b) , dot as f32 , epsilon = 1e-5 ) ;
49-
50- let max = 8 as Ixs ;
51- for i in 1 ..max {
52- let a1 = a. slice ( s ! [ i..] ) ;
53- let b1 = b. slice ( s ! [ i..] ) ;
54- assert_abs_diff_eq ! ( a1. dot( & b1) , reference_dot( & a1, & b1) , epsilon = 1e-5 ) ;
55- let a2 = a. slice ( s ! [ ..-i] ) ;
56- let b2 = b. slice ( s ! [ i..] ) ;
57- assert_abs_diff_eq ! ( a2. dot( & b2) , reference_dot( & a2, & b2) , epsilon = 1e-5 ) ;
58- }
59-
60- let a = a. map ( |f| * f as i32 ) ;
61- let b = b. map ( |f| * f as i32 ) ;
62- assert_eq ! ( a. dot( & b) , dot as i32 ) ;
63- }
64-
6516#[ test]
6617fn mat_vec_product_1d ( ) {
6718 let a = arr2 ( & [ [ 1. ] , [ 2. ] ] ) ;
@@ -70,46 +21,6 @@ fn mat_vec_product_1d() {
7021 assert_eq ! ( a. t( ) . dot( & b) , ans) ;
7122}
7223
73- // test that we can dot product with a broadcast array
74- #[ test]
75- fn dot_product_0 ( ) {
76- let a = Array :: range ( 0. , 69. , 1. ) ;
77- let x = 1.5 ;
78- let b = aview0 ( & x) ;
79- let b = b. broadcast ( a. dim ( ) ) . unwrap ( ) ;
80- assert_abs_diff_eq ! ( a. dot( & b) , reference_dot( & a, & b) , epsilon = 1e-5 ) ;
81-
82- // test different alignments
83- let max = 8 as Ixs ;
84- for i in 1 ..max {
85- let a1 = a. slice ( s ! [ i..] ) ;
86- let b1 = b. slice ( s ! [ i..] ) ;
87- assert_abs_diff_eq ! ( a1. dot( & b1) , reference_dot( & a1, & b1) , epsilon = 1e-5 ) ;
88- let a2 = a. slice ( s ! [ ..-i] ) ;
89- let b2 = b. slice ( s ! [ i..] ) ;
90- assert_abs_diff_eq ! ( a2. dot( & b2) , reference_dot( & a2, & b2) , epsilon = 1e-5 ) ;
91- }
92- }
93-
94- #[ test]
95- fn dot_product_neg_stride ( ) {
96- // test that we can dot with negative stride
97- let a = Array :: range ( 0. , 69. , 1. ) ;
98- let b = & a * 2. - 7. ;
99- for stride in -10 ..0 {
100- // both negative
101- let a = a. slice ( s ! [ ..; stride] ) ;
102- let b = b. slice ( s ! [ ..; stride] ) ;
103- assert_abs_diff_eq ! ( a. dot( & b) , reference_dot( & a, & b) , epsilon = 1e-5 ) ;
104- }
105- for stride in -10 ..0 {
106- // mixed
107- let a = a. slice ( s ! [ ..; -stride] ) ;
108- let b = b. slice ( s ! [ ..; stride] ) ;
109- assert_abs_diff_eq ! ( a. dot( & b) , reference_dot( & a, & b) , epsilon = 1e-5 ) ;
110- }
111- }
112-
11324fn range_mat ( m : Ix , n : Ix ) -> Array2 < f32 > {
11425 Array :: linspace ( 0. , ( m * n) as f32 - 1. , m * n)
11526 . into_shape ( ( m, n) )
@@ -190,71 +101,11 @@ where
190101 . unwrap ( )
191102}
192103
193- #[ test]
194- fn mat_mul ( ) {
195- let ( m, n, k) = ( 8 , 8 , 8 ) ;
196- let a = range_mat ( m, n) ;
197- let b = range_mat ( n, k) ;
198- let mut b = b / 4. ;
199- {
200- let mut c = b. column_mut ( 0 ) ;
201- c += 1.0 ;
202- }
203- let ab = a. dot ( & b) ;
204-
205- let mut af = Array :: zeros ( a. dim ( ) . f ( ) ) ;
206- let mut bf = Array :: zeros ( b. dim ( ) . f ( ) ) ;
207- af. assign ( & a) ;
208- bf. assign ( & b) ;
209-
210- assert_eq ! ( ab, a. dot( & bf) ) ;
211- assert_eq ! ( ab, af. dot( & b) ) ;
212- assert_eq ! ( ab, af. dot( & bf) ) ;
213-
214- let ( m, n, k) = ( 10 , 5 , 11 ) ;
215- let a = range_mat ( m, n) ;
216- let b = range_mat ( n, k) ;
217- let mut b = b / 4. ;
218- {
219- let mut c = b. column_mut ( 0 ) ;
220- c += 1.0 ;
221- }
222- let ab = a. dot ( & b) ;
223-
224- let mut af = Array :: zeros ( a. dim ( ) . f ( ) ) ;
225- let mut bf = Array :: zeros ( b. dim ( ) . f ( ) ) ;
226- af. assign ( & a) ;
227- bf. assign ( & b) ;
228-
229- assert_eq ! ( ab, a. dot( & bf) ) ;
230- assert_eq ! ( ab, af. dot( & b) ) ;
231- assert_eq ! ( ab, af. dot( & bf) ) ;
232-
233- let ( m, n, k) = ( 10 , 8 , 1 ) ;
234- let a = range_mat ( m, n) ;
235- let b = range_mat ( n, k) ;
236- let mut b = b / 4. ;
237- {
238- let mut c = b. column_mut ( 0 ) ;
239- c += 1.0 ;
240- }
241- let ab = a. dot ( & b) ;
242-
243- let mut af = Array :: zeros ( a. dim ( ) . f ( ) ) ;
244- let mut bf = Array :: zeros ( b. dim ( ) . f ( ) ) ;
245- af. assign ( & a) ;
246- bf. assign ( & b) ;
247-
248- assert_eq ! ( ab, a. dot( & bf) ) ;
249- assert_eq ! ( ab, af. dot( & b) ) ;
250- assert_eq ! ( ab, af. dot( & bf) ) ;
251- }
252-
253104// Check that matrix multiplication of contiguous matrices returns a
254105// matrix with the same order
255106#[ test]
256107fn mat_mul_order ( ) {
257- let ( m, n, k) = ( 8 , 8 , 8 ) ;
108+ let ( m, n, k) = ( 50 , 50 , 50 ) ;
258109 let a = range_mat ( m, n) ;
259110 let b = range_mat ( n, k) ;
260111 let mut af = Array :: zeros ( a. dim ( ) . f ( ) ) ;
@@ -269,27 +120,6 @@ fn mat_mul_order() {
269120 assert_eq ! ( ff. strides( ) [ 0 ] , 1 ) ;
270121}
271122
272- // test matrix multiplication shape mismatch
273- #[ test]
274- #[ should_panic]
275- fn mat_mul_shape_mismatch ( ) {
276- let ( m, k, k2, n) = ( 8 , 8 , 9 , 8 ) ;
277- let a = range_mat ( m, k) ;
278- let b = range_mat ( k2, n) ;
279- a. dot ( & b) ;
280- }
281-
282- // test matrix multiplication shape mismatch
283- #[ test]
284- #[ should_panic]
285- fn mat_mul_shape_mismatch_2 ( ) {
286- let ( m, k, k2, n) = ( 8 , 8 , 8 , 8 ) ;
287- let a = range_mat ( m, k) ;
288- let b = range_mat ( k2, n) ;
289- let mut c = range_mat ( m, n + 1 ) ;
290- general_mat_mul ( 1. , & a, & b, 1. , & mut c) ;
291- }
292-
293123// Check that matrix multiplication
294124// supports broadcast arrays.
295125#[ test]
@@ -348,102 +178,6 @@ fn mat_mut_zero_len() {
348178 mat_mul_zero_len ! ( range_i32) ;
349179}
350180
351- #[ test]
352- fn scaled_add ( ) {
353- let a = range_mat ( 16 , 15 ) ;
354- let mut b = range_mat ( 16 , 15 ) ;
355- b. mapv_inplace ( f32:: exp) ;
356-
357- let alpha = 0.2_f32 ;
358- let mut c = a. clone ( ) ;
359- c. scaled_add ( alpha, & b) ;
360-
361- let d = alpha * & b + & a;
362- assert_eq ! ( c, d) ;
363- }
364-
365- #[ test]
366- fn scaled_add_2 ( ) {
367- let beta = -2.3 ;
368- let sizes = vec ! [
369- ( 4 , 4 , 1 , 4 ) ,
370- ( 8 , 8 , 1 , 8 ) ,
371- ( 17 , 15 , 17 , 15 ) ,
372- ( 4 , 17 , 4 , 17 ) ,
373- ( 17 , 3 , 1 , 3 ) ,
374- ( 19 , 18 , 19 , 18 ) ,
375- ( 16 , 17 , 16 , 17 ) ,
376- ( 15 , 16 , 15 , 16 ) ,
377- ( 67 , 63 , 1 , 63 ) ,
378- ] ;
379- // test different strides
380- for & s1 in & [ 1 , 2 , -1 , -2 ] {
381- for & s2 in & [ 1 , 2 , -1 , -2 ] {
382- for & ( m, k, n, q) in & sizes {
383- let mut a = range_mat64 ( m, k) ;
384- let mut answer = a. clone ( ) ;
385- let c = range_mat64 ( n, q) ;
386-
387- {
388- let mut av = a. slice_mut ( s ! [ ..; s1, ..; s2] ) ;
389- let c = c. slice ( s ! [ ..; s1, ..; s2] ) ;
390-
391- let mut answerv = answer. slice_mut ( s ! [ ..; s1, ..; s2] ) ;
392- answerv += & ( beta * & c) ;
393- av. scaled_add ( beta, & c) ;
394- }
395- assert_relative_eq ! ( a, answer, epsilon = 1e-12 , max_relative = 1e-7 ) ;
396- }
397- }
398- }
399- }
400-
401- #[ test]
402- fn scaled_add_3 ( ) {
403- let beta = -2.3 ;
404- let sizes = vec ! [
405- ( 4 , 4 , 1 , 4 ) ,
406- ( 8 , 8 , 1 , 8 ) ,
407- ( 17 , 15 , 17 , 15 ) ,
408- ( 4 , 17 , 4 , 17 ) ,
409- ( 17 , 3 , 1 , 3 ) ,
410- ( 19 , 18 , 19 , 18 ) ,
411- ( 16 , 17 , 16 , 17 ) ,
412- ( 15 , 16 , 15 , 16 ) ,
413- ( 67 , 63 , 1 , 63 ) ,
414- ] ;
415- // test different strides
416- for & s1 in & [ 1 , 2 , -1 , -2 ] {
417- for & s2 in & [ 1 , 2 , -1 , -2 ] {
418- for & ( m, k, n, q) in & sizes {
419- let mut a = range_mat64 ( m, k) ;
420- let mut answer = a. clone ( ) ;
421- let cdim = if n == 1 { vec ! [ q] } else { vec ! [ n, q] } ;
422- let cslice: Vec < SliceInfoElem > = if n == 1 {
423- vec ! [ Slice :: from( ..) . step_by( s2) . into( ) ]
424- } else {
425- vec ! [
426- Slice :: from( ..) . step_by( s1) . into( ) ,
427- Slice :: from( ..) . step_by( s2) . into( ) ,
428- ]
429- } ;
430-
431- let c = range_mat64 ( n, q) . into_shape ( cdim) . unwrap ( ) ;
432-
433- {
434- let mut av = a. slice_mut ( s ! [ ..; s1, ..; s2] ) ;
435- let c = c. slice ( & * cslice) ;
436-
437- let mut answerv = answer. slice_mut ( s ! [ ..; s1, ..; s2] ) ;
438- answerv += & ( beta * & c) ;
439- av. scaled_add ( beta, & c) ;
440- }
441- assert_relative_eq ! ( a, answer, epsilon = 1e-12 , max_relative = 1e-7 ) ;
442- }
443- }
444- }
445- }
446-
447181#[ test]
448182fn gen_mat_mul ( ) {
449183 let alpha = -2.3 ;
@@ -497,32 +231,6 @@ fn gemm_64_1_f() {
497231 assert_relative_eq ! ( y, answer, epsilon = 1e-12 , max_relative = 1e-7 ) ;
498232}
499233
500- #[ test]
501- fn gen_mat_mul_i32 ( ) {
502- let alpha = -1 ;
503- let beta = 2 ;
504- let sizes = vec ! [
505- ( 4 , 4 , 4 ) ,
506- ( 8 , 8 , 8 ) ,
507- ( 17 , 15 , 16 ) ,
508- ( 4 , 17 , 3 ) ,
509- ( 17 , 3 , 22 ) ,
510- ( 19 , 18 , 2 ) ,
511- ( 16 , 17 , 15 ) ,
512- ( 15 , 16 , 17 ) ,
513- ( 67 , 63 , 62 ) ,
514- ] ;
515- for & ( m, k, n) in & sizes {
516- let a = range_i32 ( m, k) ;
517- let b = range_i32 ( k, n) ;
518- let mut c = range_i32 ( m, n) ;
519-
520- let answer = alpha * reference_mat_mul ( & a, & b) + beta * & c;
521- general_mat_mul ( alpha, & a, & b, beta, & mut c) ;
522- assert_eq ! ( & c, & answer) ;
523- }
524- }
525-
526234#[ test]
527235fn gen_mat_vec_mul ( ) {
528236 let alpha = -2.3 ;
0 commit comments