@@ -261,6 +261,7 @@ impl<A, D> Array<A, D>
261261 return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
262262 }
263263
264+ let current_axis_len = self . len_of ( axis) ;
264265 let remaining_shape = self . raw_dim ( ) . remove_axis ( axis) ;
265266 let array_rem_shape = array. raw_dim ( ) . remove_axis ( axis) ;
266267
@@ -280,22 +281,46 @@ impl<A, D> Array<A, D>
280281
281282 let self_is_empty = self . is_empty ( ) ;
282283
283- // array must be empty or have `axis` as the outermost (longest stride)
284- // axis
285- if !( self_is_empty ||
286- self . axes ( ) . max_by_key ( |ax| ax. stride ) . map ( |ax| ax. axis ) == Some ( axis) )
287- {
288- return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleLayout ) ) ;
284+ // array must be empty or have `axis` as the outermost (longest stride) axis
285+ if !self_is_empty && current_axis_len > 1 {
286+ // `axis` must be max stride axis or equal to its stride
287+ let max_stride_axis = self . axes ( ) . max_by_key ( |ax| ax. stride ) . unwrap ( ) ;
288+ if max_stride_axis. axis != axis && max_stride_axis. stride > self . stride_of ( axis) {
289+ return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleLayout ) ) ;
290+ }
289291 }
290292
291293 // array must be be "full" (have no exterior holes)
292294 if self . len ( ) != self . data . len ( ) {
293295 return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleLayout ) ) ;
294296 }
297+
295298 let strides = if self_is_empty {
296- // recompute strides - if the array was previously empty, it could have
297- // zeros in strides.
298- res_dim. default_strides ( )
299+ // recompute strides - if the array was previously empty, it could have zeros in
300+ // strides.
301+ // The new order is based on c/f-contig but must have `axis` as outermost axis.
302+ if axis == Axis ( self . ndim ( ) - 1 ) {
303+ // prefer f-contig when appending to the last axis
304+ // Axis n - 1 is outermost axis
305+ res_dim. fortran_strides ( )
306+ } else {
307+ // Default with modification
308+ res_dim. slice_mut ( ) . swap ( 0 , axis. index ( ) ) ;
309+ let mut strides = res_dim. default_strides ( ) ;
310+ res_dim. slice_mut ( ) . swap ( 0 , axis. index ( ) ) ;
311+ strides. slice_mut ( ) . swap ( 0 , axis. index ( ) ) ;
312+ strides
313+ }
314+ } else if current_axis_len == 1 {
315+ // This is the outermost/longest stride axis; so we find the max across the other axes
316+ let new_stride = self . axes ( ) . fold ( 1 , |acc, ax| {
317+ if ax. axis == axis { acc } else {
318+ Ord :: max ( acc, ax. len as isize * ax. stride )
319+ }
320+ } ) ;
321+ let mut strides = self . strides . clone ( ) ;
322+ strides[ axis. index ( ) ] = new_stride as usize ;
323+ strides
299324 } else {
300325 self . strides . clone ( )
301326 } ;
@@ -383,7 +408,8 @@ where
383408 return ;
384409 }
385410 sort_axes_impl ( & mut a. dim , & mut a. strides , & mut b. dim , & mut b. strides ) ;
386- debug_assert ! ( a. is_standard_layout( ) ) ;
411+ debug_assert ! ( a. is_standard_layout( ) , "not std layout dim: {:?}, strides: {:?}" ,
412+ a. shape( ) , a. strides( ) ) ;
387413}
388414
389415fn sort_axes_impl < D > ( adim : & mut D , astrides : & mut D , bdim : & mut D , bstrides : & mut D )
0 commit comments