@@ -180,7 +180,7 @@ use crate::dtype::Element;
180180use crate :: error:: { BorrowError , NotContiguousError } ;
181181use crate :: npyffi:: { self , PyArrayObject , NPY_ARRAY_WRITEABLE } ;
182182
183- #[ derive( PartialEq , Eq , Hash ) ]
183+ #[ derive( Clone , Copy , PartialEq , Eq , Hash ) ]
184184struct BorrowKey {
185185 /// exclusive range of lowest and highest address covered by array
186186 range : ( usize , usize ) ,
@@ -375,10 +375,16 @@ static BORROW_FLAGS: BorrowFlags = BorrowFlags::new();
375375/// i.e. that only shared references into the interior of the array can be created safely.
376376///
377377/// See the [module-level documentation](self) for more.
378- pub struct PyReadonlyArray < ' py , T , D > ( & ' py PyArray < T , D > )
378+ #[ repr( C ) ]
379+ pub struct PyReadonlyArray < ' py , T , D >
379380where
380381 T : Element ,
381- D : Dimension ;
382+ D : Dimension ,
383+ {
384+ array : & ' py PyArray < T , D > ,
385+ address : usize ,
386+ key : BorrowKey ,
387+ }
382388
383389/// Read-only borrow of a one-dimensional array.
384390pub type PyReadonlyArray1 < ' py , T > = PyReadonlyArray < ' py , T , Ix1 > ;
@@ -409,7 +415,7 @@ where
409415 type Target = PyArray < T , D > ;
410416
411417 fn deref ( & self ) -> & Self :: Target {
412- self . 0
418+ self . array
413419 }
414420}
415421
@@ -426,27 +432,30 @@ where
426432 D : Dimension ,
427433{
428434 pub ( crate ) fn try_new ( array : & ' py PyArray < T , D > ) -> Result < Self , BorrowError > {
429- let py = array. py ( ) ;
430435 let address = base_address ( array) ;
431436 let key = BorrowKey :: from_array ( array) ;
432437
433- BORROW_FLAGS . acquire ( py , address, key) ?;
438+ BORROW_FLAGS . acquire ( array . py ( ) , address, key) ?;
434439
435- Ok ( Self ( array) )
440+ Ok ( Self {
441+ array,
442+ address,
443+ key,
444+ } )
436445 }
437446
438447 /// Provides an immutable array view of the interior of the NumPy array.
439448 #[ inline( always) ]
440449 pub fn as_array ( & self ) -> ArrayView < T , D > {
441450 // SAFETY: Global borrow flags ensure aliasing discipline.
442- unsafe { self . 0 . as_array ( ) }
451+ unsafe { self . array . as_array ( ) }
443452 }
444453
445454 /// Provide an immutable slice view of the interior of the NumPy array if it is contiguous.
446455 #[ inline( always) ]
447456 pub fn as_slice ( & self ) -> Result < & [ T ] , NotContiguousError > {
448457 // SAFETY: Global borrow flags ensure aliasing discipline.
449- unsafe { self . 0 . as_slice ( ) }
458+ unsafe { self . array . as_slice ( ) }
450459 }
451460
452461 /// Provide an immutable reference to an element of the NumPy array if the index is within bounds.
@@ -455,7 +464,7 @@ where
455464 where
456465 I : NpyIndex < Dim = D > ,
457466 {
458- unsafe { self . 0 . get ( index) }
467+ unsafe { self . array . get ( index) }
459468 }
460469}
461470
@@ -465,7 +474,15 @@ where
465474 D : Dimension ,
466475{
467476 fn clone ( & self ) -> Self {
468- Self :: try_new ( self . 0 ) . unwrap ( )
477+ BORROW_FLAGS
478+ . acquire ( self . array . py ( ) , self . address , self . key )
479+ . unwrap ( ) ;
480+
481+ Self {
482+ array : self . array ,
483+ address : self . address ,
484+ key : self . key ,
485+ }
469486 }
470487}
471488
@@ -475,11 +492,7 @@ where
475492 D : Dimension ,
476493{
477494 fn drop ( & mut self ) {
478- let py = self . 0 . py ( ) ;
479- let address = base_address ( self . 0 ) ;
480- let key = BorrowKey :: from_array ( self . 0 ) ;
481-
482- BORROW_FLAGS . release ( py, address, key) ;
495+ BORROW_FLAGS . release ( self . array . py ( ) , self . address , self . key ) ;
483496 }
484497}
485498
@@ -505,10 +518,16 @@ where
505518/// i.e. that only a single exclusive reference into the interior of the array can be created safely.
506519///
507520/// See the [module-level documentation](self) for more.
508- pub struct PyReadwriteArray < ' py , T , D > ( & ' py PyArray < T , D > )
521+ #[ repr( C ) ]
522+ pub struct PyReadwriteArray < ' py , T , D >
509523where
510524 T : Element ,
511- D : Dimension ;
525+ D : Dimension ,
526+ {
527+ array : & ' py PyArray < T , D > ,
528+ address : usize ,
529+ key : BorrowKey ,
530+ }
512531
513532/// Read-write borrow of a one-dimensional array.
514533pub type PyReadwriteArray1 < ' py , T > = PyReadwriteArray < ' py , T , Ix1 > ;
@@ -561,27 +580,30 @@ where
561580 return Err ( BorrowError :: NotWriteable ) ;
562581 }
563582
564- let py = array. py ( ) ;
565583 let address = base_address ( array) ;
566584 let key = BorrowKey :: from_array ( array) ;
567585
568- BORROW_FLAGS . acquire_mut ( py , address, key) ?;
586+ BORROW_FLAGS . acquire_mut ( array . py ( ) , address, key) ?;
569587
570- Ok ( Self ( array) )
588+ Ok ( Self {
589+ array,
590+ address,
591+ key,
592+ } )
571593 }
572594
573595 /// Provides a mutable array view of the interior of the NumPy array.
574596 #[ inline( always) ]
575597 pub fn as_array_mut ( & mut self ) -> ArrayViewMut < T , D > {
576598 // SAFETY: Global borrow flags ensure aliasing discipline.
577- unsafe { self . 0 . as_array_mut ( ) }
599+ unsafe { self . array . as_array_mut ( ) }
578600 }
579601
580602 /// Provide a mutable slice view of the interior of the NumPy array if it is contiguous.
581603 #[ inline( always) ]
582604 pub fn as_slice_mut ( & mut self ) -> Result < & mut [ T ] , NotContiguousError > {
583605 // SAFETY: Global borrow flags ensure aliasing discipline.
584- unsafe { self . 0 . as_slice_mut ( ) }
606+ unsafe { self . array . as_slice_mut ( ) }
585607 }
586608
587609 /// Provide a mutable reference to an element of the NumPy array if the index is within bounds.
@@ -590,7 +612,7 @@ where
590612 where
591613 I : NpyIndex < Dim = D > ,
592614 {
593- unsafe { self . 0 . get_mut ( index) }
615+ unsafe { self . array . get_mut ( index) }
594616 }
595617}
596618
@@ -616,23 +638,16 @@ where
616638 /// });
617639 /// ```
618640 pub fn resize ( self , new_elems : usize ) -> PyResult < Self > {
619- let py = self . 0 . py ( ) ;
620- let address = base_address ( self . 0 ) ;
621- let key = BorrowKey :: from_array ( self . 0 ) ;
622-
623- BORROW_FLAGS . release_mut ( py, address, key) ;
641+ let array = self . array ;
624642
625643 // SAFETY: Ownership of `self` proves exclusive access to the interior of the array.
626644 unsafe {
627- self . 0 . resize ( new_elems) ?;
645+ array . resize ( new_elems) ?;
628646 }
629647
630- let address = base_address ( self . 0 ) ;
631- let key = BorrowKey :: from_array ( self . 0 ) ;
632-
633- BORROW_FLAGS . acquire_mut ( py, address, key) ?;
648+ drop ( self ) ;
634649
635- Ok ( self )
650+ Ok ( Self :: try_new ( array ) . unwrap ( ) )
636651 }
637652}
638653
@@ -642,11 +657,7 @@ where
642657 D : Dimension ,
643658{
644659 fn drop ( & mut self ) {
645- let py = self . 0 . py ( ) ;
646- let address = base_address ( self . 0 ) ;
647- let key = BorrowKey :: from_array ( self . 0 ) ;
648-
649- BORROW_FLAGS . release_mut ( py, address, key) ;
660+ BORROW_FLAGS . release_mut ( self . array . py ( ) , self . address , self . key ) ;
650661 }
651662}
652663
@@ -1275,4 +1286,43 @@ mod tests {
12751286 }
12761287 } ) ;
12771288 }
1289+
1290+ #[ test]
1291+ #[ should_panic( expected = "AlreadyBorrowed" ) ]
1292+ fn cannot_clone_exclusive_borrow_via_deref ( ) {
1293+ Python :: with_gil ( |py| {
1294+ let array = PyArray :: < f64 , _ > :: zeros ( py, ( 3 , 2 , 1 ) , false ) ;
1295+
1296+ let exclusive = array. readwrite ( ) ;
1297+ let _shared = exclusive. clone ( ) ;
1298+ } ) ;
1299+ }
1300+
1301+ #[ test]
1302+ fn failed_resize_does_not_double_release ( ) {
1303+ Python :: with_gil ( |py| {
1304+ let array = PyArray :: < f64 , _ > :: zeros ( py, 10 , false ) ;
1305+
1306+ // The view will make the internal reference check of `PyArray_Resize` fail.
1307+ let locals = [ ( "array" , array) ] . into_py_dict ( py) ;
1308+ let _view = py
1309+ . eval ( "array[:]" , None , Some ( locals) )
1310+ . unwrap ( )
1311+ . downcast :: < PyArray1 < f64 > > ( )
1312+ . unwrap ( ) ;
1313+
1314+ let exclusive = array. readwrite ( ) ;
1315+ assert ! ( exclusive. resize( 100 ) . is_err( ) ) ;
1316+ } ) ;
1317+ }
1318+
1319+ #[ test]
1320+ fn ineffective_resize_does_not_conflict ( ) {
1321+ Python :: with_gil ( |py| {
1322+ let array = PyArray :: < f64 , _ > :: zeros ( py, 10 , false ) ;
1323+
1324+ let exclusive = array. readwrite ( ) ;
1325+ assert ! ( exclusive. resize( 10 ) . is_ok( ) ) ;
1326+ } ) ;
1327+ }
12781328}
0 commit comments