11use ndarray:: prelude:: * ;
22use ndarray:: { Data , RemoveAxis , Zip } ;
33
4+ use rawpointer:: PointerExt ;
5+
46use std:: cmp:: Ordering ;
57use std:: ptr:: copy_nonoverlapping;
68
9799 where
98100 D : RemoveAxis ,
99101 {
100- let axis = axis;
101102 let axis_len = self . len_of ( axis) ;
103+ let axis_stride = self . stride_of ( axis) ;
102104 assert_eq ! ( axis_len, perm. indices. len( ) ) ;
103105 debug_assert ! ( perm. correct( ) ) ;
104106
@@ -112,26 +114,48 @@ where
112114 // logically move ownership of all elements from self into result
113115 // the result realizes this ownership at .assume_init() further down
114116 let mut moved_elements = 0 ;
117+
118+ // the permutation vector is used like this:
119+ //
120+ // index: 0 1 2 3 (index in result)
121+ // permut: 2 3 0 1 (index in the source)
122+ //
123+ // move source 2 -> result 0,
124+ // move source 3 -> result 1,
125+ // move source 0 -> result 2,
126+ // move source 1 -> result 3,
127+ // et.c.
128+
129+ let source_0 = self . raw_view ( ) . index_axis_move ( axis, 0 ) ;
130+
115131 Zip :: from ( & perm. indices )
116132 . and ( result. axis_iter_mut ( axis) )
117133 . for_each ( |& perm_i, result_pane| {
118- // possible improvement: use unchecked indexing for `index_axis`
134+ // Use a shortcut to avoid bounds checking in `index_axis` for the source.
135+ //
136+ // It works because for any given element pointer in the array we have the
137+ // relationship:
138+ //
139+ // .index_axis(axis, 0) + .stride_of(axis) * j == .index_axis(axis, j)
140+ //
141+ // where + is pointer arithmetic on the element pointers.
142+ //
143+ // Here source_0 and the offset is equivalent to self.index_axis(axis, perm_i)
119144 Zip :: from ( result_pane)
120- . and ( self . index_axis ( axis, perm_i) )
121- . for_each ( |to, from| {
145+ . and ( source_0. clone ( ) )
146+ . for_each ( |to, from_0| {
147+ let from = from_0. stride_offset ( axis_stride, perm_i) ;
122148 copy_nonoverlapping ( from, to. as_mut_ptr ( ) , 1 ) ;
123149 moved_elements += 1 ;
124150 } ) ;
125151 } ) ;
126152 debug_assert_eq ! ( result. len( ) , moved_elements) ;
127- // panic-critical begin: we must not panic
128- // forget moved array elements but not its vec
129- // old_storage drops empty
153+ // forget the old elements but not the allocation
130154 let mut old_storage = self . into_raw_vec ( ) ;
131155 old_storage. set_len ( 0 ) ;
132156
157+ // transfer ownership of the elements into the result
133158 result. assume_init ( )
134- // panic-critical end
135159 }
136160 }
137161}
@@ -179,31 +203,46 @@ mod tests {
179203 [ 75600.94 , 17. ] ,
180204 [ 75601.06 , 18. ] ,
181205 ] ;
206+ let answer = array ! [
207+ [ 75600.09 , 10. ] ,
208+ [ 75600.21 , 11. ] ,
209+ [ 75600.45 , 13. ] ,
210+ [ 75600.58 , 14. ] ,
211+ [ 75600.82 , 16. ] ,
212+ [ 75600.94 , 17. ] ,
213+ [ 75601.06 , 18. ] ,
214+ [ 75601.33 , 12. ] ,
215+ [ 107998.96 , 1. ] ,
216+ [ 107999.08 , 2. ] ,
217+ [ 107999.20 , 3. ] ,
218+ [ 107999.45 , 5. ] ,
219+ [ 107999.57 , 6. ] ,
220+ [ 107999.81 , 8. ] ,
221+ [ 107999.94 , 9. ] ,
222+ [ 108000.33 , 4. ] ,
223+ [ 108010.69 , 7. ] ,
224+ [ 109000.70 , 15. ] ,
225+ ] ;
226+
227+ // f layout copy of a
228+ let mut af = Array :: zeros ( a. dim ( ) . f ( ) ) ;
229+ af. assign ( & a) ;
230+
231+ // transposed copy of a
232+ let at = a. t ( ) . to_owned ( ) ;
182233
234+ // c layout permute
183235 let perm = a. sort_axis_by ( Axis ( 0 ) , |i, j| a[ [ i, 0 ] ] < a[ [ j, 0 ] ] ) ;
236+
184237 let b = a. permute_axis ( Axis ( 0 ) , & perm) ;
185- assert_eq ! (
186- b,
187- array![
188- [ 75600.09 , 10. ] ,
189- [ 75600.21 , 11. ] ,
190- [ 75600.45 , 13. ] ,
191- [ 75600.58 , 14. ] ,
192- [ 75600.82 , 16. ] ,
193- [ 75600.94 , 17. ] ,
194- [ 75601.06 , 18. ] ,
195- [ 75601.33 , 12. ] ,
196- [ 107998.96 , 1. ] ,
197- [ 107999.08 , 2. ] ,
198- [ 107999.20 , 3. ] ,
199- [ 107999.45 , 5. ] ,
200- [ 107999.57 , 6. ] ,
201- [ 107999.81 , 8. ] ,
202- [ 107999.94 , 9. ] ,
203- [ 108000.33 , 4. ] ,
204- [ 108010.69 , 7. ] ,
205- [ 109000.70 , 15. ] ,
206- ]
207- ) ;
238+ assert_eq ! ( b, answer) ;
239+
240+ // f layout permute
241+ let bf = af. permute_axis ( Axis ( 0 ) , & perm) ;
242+ assert_eq ! ( bf, answer) ;
243+
244+ // transposed permute
245+ let bt = at. permute_axis ( Axis ( 1 ) , & perm) ;
246+ assert_eq ! ( bt, answer. t( ) ) ;
208247 }
209248}
0 commit comments