@@ -738,11 +738,19 @@ where
738738
739739#[ derive( Debug ) ]
740740pub struct AxisIterCore < A , D > {
741+ /// Index along the axis of the value of `.next()`, relative to the start
742+ /// of the axis.
741743 index : Ix ,
742- len : Ix ,
744+ /// (Exclusive) upper bound on `index`. Initially, this is equal to the
745+ /// length of the axis.
746+ end : Ix ,
747+ /// Stride along the axis (offset between consecutive pointers).
743748 stride : Ixs ,
749+ /// Shape of the iterator's items.
744750 inner_dim : D ,
751+ /// Strides of the iterator's items.
745752 inner_strides : D ,
753+ /// Pointer corresponding to `index == 0`.
746754 ptr : * mut A ,
747755}
748756
@@ -751,7 +759,7 @@ clone_bounds!(
751759 AxisIterCore [ A , D ] {
752760 @copy {
753761 index,
754- len ,
762+ end ,
755763 stride,
756764 ptr,
757765 }
@@ -767,54 +775,53 @@ impl<A, D: Dimension> AxisIterCore<A, D> {
767775 Di : RemoveAxis < Smaller = D > ,
768776 S : Data < Elem = A > ,
769777 {
770- let shape = v. shape ( ) [ axis. index ( ) ] ;
771- let stride = v. strides ( ) [ axis. index ( ) ] ;
772778 AxisIterCore {
773779 index : 0 ,
774- len : shape ,
775- stride,
780+ end : v . len_of ( axis ) ,
781+ stride : v . stride_of ( axis ) ,
776782 inner_dim : v. dim . remove_axis ( axis) ,
777783 inner_strides : v. strides . remove_axis ( axis) ,
778784 ptr : v. ptr ,
779785 }
780786 }
781787
788+ #[ inline]
782789 unsafe fn offset ( & self , index : usize ) -> * mut A {
783790 debug_assert ! (
784- index <= self . len ,
785- "index={}, len ={}, stride={}" ,
791+ index < self . end ,
792+ "index={}, end ={}, stride={}" ,
786793 index,
787- self . len ,
794+ self . end ,
788795 self . stride
789796 ) ;
790797 self . ptr . offset ( index as isize * self . stride )
791798 }
792799
793- /// Split the iterator at index, yielding two disjoint iterators.
800+ /// Splits the iterator at ` index` , yielding two disjoint iterators.
794801 ///
795- /// **Panics** if `index` is strictly greater than the iterator's length
802+ /// `index` is relative to the current state of the iterator (which is not
803+ /// necessarily the start of the axis).
804+ ///
805+ /// **Panics** if `index` is strictly greater than the iterator's remaining
806+ /// length.
796807 fn split_at ( self , index : usize ) -> ( Self , Self ) {
797- assert ! ( index <= self . len) ;
798- let right_ptr = if index != self . len {
799- unsafe { self . offset ( index) }
800- } else {
801- self . ptr
802- } ;
808+ assert ! ( index <= self . len( ) ) ;
809+ let mid = self . index + index;
803810 let left = AxisIterCore {
804- index : 0 ,
805- len : index ,
811+ index : self . index ,
812+ end : mid ,
806813 stride : self . stride ,
807814 inner_dim : self . inner_dim . clone ( ) ,
808815 inner_strides : self . inner_strides . clone ( ) ,
809816 ptr : self . ptr ,
810817 } ;
811818 let right = AxisIterCore {
812- index : 0 ,
813- len : self . len - index ,
819+ index : mid ,
820+ end : self . end ,
814821 stride : self . stride ,
815822 inner_dim : self . inner_dim ,
816823 inner_strides : self . inner_strides ,
817- ptr : right_ptr ,
824+ ptr : self . ptr ,
818825 } ;
819826 ( left, right)
820827 }
@@ -827,7 +834,7 @@ where
827834 type Item = * mut A ;
828835
829836 fn next ( & mut self ) -> Option < Self :: Item > {
830- if self . index >= self . len {
837+ if self . index >= self . end {
831838 None
832839 } else {
833840 let ptr = unsafe { self . offset ( self . index ) } ;
@@ -837,7 +844,7 @@ where
837844 }
838845
839846 fn size_hint ( & self ) -> ( usize , Option < usize > ) {
840- let len = self . len - self . index ;
847+ let len = self . len ( ) ;
841848 ( len, Some ( len) )
842849 }
843850}
@@ -847,16 +854,25 @@ where
847854 D : Dimension ,
848855{
849856 fn next_back ( & mut self ) -> Option < Self :: Item > {
850- if self . index >= self . len {
857+ if self . index >= self . end {
851858 None
852859 } else {
853- self . len -= 1 ;
854- let ptr = unsafe { self . offset ( self . len ) } ;
860+ let ptr = unsafe { self . offset ( self . end - 1 ) } ;
861+ self . end -= 1 ;
855862 Some ( ptr)
856863 }
857864 }
858865}
859866
867+ impl < A , D > ExactSizeIterator for AxisIterCore < A , D >
868+ where
869+ D : Dimension ,
870+ {
871+ fn len ( & self ) -> usize {
872+ self . end - self . index
873+ }
874+ }
875+
860876/// An iterator that traverses over an axis and
861877/// and yields each subview.
862878///
@@ -899,9 +915,13 @@ impl<'a, A, D: Dimension> AxisIter<'a, A, D> {
899915 }
900916 }
901917
902- /// Split the iterator at index, yielding two disjoint iterators.
918+ /// Splits the iterator at ` index` , yielding two disjoint iterators.
903919 ///
904- /// **Panics** if `index` is strictly greater than the iterator's length
920+ /// `index` is relative to the current state of the iterator (which is not
921+ /// necessarily the start of the axis).
922+ ///
923+ /// **Panics** if `index` is strictly greater than the iterator's remaining
924+ /// length.
905925 pub fn split_at ( self , index : usize ) -> ( Self , Self ) {
906926 let ( left, right) = self . iter . split_at ( index) ;
907927 (
@@ -946,7 +966,7 @@ where
946966 D : Dimension ,
947967{
948968 fn len ( & self ) -> usize {
949- self . size_hint ( ) . 0
969+ self . iter . len ( )
950970 }
951971}
952972
@@ -981,9 +1001,13 @@ impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> {
9811001 }
9821002 }
9831003
984- /// Split the iterator at index, yielding two disjoint iterators.
1004+ /// Splits the iterator at ` index` , yielding two disjoint iterators.
9851005 ///
986- /// **Panics** if `index` is strictly greater than the iterator's length
1006+ /// `index` is relative to the current state of the iterator (which is not
1007+ /// necessarily the start of the axis).
1008+ ///
1009+ /// **Panics** if `index` is strictly greater than the iterator's remaining
1010+ /// length.
9871011 pub fn split_at ( self , index : usize ) -> ( Self , Self ) {
9881012 let ( left, right) = self . iter . split_at ( index) ;
9891013 (
@@ -1028,7 +1052,7 @@ where
10281052 D : Dimension ,
10291053{
10301054 fn len ( & self ) -> usize {
1031- self . size_hint ( ) . 0
1055+ self . iter . len ( )
10321056 }
10331057}
10341058
@@ -1048,7 +1072,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> {
10481072 }
10491073 #[ doc( hidden) ]
10501074 fn as_ptr ( & self ) -> Self :: Ptr {
1051- self . iter . ptr
1075+ if self . len ( ) > 0 {
1076+ // `self.iter.index` is guaranteed to be in-bounds if any of the
1077+ // iterator remains (i.e. if `self.len() > 0`).
1078+ unsafe { self . iter . offset ( self . iter . index ) }
1079+ } else {
1080+ // In this case, `self.iter.index` may be past the end, so we must
1081+ // not call `.offset()`. It's okay to return a dangling pointer
1082+ // because it will never be used in the length 0 case.
1083+ std:: ptr:: NonNull :: dangling ( ) . as_ptr ( )
1084+ }
10521085 }
10531086
10541087 fn contiguous_stride ( & self ) -> isize {
@@ -1065,7 +1098,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> {
10651098 }
10661099 #[ doc( hidden) ]
10671100 unsafe fn uget_ptr ( & self , i : & Self :: Dim ) -> Self :: Ptr {
1068- self . iter . ptr . offset ( self . iter . stride * i[ 0 ] as isize )
1101+ self . iter . offset ( self . iter . index + i[ 0 ] )
10691102 }
10701103
10711104 #[ doc( hidden) ]
@@ -1096,7 +1129,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
10961129 }
10971130 #[ doc( hidden) ]
10981131 fn as_ptr ( & self ) -> Self :: Ptr {
1099- self . iter . ptr
1132+ if self . len ( ) > 0 {
1133+ // `self.iter.index` is guaranteed to be in-bounds if any of the
1134+ // iterator remains (i.e. if `self.len() > 0`).
1135+ unsafe { self . iter . offset ( self . iter . index ) }
1136+ } else {
1137+ // In this case, `self.iter.index` may be past the end, so we must
1138+ // not call `.offset()`. It's okay to return a dangling pointer
1139+ // because it will never be used in the length 0 case.
1140+ std:: ptr:: NonNull :: dangling ( ) . as_ptr ( )
1141+ }
11001142 }
11011143
11021144 fn contiguous_stride ( & self ) -> isize {
@@ -1113,7 +1155,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
11131155 }
11141156 #[ doc( hidden) ]
11151157 unsafe fn uget_ptr ( & self , i : & Self :: Dim ) -> Self :: Ptr {
1116- self . iter . ptr . offset ( self . iter . stride * i[ 0 ] as isize )
1158+ self . iter . offset ( self . iter . index + i[ 0 ] )
11171159 }
11181160
11191161 #[ doc( hidden) ]
@@ -1164,21 +1206,28 @@ clone_bounds!(
11641206///
11651207/// Returns an axis iterator with the correct stride to move between chunks,
11661208/// the number of chunks, and the shape of the last chunk.
1209+ ///
1210+ /// **Panics** if `size == 0`.
11671211fn chunk_iter_parts < A , D : Dimension > (
11681212 v : ArrayView < ' _ , A , D > ,
11691213 axis : Axis ,
11701214 size : usize ,
11711215) -> ( AxisIterCore < A , D > , usize , D ) {
1216+ assert_ne ! ( size, 0 , "Chunk size must be nonzero." ) ;
11721217 let axis_len = v. len_of ( axis) ;
1173- let size = if size > axis_len { axis_len } else { size } ;
11741218 let n_whole_chunks = axis_len / size;
11751219 let chunk_remainder = axis_len % size;
11761220 let iter_len = if chunk_remainder == 0 {
11771221 n_whole_chunks
11781222 } else {
11791223 n_whole_chunks + 1
11801224 } ;
1181- let stride = v. stride_of ( axis) * size as isize ;
1225+ let stride = if n_whole_chunks == 0 {
1226+ // This case avoids potential overflow when `size > axis_len`.
1227+ 0
1228+ } else {
1229+ v. stride_of ( axis) * size as isize
1230+ } ;
11821231
11831232 let axis = axis. index ( ) ;
11841233 let mut inner_dim = v. dim . clone ( ) ;
@@ -1193,7 +1242,7 @@ fn chunk_iter_parts<A, D: Dimension>(
11931242
11941243 let iter = AxisIterCore {
11951244 index : 0 ,
1196- len : iter_len,
1245+ end : iter_len,
11971246 stride,
11981247 inner_dim,
11991248 inner_strides : v. strides ,
@@ -1270,7 +1319,7 @@ macro_rules! chunk_iter_impl {
12701319 D : Dimension ,
12711320 {
12721321 fn next_back( & mut self ) -> Option <Self :: Item > {
1273- let is_uneven = self . iter. len > self . n_whole_chunks;
1322+ let is_uneven = self . iter. end > self . n_whole_chunks;
12741323 let res = self . iter. next_back( ) ;
12751324 self . get_subview( res, is_uneven)
12761325 }
0 commit comments