@@ -3,6 +3,7 @@ extern crate defmac;
33extern crate ndarray;
44extern crate num_traits;
55extern crate blas_src;
6+ extern crate num_complex;
67
78use ndarray:: prelude:: * ;
89
@@ -12,6 +13,8 @@ use ndarray::{Data, Ix, LinalgScalar};
1213
1314use approx:: assert_relative_eq;
1415use defmac:: defmac;
16+ use num_complex:: Complex32 ;
17+ use num_complex:: Complex64 ;
1518
1619#[ test]
1720fn mat_vec_product_1d ( ) {
@@ -52,6 +55,20 @@ fn range_mat64(m: Ix, n: Ix) -> Array2<f64> {
5255 . unwrap ( )
5356}
5457
58+ fn range_mat_complex ( m : Ix , n : Ix ) -> Array2 < Complex32 > {
59+ Array :: linspace ( 0. , ( m * n) as f32 - 1. , m * n)
60+ . into_shape ( ( m, n) )
61+ . unwrap ( )
62+ . map_mut ( |& mut f| Complex32 :: new ( f, 0. ) )
63+ }
64+
65+ fn range_mat_complex64 ( m : Ix , n : Ix ) -> Array2 < Complex64 > {
66+ Array :: linspace ( 0. , ( m * n) as f64 - 1. , m * n)
67+ . into_shape ( ( m, n) )
68+ . unwrap ( )
69+ . map_mut ( |& mut f| Complex64 :: new ( f, 0. ) )
70+ }
71+
5572fn range1_mat64 ( m : Ix ) -> Array1 < f64 > {
5673 Array :: linspace ( 0. , m as f64 - 1. , m)
5774}
@@ -250,6 +267,30 @@ fn gemm_64_1_f() {
250267 assert_relative_eq ! ( y, answer, epsilon = 1e-12 , max_relative = 1e-7 ) ;
251268}
252269
270+ #[ test]
271+ fn gemm_c64_1_f ( ) {
272+ let a = range_mat_complex64 ( 64 , 64 ) . reversed_axes ( ) ;
273+ let ( m, n) = a. dim ( ) ;
274+ // m x n times n x 1 == m x 1
275+ let x = range_mat_complex64 ( n, 1 ) ;
276+ let mut y = range_mat_complex64 ( m, 1 ) ;
277+ let answer = reference_mat_mul ( & a, & x) + & y;
278+ general_mat_mul ( Complex64 :: new ( 1.0 , 0. ) , & a, & x, Complex64 :: new ( 1.0 , 0. ) , & mut y) ;
279+ assert_relative_eq ! ( y. mapv( |i| i. norm_sqr( ) ) , answer. mapv( |i| i. norm_sqr( ) ) , epsilon = 1e-12 , max_relative = 1e-7 ) ;
280+ }
281+
282+ #[ test]
283+ fn gemm_c32_1_f ( ) {
284+ let a = range_mat_complex ( 64 , 64 ) . reversed_axes ( ) ;
285+ let ( m, n) = a. dim ( ) ;
286+ // m x n times n x 1 == m x 1
287+ let x = range_mat_complex ( n, 1 ) ;
288+ let mut y = range_mat_complex ( m, 1 ) ;
289+ let answer = reference_mat_mul ( & a, & x) + & y;
290+ general_mat_mul ( Complex32 :: new ( 1.0 , 0. ) , & a, & x, Complex32 :: new ( 1.0 , 0. ) , & mut y) ;
291+ assert_relative_eq ! ( y. mapv( |i| i. norm_sqr( ) ) , answer. mapv( |i| i. norm_sqr( ) ) , epsilon = 1e-12 , max_relative = 1e-7 ) ;
292+ }
293+
253294#[ test]
254295fn gen_mat_vec_mul ( ) {
255296 let alpha = -2.3 ;
0 commit comments