@@ -757,6 +757,33 @@ where
757757 }
758758}
759759
760+ /// Attempt to merge axes if possible, starting from the back
761+ ///
762+ /// Given axes [Axis(0), Axis(1), Axis(2), Axis(3)] this attempts
763+ /// to merge all axes one by one into Axis(3); when/if this fails,
764+ /// it attempts to merge the rest of the axes together into the next
765+ /// axis in line, for example a result could be:
766+ ///
767+ /// [1, Axis(0) + Axis(1), 1, Axis(2) + Axis(3)] where `+` would
768+ /// mean axes were merged.
769+ pub ( crate ) fn merge_axes_from_the_back < D > ( dim : & mut D , strides : & mut D )
770+ where
771+ D : Dimension ,
772+ {
773+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
774+ match dim. ndim ( ) {
775+ 0 | 1 => { }
776+ n => {
777+ let mut last = n - 1 ;
778+ for i in ( 0 ..last) . rev ( ) {
779+ if !merge_axes ( dim, strides, Axis ( i) , Axis ( last) ) {
780+ last = i;
781+ }
782+ }
783+ }
784+ }
785+ }
786+
760787/// Move the axis which has the smallest absolute stride and a length
761788/// greater than one to be the last axis.
762789pub fn move_min_stride_axis_to_last < D > ( dim : & mut D , strides : & mut D )
@@ -821,12 +848,40 @@ where
821848 * strides = new_strides;
822849}
823850
851+
852+ /// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
853+ /// stride
854+ ///
855+ /// The axes are sorted according to the .abs() of their stride.
856+ pub ( crate ) fn sort_axes_to_standard < D > ( dim : & mut D , strides : & mut D )
857+ where
858+ D : Dimension ,
859+ {
860+ debug_assert ! ( dim. ndim( ) > 1 ) ;
861+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
862+ // bubble sort axes
863+ let mut changed = true ;
864+ while changed {
865+ changed = false ;
866+ for i in 0 ..dim. ndim ( ) - 1 {
867+ // make sure higher stride axes sort before.
868+ if strides. get_stride ( Axis ( i) ) . abs ( ) < strides. get_stride ( Axis ( i + 1 ) ) . abs ( ) {
869+ changed = true ;
870+ dim. slice_mut ( ) . swap ( i, i + 1 ) ;
871+ strides. slice_mut ( ) . swap ( i, i + 1 ) ;
872+ }
873+ }
874+ }
875+ }
876+
877+
824878#[ cfg( test) ]
825879mod test {
826880 use super :: {
827881 arith_seq_intersect, can_index_slice, can_index_slice_not_custom, extended_gcd,
828882 max_abs_offset_check_overflow, slice_min_max, slices_intersect,
829883 solve_linear_diophantine_eq, IntoDimension , squeeze,
884+ merge_axes_from_the_back,
830885 } ;
831886 use crate :: error:: { from_kind, ErrorKind } ;
832887 use crate :: slice:: Slice ;
@@ -1191,4 +1246,26 @@ mod test {
11911246 assert_eq ! ( d, dans) ;
11921247 assert_eq ! ( s, sans) ;
11931248 }
1249+
1250+ #[ test]
1251+ fn test_merge_axes_from_the_back ( ) {
1252+ let dyndim = Dim :: < & [ usize ] > ;
1253+
1254+ let mut d = Dim ( [ 3 , 4 , 5 ] ) ;
1255+ let mut s = Dim ( [ 20 , 5 , 1 ] ) ;
1256+ merge_axes_from_the_back ( & mut d, & mut s) ;
1257+ assert_eq ! ( d, Dim ( [ 1 , 1 , 60 ] ) ) ;
1258+ assert_eq ! ( s, Dim ( [ 20 , 5 , 1 ] ) ) ;
1259+
1260+ let mut d = Dim ( [ 3 , 4 , 5 , 2 ] ) ;
1261+ let mut s = Dim ( [ 80 , 20 , 2 , 1 ] ) ;
1262+ merge_axes_from_the_back ( & mut d, & mut s) ;
1263+ assert_eq ! ( d, Dim ( [ 1 , 12 , 1 , 10 ] ) ) ;
1264+ assert_eq ! ( s, Dim ( [ 80 , 20 , 2 , 1 ] ) ) ;
1265+ let mut d = d. into_dyn ( ) ;
1266+ let mut s = s. into_dyn ( ) ;
1267+ squeeze ( & mut d, & mut s) ;
1268+ assert_eq ! ( d, dyndim( & [ 12 , 10 ] ) ) ;
1269+ assert_eq ! ( s, dyndim( & [ 20 , 1 ] ) ) ;
1270+ }
11941271}
0 commit comments