5959//! });
6060//! ```
6161//!
62- //! The second example shows that while non-overlapping views are supported,
63- //! interleaved views which do not touch are currently not supported
64- //! due to over-approximating which borrows are in conflict.
62+ //! The second example shows that non-overlapping and interleaved views are also supported.
6563//!
6664//! ```rust
67- //! # use std::panic::{catch_unwind, AssertUnwindSafe};
68- //! #
6965//! use numpy::PyArray1;
7066//! use pyo3::{types::IntoPyDict, Python};
7167//!
7874//! let view3 = py.eval("array[::2]", None, Some(locals)).unwrap().downcast::<PyArray1<f64>>().unwrap();
7975//! let view4 = py.eval("array[1::2]", None, Some(locals)).unwrap().downcast::<PyArray1<f64>>().unwrap();
8076//!
81- //! let _view1 = view1.readwrite();
82- //! let _view2 = view2.readwrite();
77+ //! {
78+ //! let _view1 = view1.readwrite();
79+ //! let _view2 = view2.readwrite();
80+ //! }
8381//!
84- //! // Will fail at runtime even though `view3` and `view4`
85- //! // interleave as they are based on the same array.
86- //! let res = catch_unwind(AssertUnwindSafe(|| {
82+ //! {
8783//! let _view3 = view3.readwrite();
8884//! let _view4 = view4.readwrite();
89- //! }));
90- //! assert!(res.is_err());
85+ //! }
9186//! });
9287//! ```
9388//!
125120//!
126121//! # Limitations
127122//!
128- //! Note that the current implementation of this is an over-approximation: It will consider overlapping borrows
123+ //! Note that the current implementation of this is an over-approximation: It will consider borrows
129124//! potentially conflicting if the initial arrays have the same object at the end of their [base object chain][base].
130- //! For example, creating two views of the same underlying array by slicing can yield potentially conflicting borrows
131- //! even if the slice indices are chosen so that the two views do not actually share any elements by interleaving along one of its axes.
125+ //! Then, multiple conditions which are sufficient but not necessary to show the absence of conflicts are checked,
126+ //! but there are cases which they do not handle, for example slicing an array with a step size
127+ //! that does not divide its dimension along that axis. In these situations, borrows are rejected even though the arrays
128+ //! do not actually share any elements.
132129//!
133130//! This does limit the set of programs that can be written using safe Rust in way similar to rustc itself
134131//! which ensures that all accepted programs are memory safe but does not necessarily accept all memory safe programs.
135- //! The plan is to refine this checking to correctly handle more involved cases like interleaved views
136- //! into the same array and until then the unsafe method [`PyArray::as_array_mut`] can be used as an escape hatch.
132+ //! In the future, more involved cases like the example from above may be handled and until then,
133+ //! the unsafe method [`PyArray::as_array_mut`] can be used as an escape hatch.
137134//!
138135//! [base]: https://numpy.org/doc/stable/reference/c-api/types-and-structures.html#c.NPY_AO.base
139136#![ deny( missing_docs) ]
@@ -143,6 +140,7 @@ use std::collections::hash_map::{Entry, HashMap};
143140use std:: ops:: { Deref , Range } ;
144141
145142use ndarray:: { ArrayView , ArrayViewMut , Dimension , Ix1 , Ix2 , Ix3 , Ix4 , Ix5 , Ix6 , IxDyn } ;
143+ use num_integer:: gcd;
146144use pyo3:: { FromPyObject , PyAny , PyResult } ;
147145
148146use crate :: array:: PyArray ;
@@ -155,9 +153,28 @@ use crate::npyffi::{self, PyArrayObject, NPY_ARRAY_WRITEABLE};
155153#[ derive( PartialEq , Eq , Hash ) ]
156154struct BorrowKey {
157155 range : Range < usize > ,
156+ data_ptr : usize ,
157+ gcd_strides : isize ,
158158}
159159
160160impl BorrowKey {
161+ fn from_array < T , D > ( array : & PyArray < T , D > ) -> Self
162+ where
163+ T : Element ,
164+ D : Dimension ,
165+ {
166+ let range = data_range ( array) ;
167+
168+ let data_ptr = array. data ( ) as usize ;
169+ let gcd_strides = reduce ( array. strides ( ) . iter ( ) . copied ( ) , gcd) . unwrap_or ( 1 ) ;
170+
171+ Self {
172+ range,
173+ data_ptr,
174+ gcd_strides,
175+ }
176+ }
177+
161178 fn conflicts ( & self , other : & Self ) -> bool {
162179 debug_assert ! ( self . range. start <= self . range. end) ;
163180 debug_assert ! ( other. range. start <= other. range. end) ;
@@ -166,6 +183,21 @@ impl BorrowKey {
166183 return false ;
167184 }
168185
186+ // The Diophantine equation which describes whether any integers can combine the data pointers and strides of the two arrays s.t.
187+ // they yield the same element has a solution if and only if the GCD of all strides divides the difference of the data pointers.
188+ //
189+ // That solution could be out of bounds which mean that this is still an over-approximation.
190+ // It appears sufficient to handle typical cases like the color channels of an image,
191+ // but fails when slicing an array with a step size that does not divide the dimension along that axis.
192+ //
193+ // https://users.rust-lang.org/t/math-for-borrow-checking-numpy-arrays/73303
194+ let ptr_diff = abs_diff ( self . data_ptr , other. data_ptr ) as isize ;
195+ let gcd_strides = gcd ( self . gcd_strides , other. gcd_strides ) ;
196+
197+ if ptr_diff % gcd_strides != 0 {
198+ return false ;
199+ }
200+
169201 true
170202 }
171203}
@@ -192,10 +224,7 @@ impl BorrowFlags {
192224 D : Dimension ,
193225 {
194226 let address = base_address ( array) ;
195-
196- let key = BorrowKey {
197- range : data_range ( array) ,
198- } ;
227+ let key = BorrowKey :: from_array ( array) ;
199228
200229 // SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
201230 // and we are not calling into user code which might re-enter this function.
@@ -242,10 +271,7 @@ impl BorrowFlags {
242271 D : Dimension ,
243272 {
244273 let address = base_address ( array) ;
245-
246- let key = BorrowKey {
247- range : data_range ( array) ,
248- } ;
274+ let key = BorrowKey :: from_array ( array) ;
249275
250276 // SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
251277 // and we are not calling into user code which might re-enter this function.
@@ -272,10 +298,7 @@ impl BorrowFlags {
272298 D : Dimension ,
273299 {
274300 let address = base_address ( array) ;
275-
276- let key = BorrowKey {
277- range : data_range ( array) ,
278- } ;
301+ let key = BorrowKey :: from_array ( array) ;
279302
280303 // SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
281304 // and we are not calling into user code which might re-enter this function.
@@ -320,10 +343,7 @@ impl BorrowFlags {
320343 D : Dimension ,
321344 {
322345 let address = base_address ( array) ;
323-
324- let key = BorrowKey {
325- range : data_range ( array) ,
326- } ;
346+ let key = BorrowKey :: from_array ( array) ;
327347
328348 // SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
329349 // and we are not calling into user code which might re-enter this function.
@@ -628,6 +648,25 @@ where
628648 Range { start, end }
629649}
630650
651+ // FIXME(adamreichold): Use `usize::abs_diff` from std when that becomes stable.
652+ fn abs_diff ( lhs : usize , rhs : usize ) -> usize {
653+ if lhs >= rhs {
654+ lhs - rhs
655+ } else {
656+ rhs - lhs
657+ }
658+ }
659+
660+ // FIXME(adamreichold): Use `Iterator::reduce` from std when our MSRV reaches 1.51.
661+ fn reduce < I , F > ( mut iter : I , f : F ) -> Option < I :: Item >
662+ where
663+ I : Iterator ,
664+ F : FnMut ( I :: Item , I :: Item ) -> I :: Item ,
665+ {
666+ let first = iter. next ( ) ?;
667+ Some ( iter. fold ( first, f) )
668+ }
669+
631670#[ cfg( test) ]
632671mod tests {
633672 use super :: * ;
@@ -650,7 +689,7 @@ mod tests {
650689 assert_eq ! ( base_address, array as * const _ as usize ) ;
651690
652691 let data_range = data_range ( array) ;
653- assert_eq ! ( data_range. start, unsafe { array. data( ) } as usize ) ;
692+ assert_eq ! ( data_range. start, array. data( ) as usize ) ;
654693 assert_eq ! ( data_range. end, unsafe { array. data( ) . add( 15 ) } as usize ) ;
655694 } ) ;
656695 }
@@ -668,7 +707,7 @@ mod tests {
668707 assert_eq ! ( base_address, base as usize ) ;
669708
670709 let data_range = data_range ( array) ;
671- assert_eq ! ( data_range. start, unsafe { array. data( ) } as usize ) ;
710+ assert_eq ! ( data_range. start, array. data( ) as usize ) ;
672711 assert_eq ! ( data_range. end, unsafe { array. data( ) . add( 15 ) } as usize ) ;
673712 } ) ;
674713 }
@@ -694,7 +733,7 @@ mod tests {
694733 assert_eq ! ( base_address, base as usize ) ;
695734
696735 let data_range = data_range ( view) ;
697- assert_eq ! ( data_range. start, unsafe { view. data( ) } as usize ) ;
736+ assert_eq ! ( data_range. start, view. data( ) as usize ) ;
698737 assert_eq ! ( data_range. end, unsafe { view. data( ) . add( 12 ) } as usize ) ;
699738 } ) ;
700739 }
@@ -724,7 +763,7 @@ mod tests {
724763 assert_eq ! ( base_address, base as usize ) ;
725764
726765 let data_range = data_range ( view) ;
727- assert_eq ! ( data_range. start, unsafe { view. data( ) } as usize ) ;
766+ assert_eq ! ( data_range. start, view. data( ) as usize ) ;
728767 assert_eq ! ( data_range. end, unsafe { view. data( ) . add( 12 ) } as usize ) ;
729768 } ) ;
730769 }
@@ -763,7 +802,7 @@ mod tests {
763802 assert_eq ! ( base_address, base as usize ) ;
764803
765804 let data_range = data_range ( view2) ;
766- assert_eq ! ( data_range. start, unsafe { view2. data( ) } as usize ) ;
805+ assert_eq ! ( data_range. start, view2. data( ) as usize ) ;
767806 assert_eq ! ( data_range. end, unsafe { view2. data( ) . add( 6 ) } as usize ) ;
768807 } ) ;
769808 }
@@ -806,7 +845,7 @@ mod tests {
806845 assert_eq ! ( base_address, base as usize ) ;
807846
808847 let data_range = data_range ( view2) ;
809- assert_eq ! ( data_range. start, unsafe { view2. data( ) } as usize ) ;
848+ assert_eq ! ( data_range. start, view2. data( ) as usize ) ;
810849 assert_eq ! ( data_range. end, unsafe { view2. data( ) . add( 6 ) } as usize ) ;
811850 } ) ;
812851 }
@@ -836,4 +875,63 @@ mod tests {
836875 assert_eq ! ( data_range. end, unsafe { view. data( ) . offset( 6 ) } as usize ) ;
837876 } ) ;
838877 }
878+
879+ #[ test]
880+ fn view_with_non_dividing_strides ( ) {
881+ Python :: with_gil ( |py| {
882+ let array = PyArray :: < f64 , _ > :: zeros ( py, ( 10 , 10 ) , false ) ;
883+ let locals = [ ( "array" , array) ] . into_py_dict ( py) ;
884+
885+ let view1 = py
886+ . eval ( "array[:,::3]" , None , Some ( locals) )
887+ . unwrap ( )
888+ . downcast :: < PyArray2 < f64 > > ( )
889+ . unwrap ( ) ;
890+
891+ let key1 = BorrowKey :: from_array ( view1) ;
892+
893+ assert_eq ! ( view1. strides( ) , & [ 80 , 24 ] ) ;
894+ assert_eq ! ( key1. gcd_strides, 8 ) ;
895+
896+ let view2 = py
897+ . eval ( "array[:,1::3]" , None , Some ( locals) )
898+ . unwrap ( )
899+ . downcast :: < PyArray2 < f64 > > ( )
900+ . unwrap ( ) ;
901+
902+ let key2 = BorrowKey :: from_array ( view2) ;
903+
904+ assert_eq ! ( view2. strides( ) , & [ 80 , 24 ] ) ;
905+ assert_eq ! ( key2. gcd_strides, 8 ) ;
906+
907+ let view3 = py
908+ . eval ( "array[:,::2]" , None , Some ( locals) )
909+ . unwrap ( )
910+ . downcast :: < PyArray2 < f64 > > ( )
911+ . unwrap ( ) ;
912+
913+ let key3 = BorrowKey :: from_array ( view3) ;
914+
915+ assert_eq ! ( view3. strides( ) , & [ 80 , 16 ] ) ;
916+ assert_eq ! ( key3. gcd_strides, 16 ) ;
917+
918+ let view4 = py
919+ . eval ( "array[:,1::2]" , None , Some ( locals) )
920+ . unwrap ( )
921+ . downcast :: < PyArray2 < f64 > > ( )
922+ . unwrap ( ) ;
923+
924+ let key4 = BorrowKey :: from_array ( view4) ;
925+
926+ assert_eq ! ( view4. strides( ) , & [ 80 , 16 ] ) ;
927+ assert_eq ! ( key4. gcd_strides, 16 ) ;
928+
929+ assert ! ( !key3. conflicts( & key4) ) ;
930+ assert ! ( key1. conflicts( & key3) ) ;
931+ assert ! ( key2. conflicts( & key4) ) ;
932+
933+ // This is a false conflict where all aliasing indices like (0,7) and (2,0) are out of bounds.
934+ assert ! ( key1. conflicts( & key2) ) ;
935+ } ) ;
936+ }
839937}
0 commit comments