@@ -24,13 +24,11 @@ use num_complex::{Complex32 as c32, Complex64 as c64};
2424
2525#[ cfg( feature = "blas" ) ]
2626use libc:: c_int;
27- #[ cfg( feature = "blas" ) ]
28- use std:: mem:: swap;
2927
3028#[ cfg( feature = "blas" ) ]
3129use cblas_sys as blas_sys;
3230#[ cfg( feature = "blas" ) ]
33- use cblas_sys:: { CblasNoTrans , CblasRowMajor , CblasTrans , CBLAS_LAYOUT } ;
31+ use cblas_sys:: { CblasNoTrans , CblasTrans , CBLAS_LAYOUT } ;
3432
3533/// len of vector before we use blas
3634#[ cfg( feature = "blas" ) ]
@@ -377,93 +375,65 @@ use self::mat_mul_general as mat_mul_impl;
377375#[ cfg( feature = "blas" ) ]
378376fn mat_mul_impl < A > (
379377 alpha : A ,
380- lhs : & ArrayView2 < ' _ , A > ,
381- rhs : & ArrayView2 < ' _ , A > ,
378+ a : & ArrayView2 < ' _ , A > ,
379+ b : & ArrayView2 < ' _ , A > ,
382380 beta : A ,
383381 c : & mut ArrayViewMut2 < ' _ , A > ,
384382) where
385383 A : LinalgScalar ,
386384{
387385 // size cutoff for using BLAS
388386 let cut = GEMM_BLAS_CUTOFF ;
389- let ( ( mut m, k) , ( k2, mut n) ) = ( lhs . dim ( ) , rhs . dim ( ) ) ;
387+ let ( ( m, k) , ( k2, n) ) = ( a . dim ( ) , b . dim ( ) ) ;
390388 debug_assert_eq ! ( k, k2) ;
391389 if !( m > cut || n > cut || k > cut)
392390 || !( same_type :: < A , f32 > ( )
393391 || same_type :: < A , f64 > ( )
394392 || same_type :: < A , c32 > ( )
395393 || same_type :: < A , c64 > ( ) )
396394 {
397- return mat_mul_general ( alpha, lhs , rhs , beta, c) ;
395+ return mat_mul_general ( alpha, a , b , beta, c) ;
398396 }
399397
400398 #[ allow( clippy:: never_loop) ] // MSRV Rust 1.64 does not have break from block
401399 ' blas_block: loop {
402- let mut a = lhs. view ( ) ;
403- let mut b = rhs. view ( ) ;
404- let mut c = c. view_mut ( ) ;
405-
406- let c_layout = get_blas_compatible_layout ( & c) ;
407- let c_layout_is_c = matches ! ( c_layout, Some ( MemoryOrder :: C ) ) ;
408- let c_layout_is_f = matches ! ( c_layout, Some ( MemoryOrder :: F ) ) ;
409-
410400 // Compute A B -> C
411- // we require for BLAS compatibility that:
412- // A, B are contiguous (stride=1) in their fastest dimension.
413- // C is c-contiguous in one dimension (stride=1 in Axis(1))
401+ // We require for BLAS compatibility that:
402+ // A, B, C are contiguous (stride=1) in their fastest dimension,
403+ // but it can be either first or second axis (either rowmajor/"c" or colmajor/"f").
414404 //
415- // If C is f-contiguous, use transpose equivalency
416- // to translate to the C-contiguous case:
417- // A^t B^t = C^t => B A = C
418-
419- let ( a_layout, b_layout) =
420- match ( get_blas_compatible_layout ( & a) , get_blas_compatible_layout ( & b) ) {
421- ( Some ( a_layout) , Some ( b_layout) ) if c_layout_is_c => {
422- // normal case
423- ( a_layout, b_layout)
405+ // The "normal case" is CblasRowMajor for cblas.
406+ // Select CblasRowMajor, CblasColMajor to fit C's memory order.
407+ //
408+ // Apply transpose to A, B as needed if they differ from the normal case.
409+ // If C is CblasColMajor then transpose both A, B (again!)
410+
411+ let ( a_layout, a_axis, b_layout, b_axis, c_layout) =
412+ match ( get_blas_compatible_layout ( a) ,
413+ get_blas_compatible_layout ( b) ,
414+ get_blas_compatible_layout ( c) )
415+ {
416+ ( Some ( a_layout) , Some ( b_layout) , Some ( c_layout @ MemoryOrder :: C ) ) => {
417+ ( a_layout, a_layout. lead_axis ( ) ,
418+ b_layout, b_layout. lead_axis ( ) , c_layout)
424419 } ,
425- ( Some ( a_layout) , Some ( b_layout) ) if c_layout_is_f => {
426- // Transpose equivalency
427- // A^t B^t = C^t => B A = C
428- //
429- // A^t becomes the new B
430- // B^t becomes the new A
431- let a_t = a. reversed_axes ( ) ;
432- a = b. reversed_axes ( ) ;
433- b = a_t;
434- c = c. reversed_axes ( ) ;
435- // Assign (n, k, m) -> (m, k, n) effectively
436- swap ( & mut m, & mut n) ;
437-
438- // Continue using the already computed memory layouts
439- ( b_layout. opposite ( ) , a_layout. opposite ( ) )
420+ ( Some ( a_layout) , Some ( b_layout) , Some ( c_layout @ MemoryOrder :: F ) ) => {
421+ // CblasColMajor is the "other case"
422+ // Mark a, b as having layouts opposite of what they were detected as, which
423+ // ends up with the correct transpose setting w.r.t col major
424+ ( a_layout. opposite ( ) , a_layout. lead_axis ( ) ,
425+ b_layout. opposite ( ) , b_layout. lead_axis ( ) , c_layout)
440426 } ,
441- _otherwise => {
442- break ' blas_block;
443- }
427+ _ => break ' blas_block,
444428 } ;
445429
446- let a_trans;
447- let b_trans;
448- let lda; // Stride of a
449- let ldb; // Stride of b
430+ let a_trans = a_layout. to_cblas_transpose ( ) ;
431+ let lda = blas_stride ( & a, a_axis) ;
450432
451- if let MemoryOrder :: C = a_layout {
452- lda = blas_stride ( & a, 0 ) ;
453- a_trans = CblasNoTrans ;
454- } else {
455- lda = blas_stride ( & a, 1 ) ;
456- a_trans = CblasTrans ;
457- }
433+ let b_trans = b_layout. to_cblas_transpose ( ) ;
434+ let ldb = blas_stride ( & b, b_axis) ;
458435
459- if let MemoryOrder :: C = b_layout {
460- ldb = blas_stride ( & b, 0 ) ;
461- b_trans = CblasNoTrans ;
462- } else {
463- ldb = blas_stride ( & b, 1 ) ;
464- b_trans = CblasTrans ;
465- }
466- let ldc = blas_stride ( & c, 0 ) ;
436+ let ldc = blas_stride ( & c, c_layout. lead_axis ( ) ) ;
467437
468438 macro_rules! gemm_scalar_cast {
469439 ( f32 , $var: ident) => {
@@ -487,7 +457,7 @@ fn mat_mul_impl<A>(
487457 // Where Op is notrans/trans/conjtrans
488458 unsafe {
489459 blas_sys:: $gemm(
490- CblasRowMajor ,
460+ c_layout . to_cblas_layout ( ) ,
491461 a_trans,
492462 b_trans,
493463 m as blas_index, // m, rows of Op(a)
@@ -507,14 +477,15 @@ fn mat_mul_impl<A>(
507477 }
508478 } ;
509479 }
480+
510481 gemm ! ( f32 , cblas_sgemm) ;
511482 gemm ! ( f64 , cblas_dgemm) ;
512-
513483 gemm ! ( c32, cblas_cgemm) ;
514484 gemm ! ( c64, cblas_zgemm) ;
485+
515486 break ' blas_block;
516487 }
517- mat_mul_general ( alpha, lhs , rhs , beta, c)
488+ mat_mul_general ( alpha, a , b , beta, c)
518489}
519490
520491/// C ← α A B + β C
@@ -873,13 +844,41 @@ enum MemoryOrder
873844#[ cfg( feature = "blas" ) ]
874845impl MemoryOrder
875846{
847+ #[ inline]
848+ /// Axis of leading stride (opposite of contiguous axis)
849+ fn lead_axis ( self ) -> usize
850+ {
851+ match self {
852+ MemoryOrder :: C => 0 ,
853+ MemoryOrder :: F => 1 ,
854+ }
855+ }
856+
857+ /// Get opposite memory order
858+ #[ inline]
876859 fn opposite ( self ) -> Self
877860 {
878861 match self {
879862 MemoryOrder :: C => MemoryOrder :: F ,
880863 MemoryOrder :: F => MemoryOrder :: C ,
881864 }
882865 }
866+
867+ fn to_cblas_transpose ( self ) -> cblas_sys:: CBLAS_TRANSPOSE
868+ {
869+ match self {
870+ MemoryOrder :: C => CblasNoTrans ,
871+ MemoryOrder :: F => CblasTrans ,
872+ }
873+ }
874+
875+ fn to_cblas_layout ( self ) -> CBLAS_LAYOUT
876+ {
877+ match self {
878+ MemoryOrder :: C => CBLAS_LAYOUT :: CblasRowMajor ,
879+ MemoryOrder :: F => CBLAS_LAYOUT :: CblasColMajor ,
880+ }
881+ }
883882}
884883
885884#[ cfg( feature = "blas" ) ]
0 commit comments