@@ -863,6 +863,7 @@ where
863863
864864#[ cfg( feature = "blas" ) ]
865865#[ derive( Copy , Clone ) ]
866+ #[ cfg_attr( test, derive( PartialEq , Eq , Debug ) ) ]
866867enum MemoryOrder
867868{
868869 C ,
@@ -887,24 +888,34 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool
887888 let ( m, n) = dim. into_pattern ( ) ;
888889 let s0 = stride[ 0 ] as isize ;
889890 let s1 = stride[ 1 ] as isize ;
890- let ( inner_stride, outer_dim) = match order {
891- MemoryOrder :: C => ( s1, n) ,
892- MemoryOrder :: F => ( s0, m) ,
891+ let ( inner_stride, outer_stride , inner_dim , outer_dim) = match order {
892+ MemoryOrder :: C => ( s1, s0 , m , n) ,
893+ MemoryOrder :: F => ( s0, s1 , n , m) ,
893894 } ;
895+
894896 if !( inner_stride == 1 || outer_dim == 1 ) {
895897 return false ;
896898 }
899+
897900 if s0 < 1 || s1 < 1 {
898901 return false ;
899902 }
903+
900904 if ( s0 > blas_index:: MAX as isize || s0 < blas_index:: MIN as isize )
901905 || ( s1 > blas_index:: MAX as isize || s1 < blas_index:: MIN as isize )
902906 {
903907 return false ;
904908 }
909+
910+ // leading stride must >= the dimension (no broadcasting/aliasing)
911+ if inner_dim > 1 && ( outer_stride as usize ) < outer_dim {
912+ return false ;
913+ }
914+
905915 if m > blas_index:: MAX as usize || n > blas_index:: MAX as usize {
906916 return false ;
907917 }
918+
908919 true
909920}
910921
@@ -1068,8 +1079,26 @@ mod blas_tests
10681079 }
10691080
10701081 #[ test]
1071- fn test ( )
1082+ fn blas_too_short_stride ( )
10721083 {
1073- //WIP test that stride is larger than other dimension
1084+ // leading stride must be longer than the other dimension
1085+ // Example, in a 5 x 5 matrix, the leading stride must be >= 5 for BLAS.
1086+
1087+ const N : usize = 5 ;
1088+ const MAXSTRIDE : usize = N + 2 ;
1089+ let mut data = [ 0 ; MAXSTRIDE * N ] ;
1090+ let mut iter = 0 ..data. len ( ) ;
1091+ data. fill_with ( || iter. next ( ) . unwrap ( ) ) ;
1092+
1093+ for stride in 1 ..=MAXSTRIDE {
1094+ let m = ArrayView :: from_shape ( ( N , N ) . strides ( ( stride, 1 ) ) , & data) . unwrap ( ) ;
1095+ eprintln ! ( "{:?}" , m) ;
1096+
1097+ if stride < N {
1098+ assert_eq ! ( get_blas_compatible_layout( & m) , None ) ;
1099+ } else {
1100+ assert_eq ! ( get_blas_compatible_layout( & m) , Some ( MemoryOrder :: C ) ) ;
1101+ }
1102+ }
10741103 }
10751104}
0 commit comments