@@ -416,8 +416,23 @@ fn mat_mul_impl<A>(
416416 rhs_trans = CblasTrans ;
417417 }
418418
419+ macro_rules! cast_ty {
420+ ( f32 , $var: ident) => {
421+ cast_as( & $var)
422+ } ;
423+ ( f64 , $var: ident) => {
424+ cast_as( & $var)
425+ } ;
426+ ( c32, $var: ident) => {
427+ & $var as * const A as * const _
428+ } ;
429+ ( c64, $var: ident) => {
430+ & $var as * const A as * const _
431+ } ;
432+ }
433+
419434 macro_rules! gemm {
420- ( $ty: ty , $gemm: ident) => {
435+ ( $ty: tt , $gemm: ident) => {
421436 if blas_row_major_2d:: <$ty, _>( & lhs_)
422437 && blas_row_major_2d:: <$ty, _>( & rhs_)
423438 && blas_row_major_2d:: <$ty, _>( & c_)
@@ -437,9 +452,9 @@ fn mat_mul_impl<A>(
437452 let lhs_stride = cmp:: max( lhs_. strides( ) [ 0 ] as blas_index, k as blas_index) ;
438453 let rhs_stride = cmp:: max( rhs_. strides( ) [ 0 ] as blas_index, n as blas_index) ;
439454 let c_stride = cmp:: max( c_. strides( ) [ 0 ] as blas_index, n as blas_index) ;
440-
441455 // gemm is C ← αA^Op B^Op + βC
442456 // Where Op is notrans/trans/conjtrans
457+
443458 unsafe {
444459 blas_sys:: $gemm(
445460 CblasRowMajor ,
@@ -448,12 +463,12 @@ fn mat_mul_impl<A>(
448463 m as blas_index, // m, rows of Op(a)
449464 n as blas_index, // n, cols of Op(b)
450465 k as blas_index, // k, cols of Op(a)
451- cast_as ( & alpha) , // alpha
466+ cast_ty! ( $ty , alpha) , // alpha
452467 lhs_. ptr. as_ptr( ) as * const _, // a
453468 lhs_stride, // lda
454469 rhs_. ptr. as_ptr( ) as * const _, // b
455470 rhs_stride, // ldb
456- cast_as ( & beta) , // beta
471+ cast_ty! ( $ty , beta) , // beta
457472 c_. ptr. as_ptr( ) as * mut _, // c
458473 c_stride, // ldc
459474 ) ;
@@ -465,52 +480,6 @@ fn mat_mul_impl<A>(
465480 gemm ! ( f32 , cblas_sgemm) ;
466481 gemm ! ( f64 , cblas_dgemm) ;
467482
468- macro_rules! gemm {
469- ( $ty: ty, $gemm: ident) => {
470- if blas_row_major_2d:: <$ty, _>( & lhs_)
471- && blas_row_major_2d:: <$ty, _>( & rhs_)
472- && blas_row_major_2d:: <$ty, _>( & c_)
473- {
474- let ( m, k) = match lhs_trans {
475- CblasNoTrans => lhs_. dim( ) ,
476- _ => {
477- let ( rows, cols) = lhs_. dim( ) ;
478- ( cols, rows)
479- }
480- } ;
481- let n = match rhs_trans {
482- CblasNoTrans => rhs_. raw_dim( ) [ 1 ] ,
483- _ => rhs_. raw_dim( ) [ 0 ] ,
484- } ;
485- // adjust strides, these may [1, 1] for column matrices
486- let lhs_stride = cmp:: max( lhs_. strides( ) [ 0 ] as blas_index, k as blas_index) ;
487- let rhs_stride = cmp:: max( rhs_. strides( ) [ 0 ] as blas_index, n as blas_index) ;
488- let c_stride = cmp:: max( c_. strides( ) [ 0 ] as blas_index, n as blas_index) ;
489-
490- // gemm is C ← αA^Op B^Op + βC
491- // Where Op is notrans/trans/conjtrans
492- unsafe {
493- blas_sys:: $gemm(
494- CblasRowMajor ,
495- lhs_trans,
496- rhs_trans,
497- m as blas_index, // m, rows of Op(a)
498- n as blas_index, // n, cols of Op(b)
499- k as blas_index, // k, cols of Op(a)
500- & alpha as * const A as * const _, // alpha
501- lhs_. ptr. as_ptr( ) as * const _, // a
502- lhs_stride, // lda
503- rhs_. ptr. as_ptr( ) as * const _, // b
504- rhs_stride, // ldb
505- & beta as * const A as * const _, // beta
506- c_. ptr. as_ptr( ) as * mut _, // c
507- c_stride, // ldc
508- ) ;
509- }
510- return ;
511- }
512- } ;
513- }
514483 gemm ! ( c32, cblas_cgemm) ;
515484 gemm ! ( c64, cblas_zgemm) ;
516485 }
0 commit comments