77// except according to those terms.
88
99use crate :: imp_prelude:: * ;
10- use crate :: numeric_util ;
10+
1111#[ cfg( feature = "blas" ) ]
1212use crate :: dimension:: offset_from_low_addr_ptr_to_logical_ptr;
13+ use crate :: numeric_util;
1314
1415use crate :: { LinalgScalar , Zip } ;
1516
1617use std:: any:: TypeId ;
1718use std:: mem:: MaybeUninit ;
1819use alloc:: vec:: Vec ;
1920
21+ #[ cfg( feature = "blas" ) ]
22+ use libc:: c_int;
2023#[ cfg( feature = "blas" ) ]
2124use std:: cmp;
2225#[ cfg( feature = "blas" ) ]
2326use std:: mem:: swap;
24- #[ cfg( feature = "blas" ) ]
25- use libc:: c_int;
2627
2728#[ cfg( feature = "blas" ) ]
2829use cblas_sys as blas_sys;
2930#[ cfg( feature = "blas" ) ]
3031use cblas_sys:: { CblasNoTrans , CblasRowMajor , CblasTrans , CBLAS_LAYOUT } ;
3132
33+ #[ cfg( feature = "blas" ) ]
34+ use num_complex:: { Complex32 as c32, Complex64 as c64} ;
35+
3236/// len of vector before we use blas
3337#[ cfg( feature = "blas" ) ]
3438const DOT_BLAS_CUTOFF : usize = 32 ;
@@ -377,7 +381,12 @@ fn mat_mul_impl<A>(
377381 // size cutoff for using BLAS
378382 let cut = GEMM_BLAS_CUTOFF ;
379383 let ( ( mut m, a) , ( _, mut n) ) = ( lhs. dim ( ) , rhs. dim ( ) ) ;
380- if !( m > cut || n > cut || a > cut) || !( same_type :: < A , f32 > ( ) || same_type :: < A , f64 > ( ) ) {
384+ if !( m > cut || n > cut || a > cut)
385+ || !( same_type :: < A , f32 > ( )
386+ || same_type :: < A , f64 > ( )
387+ || same_type :: < A , c32 > ( )
388+ || same_type :: < A , c64 > ( ) )
389+ {
381390 return mat_mul_general ( alpha, lhs, rhs, beta, c) ;
382391 }
383392 {
@@ -407,8 +416,23 @@ fn mat_mul_impl<A>(
407416 rhs_trans = CblasTrans ;
408417 }
409418
419+ macro_rules! gemm_scalar_cast {
420+ ( f32 , $var: ident) => {
421+ cast_as( & $var)
422+ } ;
423+ ( f64 , $var: ident) => {
424+ cast_as( & $var)
425+ } ;
426+ ( c32, $var: ident) => {
427+ & $var as * const A as * const _
428+ } ;
429+ ( c64, $var: ident) => {
430+ & $var as * const A as * const _
431+ } ;
432+ }
433+
410434 macro_rules! gemm {
411- ( $ty: ty , $gemm: ident) => {
435+ ( $ty: tt , $gemm: ident) => {
412436 if blas_row_major_2d:: <$ty, _>( & lhs_)
413437 && blas_row_major_2d:: <$ty, _>( & rhs_)
414438 && blas_row_major_2d:: <$ty, _>( & c_)
@@ -428,25 +452,25 @@ fn mat_mul_impl<A>(
428452 let lhs_stride = cmp:: max( lhs_. strides( ) [ 0 ] as blas_index, k as blas_index) ;
429453 let rhs_stride = cmp:: max( rhs_. strides( ) [ 0 ] as blas_index, n as blas_index) ;
430454 let c_stride = cmp:: max( c_. strides( ) [ 0 ] as blas_index, n as blas_index) ;
431-
455+
432456 // gemm is C ← αA^Op B^Op + βC
433457 // Where Op is notrans/trans/conjtrans
434458 unsafe {
435459 blas_sys:: $gemm(
436460 CblasRowMajor ,
437461 lhs_trans,
438462 rhs_trans,
439- m as blas_index, // m, rows of Op(a)
440- n as blas_index, // n, cols of Op(b)
441- k as blas_index, // k, cols of Op(a)
442- cast_as ( & alpha) , // alpha
443- lhs_. ptr. as_ptr( ) as * const _, // a
444- lhs_stride, // lda
445- rhs_. ptr. as_ptr( ) as * const _, // b
446- rhs_stride, // ldb
447- cast_as ( & beta) , // beta
448- c_. ptr. as_ptr( ) as * mut _, // c
449- c_stride, // ldc
463+ m as blas_index, // m, rows of Op(a)
464+ n as blas_index, // n, cols of Op(b)
465+ k as blas_index, // k, cols of Op(a)
466+ gemm_scalar_cast! ( $ty , alpha) , // alpha
467+ lhs_. ptr. as_ptr( ) as * const _, // a
468+ lhs_stride, // lda
469+ rhs_. ptr. as_ptr( ) as * const _, // b
470+ rhs_stride, // ldb
471+ gemm_scalar_cast! ( $ty , beta) , // beta
472+ c_. ptr. as_ptr( ) as * mut _, // c
473+ c_stride, // ldc
450474 ) ;
451475 }
452476 return ;
@@ -455,6 +479,9 @@ fn mat_mul_impl<A>(
455479 }
456480 gemm ! ( f32 , cblas_sgemm) ;
457481 gemm ! ( f64 , cblas_dgemm) ;
482+
483+ gemm ! ( c32, cblas_cgemm) ;
484+ gemm ! ( c64, cblas_zgemm) ;
458485 }
459486 mat_mul_general ( alpha, lhs, rhs, beta, c)
460487}
@@ -603,9 +630,7 @@ pub fn general_mat_vec_mul<A, S1, S2, S3>(
603630 S3 : DataMut < Elem = A > ,
604631 A : LinalgScalar ,
605632{
606- unsafe {
607- general_mat_vec_mul_impl ( alpha, a, x, beta, y. raw_view_mut ( ) )
608- }
633+ unsafe { general_mat_vec_mul_impl ( alpha, a, x, beta, y. raw_view_mut ( ) ) }
609634}
610635
611636/// General matrix-vector multiplication
0 commit comments