@@ -759,6 +759,32 @@ where D: Dimension
759759 }
760760}
761761
762+ /// Attempt to merge axes if possible, starting from the back
763+ ///
764+ /// Given axes [Axis(0), Axis(1), Axis(2), Axis(3)] this attempts
765+ /// to merge all axes one by one into Axis(3); when/if this fails,
766+ /// it attempts to merge the rest of the axes together into the next
767+ /// axis in line, for example a result could be:
768+ ///
769+ /// [1, Axis(0) + Axis(1), 1, Axis(2) + Axis(3)] where `+` would
770+ /// mean axes were merged.
771+ pub ( crate ) fn merge_axes_from_the_back < D > ( dim : & mut D , strides : & mut D )
772+ where D : Dimension
773+ {
774+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
775+ match dim. ndim ( ) {
776+ 0 | 1 => { }
777+ n => {
778+ let mut last = n - 1 ;
779+ for i in ( 0 ..last) . rev ( ) {
780+ if !merge_axes ( dim, strides, Axis ( i) , Axis ( last) ) {
781+ last = i;
782+ }
783+ }
784+ }
785+ }
786+ }
787+
762788/// Move the axis which has the smallest absolute stride and a length
763789/// greater than one to be the last axis.
764790pub fn move_min_stride_axis_to_last < D > ( dim : & mut D , strides : & mut D )
@@ -822,6 +848,30 @@ where D: Dimension
822848 * strides = new_strides;
823849}
824850
851+ /// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
852+ /// stride
853+ ///
854+ /// The axes are sorted according to the .abs() of their stride.
855+ pub ( crate ) fn sort_axes_to_standard < D > ( dim : & mut D , strides : & mut D )
856+ where D : Dimension
857+ {
858+ debug_assert ! ( dim. ndim( ) > 1 ) ;
859+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
860+ // bubble sort axes
861+ let mut changed = true ;
862+ while changed {
863+ changed = false ;
864+ for i in 0 ..dim. ndim ( ) - 1 {
865+ // make sure higher stride axes sort before.
866+ if strides. get_stride ( Axis ( i) ) . abs ( ) < strides. get_stride ( Axis ( i + 1 ) ) . abs ( ) {
867+ changed = true ;
868+ dim. slice_mut ( ) . swap ( i, i + 1 ) ;
869+ strides. slice_mut ( ) . swap ( i, i + 1 ) ;
870+ }
871+ }
872+ }
873+ }
874+
825875#[ cfg( test) ]
826876mod test
827877{
@@ -831,6 +881,7 @@ mod test
831881 can_index_slice_not_custom,
832882 extended_gcd,
833883 max_abs_offset_check_overflow,
884+ merge_axes_from_the_back,
834885 slice_min_max,
835886 slices_intersect,
836887 solve_linear_diophantine_eq,
@@ -1215,4 +1266,27 @@ mod test
12151266 assert_eq ! ( d, dans) ;
12161267 assert_eq ! ( s, sans) ;
12171268 }
1269+
1270+ #[ test]
1271+ fn test_merge_axes_from_the_back ( )
1272+ {
1273+ let dyndim = Dim :: < & [ usize ] > ;
1274+
1275+ let mut d = Dim ( [ 3 , 4 , 5 ] ) ;
1276+ let mut s = Dim ( [ 20 , 5 , 1 ] ) ;
1277+ merge_axes_from_the_back ( & mut d, & mut s) ;
1278+ assert_eq ! ( d, Dim ( [ 1 , 1 , 60 ] ) ) ;
1279+ assert_eq ! ( s, Dim ( [ 20 , 5 , 1 ] ) ) ;
1280+
1281+ let mut d = Dim ( [ 3 , 4 , 5 , 2 ] ) ;
1282+ let mut s = Dim ( [ 80 , 20 , 2 , 1 ] ) ;
1283+ merge_axes_from_the_back ( & mut d, & mut s) ;
1284+ assert_eq ! ( d, Dim ( [ 1 , 12 , 1 , 10 ] ) ) ;
1285+ assert_eq ! ( s, Dim ( [ 80 , 20 , 2 , 1 ] ) ) ;
1286+ let mut d = d. into_dyn ( ) ;
1287+ let mut s = s. into_dyn ( ) ;
1288+ squeeze ( & mut d, & mut s) ;
1289+ assert_eq ! ( d, dyndim( & [ 12 , 10 ] ) ) ;
1290+ assert_eq ! ( s, dyndim( & [ 20 , 1 ] ) ) ;
1291+ }
12181292}
0 commit comments