@@ -371,32 +371,15 @@ where
371371#[ cfg( not( feature = "blas" ) ) ]
372372use self :: mat_mul_general as mat_mul_impl;
373373
374- #[ rustfmt:: skip]
375374#[ cfg( feature = "blas" ) ]
376- fn mat_mul_impl < A > (
377- alpha : A ,
378- a : & ArrayView2 < ' _ , A > ,
379- b : & ArrayView2 < ' _ , A > ,
380- beta : A ,
381- c : & mut ArrayViewMut2 < ' _ , A > ,
382- ) where
383- A : LinalgScalar ,
375+ fn mat_mul_impl < A > ( alpha : A , a : & ArrayView2 < ' _ , A > , b : & ArrayView2 < ' _ , A > , beta : A , c : & mut ArrayViewMut2 < ' _ , A > )
376+ where A : LinalgScalar
384377{
385- // size cutoff for using BLAS
386- let cut = GEMM_BLAS_CUTOFF ;
387378 let ( ( m, k) , ( k2, n) ) = ( a. dim ( ) , b. dim ( ) ) ;
388379 debug_assert_eq ! ( k, k2) ;
389- if !( m > cut || n > cut || k > cut)
390- || !( same_type :: < A , f32 > ( )
391- || same_type :: < A , f64 > ( )
392- || same_type :: < A , c32 > ( )
393- || same_type :: < A , c64 > ( ) )
380+ if ( m > GEMM_BLAS_CUTOFF || n > GEMM_BLAS_CUTOFF || k > GEMM_BLAS_CUTOFF )
381+ && ( same_type :: < A , f32 > ( ) || same_type :: < A , f64 > ( ) || same_type :: < A , c32 > ( ) || same_type :: < A , c64 > ( ) )
394382 {
395- return mat_mul_general ( alpha, a, b, beta, c) ;
396- }
397-
398- #[ allow( clippy:: never_loop) ] // MSRV Rust 1.64 does not have break from block
399- ' blas_block: loop {
400383 // Compute A B -> C
401384 // We require for BLAS compatibility that:
402385 // A, B, C are contiguous (stride=1) in their fastest dimension,
@@ -408,75 +391,68 @@ fn mat_mul_impl<A>(
408391 // Apply transpose to A, B as needed if they differ from the row major case.
409392 // If C is CblasColMajor then transpose both A, B (again!)
410393
411- let ( a_layout, b_layout, c_layout) =
412- if let ( Some ( a_layout) , Some ( b_layout) , Some ( c_layout) ) =
413- ( get_blas_compatible_layout ( a) ,
414- get_blas_compatible_layout ( b) ,
415- get_blas_compatible_layout ( c) )
416- {
417- ( a_layout, b_layout, c_layout)
418- } else {
419- break ' blas_block;
420- } ;
421-
422- let cblas_layout = c_layout. to_cblas_layout ( ) ;
423- let a_trans = a_layout. to_cblas_transpose_for ( cblas_layout) ;
424- let lda = blas_stride ( & a, a_layout) ;
425-
426- let b_trans = b_layout. to_cblas_transpose_for ( cblas_layout) ;
427- let ldb = blas_stride ( & b, b_layout) ;
428-
429- let ldc = blas_stride ( & c, c_layout) ;
430-
431- macro_rules! gemm_scalar_cast {
432- ( f32 , $var: ident) => {
433- cast_as( & $var)
434- } ;
435- ( f64 , $var: ident) => {
436- cast_as( & $var)
437- } ;
438- ( c32, $var: ident) => {
439- & $var as * const A as * const _
440- } ;
441- ( c64, $var: ident) => {
442- & $var as * const A as * const _
443- } ;
444- }
394+ if let ( Some ( a_layout) , Some ( b_layout) , Some ( c_layout) ) =
395+ ( get_blas_compatible_layout ( a) , get_blas_compatible_layout ( b) , get_blas_compatible_layout ( c) )
396+ {
397+ let cblas_layout = c_layout. to_cblas_layout ( ) ;
398+ let a_trans = a_layout. to_cblas_transpose_for ( cblas_layout) ;
399+ let lda = blas_stride ( & a, a_layout) ;
400+
401+ let b_trans = b_layout. to_cblas_transpose_for ( cblas_layout) ;
402+ let ldb = blas_stride ( & b, b_layout) ;
403+
404+ let ldc = blas_stride ( & c, c_layout) ;
405+
406+ macro_rules! gemm_scalar_cast {
407+ ( f32 , $var: ident) => {
408+ cast_as( & $var)
409+ } ;
410+ ( f64 , $var: ident) => {
411+ cast_as( & $var)
412+ } ;
413+ ( c32, $var: ident) => {
414+ & $var as * const A as * const _
415+ } ;
416+ ( c64, $var: ident) => {
417+ & $var as * const A as * const _
418+ } ;
419+ }
445420
446- macro_rules! gemm {
447- ( $ty: tt, $gemm: ident) => {
448- if same_type:: <A , $ty>( ) {
449- // gemm is C ← αA^Op B^Op + βC
450- // Where Op is notrans/trans/conjtrans
451- unsafe {
452- blas_sys:: $gemm(
453- cblas_layout,
454- a_trans,
455- b_trans,
456- m as blas_index, // m, rows of Op(a)
457- n as blas_index, // n, cols of Op(b)
458- k as blas_index, // k, cols of Op(a)
459- gemm_scalar_cast!( $ty, alpha) , // alpha
460- a. ptr. as_ptr( ) as * const _, // a
461- lda, // lda
462- b. ptr. as_ptr( ) as * const _, // b
463- ldb, // ldb
464- gemm_scalar_cast!( $ty, beta) , // beta
465- c. ptr. as_ptr( ) as * mut _, // c
466- ldc, // ldc
467- ) ;
421+ macro_rules! gemm {
422+ ( $ty: tt, $gemm: ident) => {
423+ if same_type:: <A , $ty>( ) {
424+ // gemm is C ← αA^Op B^Op + βC
425+ // Where Op is notrans/trans/conjtrans
426+ unsafe {
427+ blas_sys:: $gemm(
428+ cblas_layout,
429+ a_trans,
430+ b_trans,
431+ m as blas_index, // m, rows of Op(a)
432+ n as blas_index, // n, cols of Op(b)
433+ k as blas_index, // k, cols of Op(a)
434+ gemm_scalar_cast!( $ty, alpha) , // alpha
435+ a. ptr. as_ptr( ) as * const _, // a
436+ lda, // lda
437+ b. ptr. as_ptr( ) as * const _, // b
438+ ldb, // ldb
439+ gemm_scalar_cast!( $ty, beta) , // beta
440+ c. ptr. as_ptr( ) as * mut _, // c
441+ ldc, // ldc
442+ ) ;
443+ }
444+ return ;
468445 }
469- return ;
470- }
471- } ;
472- }
446+ } ;
447+ }
473448
474- gemm ! ( f32 , cblas_sgemm) ;
475- gemm ! ( f64 , cblas_dgemm) ;
476- gemm ! ( c32, cblas_cgemm) ;
477- gemm ! ( c64, cblas_zgemm) ;
449+ gemm ! ( f32 , cblas_sgemm) ;
450+ gemm ! ( f64 , cblas_dgemm) ;
451+ gemm ! ( c32, cblas_cgemm) ;
452+ gemm ! ( c64, cblas_zgemm) ;
478453
479- break ' blas_block;
454+ unreachable ! ( ) // we checked above that A is one of f32, f64, c32, c64
455+ }
480456 }
481457 mat_mul_general ( alpha, a, b, beta, c)
482458}
0 commit comments