@@ -18,6 +18,9 @@ use std::any::TypeId;
1818use std:: mem:: MaybeUninit ;
1919use alloc:: vec:: Vec ;
2020
21+ use num_complex:: Complex ;
22+ use num_complex:: { Complex32 as c32, Complex64 as c64} ;
23+
2124#[ cfg( feature = "blas" ) ]
2225use libc:: c_int;
2326#[ cfg( feature = "blas" ) ]
@@ -30,9 +33,6 @@ use cblas_sys as blas_sys;
3033#[ cfg( feature = "blas" ) ]
3134use cblas_sys:: { CblasNoTrans , CblasRowMajor , CblasTrans , CBLAS_LAYOUT } ;
3235
33- #[ cfg( feature = "blas" ) ]
34- use num_complex:: { Complex32 as c32, Complex64 as c64} ;
35-
3636/// len of vector before we use blas
3737#[ cfg( feature = "blas" ) ]
3838const DOT_BLAS_CUTOFF : usize = 32 ;
@@ -505,7 +505,7 @@ fn mat_mul_general<A>(
505505 let ( rsc, csc) = ( c. strides ( ) [ 0 ] , c. strides ( ) [ 1 ] ) ;
506506 if same_type :: < A , f32 > ( ) {
507507 unsafe {
508- :: matrixmultiply:: sgemm (
508+ matrixmultiply:: sgemm (
509509 m,
510510 k,
511511 n,
@@ -524,7 +524,7 @@ fn mat_mul_general<A>(
524524 }
525525 } else if same_type :: < A , f64 > ( ) {
526526 unsafe {
527- :: matrixmultiply:: dgemm (
527+ matrixmultiply:: dgemm (
528528 m,
529529 k,
530530 n,
@@ -541,6 +541,48 @@ fn mat_mul_general<A>(
541541 csc,
542542 ) ;
543543 }
544+ } else if same_type :: < A , c32 > ( ) {
545+ unsafe {
546+ matrixmultiply:: cgemm (
547+ matrixmultiply:: CGemmOption :: Standard ,
548+ matrixmultiply:: CGemmOption :: Standard ,
549+ m,
550+ k,
551+ n,
552+ complex_array ( cast_as ( & alpha) ) ,
553+ ap as * const _ ,
554+ lhs. strides ( ) [ 0 ] ,
555+ lhs. strides ( ) [ 1 ] ,
556+ bp as * const _ ,
557+ rhs. strides ( ) [ 0 ] ,
558+ rhs. strides ( ) [ 1 ] ,
559+ complex_array ( cast_as ( & beta) ) ,
560+ cp as * mut _ ,
561+ rsc,
562+ csc,
563+ ) ;
564+ }
565+ } else if same_type :: < A , c64 > ( ) {
566+ unsafe {
567+ matrixmultiply:: zgemm (
568+ matrixmultiply:: CGemmOption :: Standard ,
569+ matrixmultiply:: CGemmOption :: Standard ,
570+ m,
571+ k,
572+ n,
573+ complex_array ( cast_as ( & alpha) ) ,
574+ ap as * const _ ,
575+ lhs. strides ( ) [ 0 ] ,
576+ lhs. strides ( ) [ 1 ] ,
577+ bp as * const _ ,
578+ rhs. strides ( ) [ 0 ] ,
579+ rhs. strides ( ) [ 1 ] ,
580+ complex_array ( cast_as ( & beta) ) ,
581+ cp as * mut _ ,
582+ rsc,
583+ csc,
584+ ) ;
585+ }
544586 } else {
545587 // It's a no-op if `c` has zero length.
546588 if c. is_empty ( ) {
@@ -768,10 +810,17 @@ fn same_type<A: 'static, B: 'static>() -> bool {
768810//
769811// **Panics** if `A` and `B` are not the same type
770812fn cast_as < A : ' static + Copy , B : ' static + Copy > ( a : & A ) -> B {
771- assert ! ( same_type:: <A , B >( ) ) ;
813+ assert ! ( same_type:: <A , B >( ) , "expect type {} and {} to match" ,
814+ std:: any:: type_name:: <A >( ) , std:: any:: type_name:: <B >( ) ) ;
772815 unsafe { :: std:: ptr:: read ( a as * const _ as * const B ) }
773816}
774817
818+ /// Return the complex in the form of an array [re, im]
819+ #[ inline]
820+ fn complex_array < A : ' static + Copy > ( z : Complex < A > ) -> [ A ; 2 ] {
821+ [ z. re , z. im ]
822+ }
823+
775824#[ cfg( feature = "blas" ) ]
776825fn blas_compat_1d < A , S > ( a : & ArrayBase < S , Ix1 > ) -> bool
777826where
0 commit comments