@@ -180,7 +180,7 @@ use crate::dtype::Element;
180180use crate :: error:: { BorrowError , NotContiguousError } ;
181181use crate :: npyffi:: { self , PyArrayObject , NPY_ARRAY_WRITEABLE } ;
182182
183- #[ derive( PartialEq , Eq , Hash ) ]
183+ #[ derive( Clone , Copy , PartialEq , Eq , Hash ) ]
184184struct BorrowKey {
185185 /// exclusive range of lowest and highest address covered by array
186186 range : ( usize , usize ) ,
@@ -199,7 +199,7 @@ impl BorrowKey {
199199 let range = data_range ( array) ;
200200
201201 let data_ptr = array. data ( ) as usize ;
202- let gcd_strides = reduce ( array. strides ( ) . iter ( ) . copied ( ) , gcd ) . unwrap_or ( 1 ) ;
202+ let gcd_strides = gcd_strides ( array. strides ( ) ) ;
203203
204204 Self {
205205 range,
@@ -252,16 +252,9 @@ impl BorrowFlags {
252252 ( * self . 0 . get ( ) ) . get_or_insert_with ( AHashMap :: new)
253253 }
254254
255- fn acquire < T , D > ( & self , array : & PyArray < T , D > ) -> Result < ( ) , BorrowError >
256- where
257- T : Element ,
258- D : Dimension ,
259- {
260- let address = base_address ( array) ;
261- let key = BorrowKey :: from_array ( array) ;
262-
263- // SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
264- // and we are not calling into user code which might re-enter this function.
255+ fn acquire ( & self , _py : Python , address : usize , key : BorrowKey ) -> Result < ( ) , BorrowError > {
256+ // SAFETY: Having `_py` implies holding the GIL and
257+ // we are not calling into user code which might re-enter this function.
265258 let borrow_flags = unsafe { BORROW_FLAGS . get ( ) } ;
266259
267260 match borrow_flags. entry ( address) {
@@ -302,16 +295,9 @@ impl BorrowFlags {
302295 Ok ( ( ) )
303296 }
304297
305- fn release < T , D > ( & self , array : & PyArray < T , D > )
306- where
307- T : Element ,
308- D : Dimension ,
309- {
310- let address = base_address ( array) ;
311- let key = BorrowKey :: from_array ( array) ;
312-
313- // SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
314- // and we are not calling into user code which might re-enter this function.
298+ fn release ( & self , _py : Python , address : usize , key : BorrowKey ) {
299+ // SAFETY: Having `_py` implies holding the GIL and
300+ // we are not calling into user code which might re-enter this function.
315301 let borrow_flags = unsafe { BORROW_FLAGS . get ( ) } ;
316302
317303 let same_base_arrays = borrow_flags. get_mut ( & address) . unwrap ( ) ;
@@ -329,16 +315,9 @@ impl BorrowFlags {
329315 }
330316 }
331317
332- fn acquire_mut < T , D > ( & self , array : & PyArray < T , D > ) -> Result < ( ) , BorrowError >
333- where
334- T : Element ,
335- D : Dimension ,
336- {
337- let address = base_address ( array) ;
338- let key = BorrowKey :: from_array ( array) ;
339-
340- // SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
341- // and we are not calling into user code which might re-enter this function.
318+ fn acquire_mut ( & self , _py : Python , address : usize , key : BorrowKey ) -> Result < ( ) , BorrowError > {
319+ // SAFETY: Having `_py` implies holding the GIL and
320+ // we are not calling into user code which might re-enter this function.
342321 let borrow_flags = unsafe { BORROW_FLAGS . get ( ) } ;
343322
344323 match borrow_flags. entry ( address) {
@@ -373,16 +352,9 @@ impl BorrowFlags {
373352 Ok ( ( ) )
374353 }
375354
376- fn release_mut < T , D > ( & self , array : & PyArray < T , D > )
377- where
378- T : Element ,
379- D : Dimension ,
380- {
381- let address = base_address ( array) ;
382- let key = BorrowKey :: from_array ( array) ;
383-
384- // SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
385- // and we are not calling into user code which might re-enter this function.
355+ fn release_mut ( & self , _py : Python , address : usize , key : BorrowKey ) {
356+ // SAFETY: Having `_py` implies holding the GIL and
357+ // we are not calling into user code which might re-enter this function.
386358 let borrow_flags = unsafe { BORROW_FLAGS . get ( ) } ;
387359
388360 let same_base_arrays = borrow_flags. get_mut ( & address) . unwrap ( ) ;
@@ -403,10 +375,16 @@ static BORROW_FLAGS: BorrowFlags = BorrowFlags::new();
403375/// i.e. that only shared references into the interior of the array can be created safely.
404376///
405377/// See the [module-level documentation](self) for more.
406- pub struct PyReadonlyArray < ' py , T , D > ( & ' py PyArray < T , D > )
378+ #[ repr( C ) ]
379+ pub struct PyReadonlyArray < ' py , T , D >
407380where
408381 T : Element ,
409- D : Dimension ;
382+ D : Dimension ,
383+ {
384+ array : & ' py PyArray < T , D > ,
385+ address : usize ,
386+ key : BorrowKey ,
387+ }
410388
411389/// Read-only borrow of a one-dimensional array.
412390pub type PyReadonlyArray1 < ' py , T > = PyReadonlyArray < ' py , T , Ix1 > ;
@@ -437,7 +415,7 @@ where
437415 type Target = PyArray < T , D > ;
438416
439417 fn deref ( & self ) -> & Self :: Target {
440- self . 0
418+ self . array
441419 }
442420}
443421
@@ -454,23 +432,30 @@ where
454432 D : Dimension ,
455433{
456434 pub ( crate ) fn try_new ( array : & ' py PyArray < T , D > ) -> Result < Self , BorrowError > {
457- BORROW_FLAGS . acquire ( array) ?;
435+ let address = base_address ( array) ;
436+ let key = BorrowKey :: from_array ( array) ;
458437
459- Ok ( Self ( array) )
438+ BORROW_FLAGS . acquire ( array. py ( ) , address, key) ?;
439+
440+ Ok ( Self {
441+ array,
442+ address,
443+ key,
444+ } )
460445 }
461446
462447 /// Provides an immutable array view of the interior of the NumPy array.
463448 #[ inline( always) ]
464449 pub fn as_array ( & self ) -> ArrayView < T , D > {
465450 // SAFETY: Global borrow flags ensure aliasing discipline.
466- unsafe { self . 0 . as_array ( ) }
451+ unsafe { self . array . as_array ( ) }
467452 }
468453
469454 /// Provide an immutable slice view of the interior of the NumPy array if it is contiguous.
470455 #[ inline( always) ]
471456 pub fn as_slice ( & self ) -> Result < & [ T ] , NotContiguousError > {
472457 // SAFETY: Global borrow flags ensure aliasing discipline.
473- unsafe { self . 0 . as_slice ( ) }
458+ unsafe { self . array . as_slice ( ) }
474459 }
475460
476461 /// Provide an immutable reference to an element of the NumPy array if the index is within bounds.
@@ -479,7 +464,7 @@ where
479464 where
480465 I : NpyIndex < Dim = D > ,
481466 {
482- unsafe { self . 0 . get ( index) }
467+ unsafe { self . array . get ( index) }
483468 }
484469}
485470
@@ -489,7 +474,15 @@ where
489474 D : Dimension ,
490475{
491476 fn clone ( & self ) -> Self {
492- Self :: try_new ( self . 0 ) . unwrap ( )
477+ BORROW_FLAGS
478+ . acquire ( self . array . py ( ) , self . address , self . key )
479+ . unwrap ( ) ;
480+
481+ Self {
482+ array : self . array ,
483+ address : self . address ,
484+ key : self . key ,
485+ }
493486 }
494487}
495488
@@ -499,7 +492,7 @@ where
499492 D : Dimension ,
500493{
501494 fn drop ( & mut self ) {
502- BORROW_FLAGS . release ( self . 0 ) ;
495+ BORROW_FLAGS . release ( self . array . py ( ) , self . address , self . key ) ;
503496 }
504497}
505498
@@ -525,10 +518,16 @@ where
525518/// i.e. that only a single exclusive reference into the interior of the array can be created safely.
526519///
527520/// See the [module-level documentation](self) for more.
528- pub struct PyReadwriteArray < ' py , T , D > ( & ' py PyArray < T , D > )
521+ #[ repr( C ) ]
522+ pub struct PyReadwriteArray < ' py , T , D >
529523where
530524 T : Element ,
531- D : Dimension ;
525+ D : Dimension ,
526+ {
527+ array : & ' py PyArray < T , D > ,
528+ address : usize ,
529+ key : BorrowKey ,
530+ }
532531
533532/// Read-write borrow of a one-dimensional array.
534533pub type PyReadwriteArray1 < ' py , T > = PyReadwriteArray < ' py , T , Ix1 > ;
@@ -581,23 +580,30 @@ where
581580 return Err ( BorrowError :: NotWriteable ) ;
582581 }
583582
584- BORROW_FLAGS . acquire_mut ( array) ?;
583+ let address = base_address ( array) ;
584+ let key = BorrowKey :: from_array ( array) ;
585585
586- Ok ( Self ( array) )
586+ BORROW_FLAGS . acquire_mut ( array. py ( ) , address, key) ?;
587+
588+ Ok ( Self {
589+ array,
590+ address,
591+ key,
592+ } )
587593 }
588594
589595 /// Provides a mutable array view of the interior of the NumPy array.
590596 #[ inline( always) ]
591597 pub fn as_array_mut ( & mut self ) -> ArrayViewMut < T , D > {
592598 // SAFETY: Global borrow flags ensure aliasing discipline.
593- unsafe { self . 0 . as_array_mut ( ) }
599+ unsafe { self . array . as_array_mut ( ) }
594600 }
595601
596602 /// Provide a mutable slice view of the interior of the NumPy array if it is contiguous.
597603 #[ inline( always) ]
598604 pub fn as_slice_mut ( & mut self ) -> Result < & mut [ T ] , NotContiguousError > {
599605 // SAFETY: Global borrow flags ensure aliasing discipline.
600- unsafe { self . 0 . as_slice_mut ( ) }
606+ unsafe { self . array . as_slice_mut ( ) }
601607 }
602608
603609 /// Provide a mutable reference to an element of the NumPy array if the index is within bounds.
@@ -606,7 +612,7 @@ where
606612 where
607613 I : NpyIndex < Dim = D > ,
608614 {
609- unsafe { self . 0 . get_mut ( index) }
615+ unsafe { self . array . get_mut ( index) }
610616 }
611617}
612618
@@ -632,16 +638,16 @@ where
632638 /// });
633639 /// ```
634640 pub fn resize ( self , new_elems : usize ) -> PyResult < Self > {
635- BORROW_FLAGS . release_mut ( self . 0 ) ;
641+ let array = self . array ;
636642
637643 // SAFETY: Ownership of `self` proves exclusive access to the interior of the array.
638644 unsafe {
639- self . 0 . resize ( new_elems) ?;
645+ array . resize ( new_elems) ?;
640646 }
641647
642- BORROW_FLAGS . acquire_mut ( self . 0 ) ? ;
648+ drop ( self ) ;
643649
644- Ok ( self )
650+ Ok ( Self :: try_new ( array ) . unwrap ( ) )
645651 }
646652}
647653
@@ -651,7 +657,7 @@ where
651657 D : Dimension ,
652658{
653659 fn drop ( & mut self ) {
654- BORROW_FLAGS . release_mut ( self . 0 ) ;
660+ BORROW_FLAGS . release_mut ( self . array . py ( ) , self . address , self . key ) ;
655661 }
656662}
657663
@@ -726,6 +732,10 @@ where
726732 )
727733}
728734
735+ fn gcd_strides ( strides : & [ isize ] ) -> isize {
736+ reduce ( strides. iter ( ) . copied ( ) , gcd) . unwrap_or ( 1 )
737+ }
738+
729739// FIXME(adamreichold): Use `usize::abs_diff` from std when that becomes stable.
730740fn abs_diff ( lhs : usize , rhs : usize ) -> usize {
731741 if lhs >= rhs {
@@ -1276,4 +1286,43 @@ mod tests {
12761286 }
12771287 } ) ;
12781288 }
1289+
1290+ #[ test]
1291+ #[ should_panic( expected = "AlreadyBorrowed" ) ]
1292+ fn cannot_clone_exclusive_borrow_via_deref ( ) {
1293+ Python :: with_gil ( |py| {
1294+ let array = PyArray :: < f64 , _ > :: zeros ( py, ( 3 , 2 , 1 ) , false ) ;
1295+
1296+ let exclusive = array. readwrite ( ) ;
1297+ let _shared = exclusive. clone ( ) ;
1298+ } ) ;
1299+ }
1300+
1301+ #[ test]
1302+ fn failed_resize_does_not_double_release ( ) {
1303+ Python :: with_gil ( |py| {
1304+ let array = PyArray :: < f64 , _ > :: zeros ( py, 10 , false ) ;
1305+
1306+ // The view will make the internal reference check of `PyArray_Resize` fail.
1307+ let locals = [ ( "array" , array) ] . into_py_dict ( py) ;
1308+ let _view = py
1309+ . eval ( "array[:]" , None , Some ( locals) )
1310+ . unwrap ( )
1311+ . downcast :: < PyArray1 < f64 > > ( )
1312+ . unwrap ( ) ;
1313+
1314+ let exclusive = array. readwrite ( ) ;
1315+ assert ! ( exclusive. resize( 100 ) . is_err( ) ) ;
1316+ } ) ;
1317+ }
1318+
1319+ #[ test]
1320+ fn ineffective_resize_does_not_conflict ( ) {
1321+ Python :: with_gil ( |py| {
1322+ let array = PyArray :: < f64 , _ > :: zeros ( py, 10 , false ) ;
1323+
1324+ let exclusive = array. readwrite ( ) ;
1325+ assert ! ( exclusive. resize( 10 ) . is_ok( ) ) ;
1326+ } ) ;
1327+ }
12791328}
0 commit comments