@@ -28,7 +28,7 @@ use libc::c_int;
2828#[ cfg( feature = "blas" ) ]
2929use cblas_sys as blas_sys;
3030#[ cfg( feature = "blas" ) ]
31- use cblas_sys:: { CblasNoTrans , CblasTrans , CBLAS_LAYOUT } ;
31+ use cblas_sys:: { CblasNoTrans , CblasTrans , CBLAS_LAYOUT , CBLAS_TRANSPOSE } ;
3232
3333/// len of vector before we use blas
3434#[ cfg( feature = "blas" ) ]
@@ -400,40 +400,33 @@ fn mat_mul_impl<A>(
400400 // Compute A B -> C
401401 // We require for BLAS compatibility that:
402402 // A, B, C are contiguous (stride=1) in their fastest dimension,
403- // but it can be either first or second axis (either rowmajor /"c" or colmajor /"f") .
403+ // but they can be either row major /"c" or col major /"f".
404404 //
405405 // The "normal case" is CblasRowMajor for cblas.
406- // Select CblasRowMajor, CblasColMajor to fit C's memory order.
406+ // Select CblasRowMajor / CblasColMajor to fit C's memory order.
407407 //
408- // Apply transpose to A, B as needed if they differ from the normal case.
408+ // Apply transpose to A, B as needed if they differ from the row major case.
409409 // If C is CblasColMajor then transpose both A, B (again!)
410410
411- let ( a_layout, a_axis, b_layout, b_axis, c_layout) =
412- match ( get_blas_compatible_layout ( a) ,
413- get_blas_compatible_layout ( b) ,
414- get_blas_compatible_layout ( c) )
411+ let ( a_layout, b_layout, c_layout) =
412+ if let ( Some ( a_layout) , Some ( b_layout) , Some ( c_layout) ) =
413+ ( get_blas_compatible_layout ( a) ,
414+ get_blas_compatible_layout ( b) ,
415+ get_blas_compatible_layout ( c) )
415416 {
416- ( Some ( a_layout) , Some ( b_layout) , Some ( c_layout @ MemoryOrder :: C ) ) => {
417- ( a_layout, a_layout. lead_axis ( ) ,
418- b_layout, b_layout. lead_axis ( ) , c_layout)
419- } ,
420- ( Some ( a_layout) , Some ( b_layout) , Some ( c_layout @ MemoryOrder :: F ) ) => {
421- // CblasColMajor is the "other case"
422- // Mark a, b as having layouts opposite of what they were detected as, which
423- // ends up with the correct transpose setting w.r.t col major
424- ( a_layout. opposite ( ) , a_layout. lead_axis ( ) ,
425- b_layout. opposite ( ) , b_layout. lead_axis ( ) , c_layout)
426- } ,
427- _ => break ' blas_block,
417+ ( a_layout, b_layout, c_layout)
418+ } else {
419+ break ' blas_block;
428420 } ;
429421
430- let a_trans = a_layout. to_cblas_transpose ( ) ;
431- let lda = blas_stride ( & a, a_axis) ;
422+ let cblas_layout = c_layout. to_cblas_layout ( ) ;
423+ let a_trans = a_layout. to_cblas_transpose_for ( cblas_layout) ;
424+ let lda = blas_stride ( & a, a_layout) ;
432425
433- let b_trans = b_layout. to_cblas_transpose ( ) ;
434- let ldb = blas_stride ( & b, b_axis ) ;
426+ let b_trans = b_layout. to_cblas_transpose_for ( cblas_layout ) ;
427+ let ldb = blas_stride ( & b, b_layout ) ;
435428
436- let ldc = blas_stride ( & c, c_layout. lead_axis ( ) ) ;
429+ let ldc = blas_stride ( & c, c_layout) ;
437430
438431 macro_rules! gemm_scalar_cast {
439432 ( f32 , $var: ident) => {
@@ -457,7 +450,7 @@ fn mat_mul_impl<A>(
457450 // Where Op is notrans/trans/conjtrans
458451 unsafe {
459452 blas_sys:: $gemm(
460- c_layout . to_cblas_layout ( ) ,
453+ cblas_layout ,
461454 a_trans,
462455 b_trans,
463456 m as blas_index, // m, rows of Op(a)
@@ -696,16 +689,8 @@ unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
696689 // may be arbitrary.
697690 let a_trans = CblasNoTrans ;
698691
699- let ( a_stride, cblas_layout) = match layout {
700- MemoryOrder :: C => {
701- ( a. strides( ) [ 0 ] . max( k as isize ) as blas_index,
702- CBLAS_LAYOUT :: CblasRowMajor )
703- }
704- MemoryOrder :: F => {
705- ( a. strides( ) [ 1 ] . max( m as isize ) as blas_index,
706- CBLAS_LAYOUT :: CblasColMajor )
707- }
708- } ;
692+ let a_stride = blas_stride( & a, layout) ;
693+ let cblas_layout = layout. to_cblas_layout( ) ;
709694
710695 // Low addr in memory pointers required for x, y
711696 let x_offset = offset_from_low_addr_ptr_to_logical_ptr( & x. dim, & x. strides) ;
@@ -835,61 +820,66 @@ where
835820#[ cfg( feature = "blas" ) ]
836821#[ derive( Copy , Clone ) ]
837822#[ cfg_attr( test, derive( PartialEq , Eq , Debug ) ) ]
838- enum MemoryOrder
823+ enum BlasOrder
839824{
840825 C ,
841826 F ,
842827}
843828
844829#[ cfg( feature = "blas" ) ]
845- impl MemoryOrder
830+ impl BlasOrder
846831{
847- #[ inline]
848- /// Axis of leading stride (opposite of contiguous axis)
849- fn lead_axis ( self ) -> usize
832+ fn transpose ( self ) -> Self
850833 {
851834 match self {
852- MemoryOrder :: C => 0 ,
853- MemoryOrder :: F => 1 ,
835+ Self :: C => Self :: F ,
836+ Self :: F => Self :: C ,
854837 }
855838 }
856839
857- /// Get opposite memory order
858840 #[ inline]
859- fn opposite ( self ) -> Self
841+ /// Axis of leading stride (opposite of contiguous axis)
842+ fn get_blas_lead_axis ( self ) -> usize
860843 {
861844 match self {
862- MemoryOrder :: C => MemoryOrder :: F ,
863- MemoryOrder :: F => MemoryOrder :: C ,
845+ Self :: C => 0 ,
846+ Self :: F => 1 ,
864847 }
865848 }
866849
867- fn to_cblas_transpose ( self ) -> cblas_sys :: CBLAS_TRANSPOSE
850+ fn to_cblas_layout ( self ) -> CBLAS_LAYOUT
868851 {
869852 match self {
870- MemoryOrder :: C => CblasNoTrans ,
871- MemoryOrder :: F => CblasTrans ,
853+ Self :: C => CBLAS_LAYOUT :: CblasRowMajor ,
854+ Self :: F => CBLAS_LAYOUT :: CblasColMajor ,
872855 }
873856 }
874857
875- fn to_cblas_layout ( self ) -> CBLAS_LAYOUT
858+ /// When using cblas_sgemm (etc) with C matrix using `for_layout`,
859+ /// how should this `self` matrix be transposed
860+ fn to_cblas_transpose_for ( self , for_layout : CBLAS_LAYOUT ) -> CBLAS_TRANSPOSE
876861 {
877- match self {
878- MemoryOrder :: C => CBLAS_LAYOUT :: CblasRowMajor ,
879- MemoryOrder :: F => CBLAS_LAYOUT :: CblasColMajor ,
862+ let effective_order = match for_layout {
863+ CBLAS_LAYOUT :: CblasRowMajor => self ,
864+ CBLAS_LAYOUT :: CblasColMajor => self . transpose ( ) ,
865+ } ;
866+
867+ match effective_order {
868+ Self :: C => CblasNoTrans ,
869+ Self :: F => CblasTrans ,
880870 }
881871 }
882872}
883873
884874#[ cfg( feature = "blas" ) ]
885- fn is_blas_2d ( dim : & Ix2 , stride : & Ix2 , order : MemoryOrder ) -> bool
875+ fn is_blas_2d ( dim : & Ix2 , stride : & Ix2 , order : BlasOrder ) -> bool
886876{
887877 let ( m, n) = dim. into_pattern ( ) ;
888878 let s0 = stride[ 0 ] as isize ;
889879 let s1 = stride[ 1 ] as isize ;
890880 let ( inner_stride, outer_stride, inner_dim, outer_dim) = match order {
891- MemoryOrder :: C => ( s1, s0, m, n) ,
892- MemoryOrder :: F => ( s0, s1, n, m) ,
881+ BlasOrder :: C => ( s1, s0, m, n) ,
882+ BlasOrder :: F => ( s0, s1, n, m) ,
893883 } ;
894884
895885 if !( inner_stride == 1 || outer_dim == 1 ) {
@@ -920,13 +910,13 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool
920910
921911/// Get BLAS compatible layout if any (C or F, preferring the former)
922912#[ cfg( feature = "blas" ) ]
923- fn get_blas_compatible_layout < S > ( a : & ArrayBase < S , Ix2 > ) -> Option < MemoryOrder >
913+ fn get_blas_compatible_layout < S > ( a : & ArrayBase < S , Ix2 > ) -> Option < BlasOrder >
924914where S : Data
925915{
926- if is_blas_2d ( & a. dim , & a. strides , MemoryOrder :: C ) {
927- Some ( MemoryOrder :: C )
928- } else if is_blas_2d ( & a. dim , & a. strides , MemoryOrder :: F ) {
929- Some ( MemoryOrder :: F )
916+ if is_blas_2d ( & a. dim , & a. strides , BlasOrder :: C ) {
917+ Some ( BlasOrder :: C )
918+ } else if is_blas_2d ( & a. dim , & a. strides , BlasOrder :: F ) {
919+ Some ( BlasOrder :: F )
930920 } else {
931921 None
932922 }
@@ -937,10 +927,10 @@ where S: Data
937927///
938928/// Return leading stride (lda, ldb, ldc) of array
939929#[ cfg( feature = "blas" ) ]
940- fn blas_stride < S > ( a : & ArrayBase < S , Ix2 > , axis : usize ) -> blas_index
930+ fn blas_stride < S > ( a : & ArrayBase < S , Ix2 > , order : BlasOrder ) -> blas_index
941931where S : Data
942932{
943- debug_assert ! ( axis <= 1 ) ;
933+ let axis = order . get_blas_lead_axis ( ) ;
944934 let other_axis = 1 - axis;
945935 let len_this = a. shape ( ) [ axis] ;
946936 let len_other = a. shape ( ) [ other_axis] ;
@@ -968,7 +958,7 @@ where
968958 if !same_type :: < A , S :: Elem > ( ) {
969959 return false ;
970960 }
971- is_blas_2d ( & a. dim , & a. strides , MemoryOrder :: C )
961+ is_blas_2d ( & a. dim , & a. strides , BlasOrder :: C )
972962}
973963
974964#[ cfg( test) ]
@@ -982,7 +972,7 @@ where
982972 if !same_type :: < A , S :: Elem > ( ) {
983973 return false ;
984974 }
985- is_blas_2d ( & a. dim , & a. strides , MemoryOrder :: F )
975+ is_blas_2d ( & a. dim , & a. strides , BlasOrder :: F )
986976}
987977
988978#[ cfg( test) ]
@@ -1096,7 +1086,7 @@ mod blas_tests
10961086 if stride < N {
10971087 assert_eq ! ( get_blas_compatible_layout( & m) , None ) ;
10981088 } else {
1099- assert_eq ! ( get_blas_compatible_layout( & m) , Some ( MemoryOrder :: C ) ) ;
1089+ assert_eq ! ( get_blas_compatible_layout( & m) , Some ( BlasOrder :: C ) ) ;
11001090 }
11011091 }
11021092 }
0 commit comments