@@ -249,7 +249,7 @@ impl<T, D> PyArray<T, D> {
249249 /// ```
250250 /// use numpy::PyArray3;
251251 /// pyo3::Python::with_gil(|py| {
252- /// let arr = PyArray3::<f64>::new (py, [4, 5, 6], false);
252+ /// let arr = PyArray3::<f64>::zeros (py, [4, 5, 6], false);
253253 /// assert_eq!(arr.ndim(), 3);
254254 /// });
255255 /// ```
@@ -266,7 +266,7 @@ impl<T, D> PyArray<T, D> {
266266 /// ```
267267 /// use numpy::PyArray3;
268268 /// pyo3::Python::with_gil(|py| {
269- /// let arr = PyArray3::<f64>::new (py, [4, 5, 6], false);
269+ /// let arr = PyArray3::<f64>::zeros (py, [4, 5, 6], false);
270270 /// assert_eq!(arr.strides(), &[240, 48, 8]);
271271 /// });
272272 /// ```
@@ -287,7 +287,7 @@ impl<T, D> PyArray<T, D> {
287287 /// ```
288288 /// use numpy::PyArray3;
289289 /// pyo3::Python::with_gil(|py| {
290- /// let arr = PyArray3::<f64>::new (py, [4, 5, 6], false);
290+ /// let arr = PyArray3::<f64>::zeros (py, [4, 5, 6], false);
291291 /// assert_eq!(arr.shape(), &[4, 5, 6]);
292292 /// });
293293 /// ```
@@ -371,20 +371,46 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
371371 ///
372372 /// If `is_fortran == true`, returns Fortran-order array. Else, returns C-order array.
373373 ///
374+ /// # Safety
375+ ///
376+ /// The returned array will always be safe to be dropped as the elements must either
377+ /// be trivially copyable or have `DATA_TYPE == DataType::Object`, i.e. be pointers
378+ /// into Python's heap, which NumPy will automatically zero-initialize.
379+ ///
380+ /// However, the elements themselves will not be valid and should only be accessed
381+ /// via raw pointers obtained via [uget_raw](#method.uget_raw).
382+ ///
383+ /// All methods which produce references to the elements invoke undefined behaviour.
384+ /// In particular, zero-initialized pointers are _not_ valid instances of `PyObject`.
385+ ///
374386 /// # Example
375387 /// ```
376388 /// use numpy::PyArray3;
389+ ///
377390 /// pyo3::Python::with_gil(|py| {
378- /// let arr = PyArray3::<i32>::new(py, [4, 5, 6], false);
391+ /// let arr = unsafe {
392+ /// let arr = PyArray3::<i32>::new(py, [4, 5, 6], false);
393+ ///
394+ /// for i in 0..4 {
395+ /// for j in 0..5 {
396+ /// for k in 0..6 {
397+ /// arr.uget_raw([i, j, k]).write((i * j * k) as i32);
398+ /// }
399+ /// }
400+ /// }
401+ ///
402+ /// arr
403+ /// };
404+ ///
379405 /// assert_eq!(arr.shape(), &[4, 5, 6]);
380406 /// });
381407 /// ```
382- pub fn new < ID > ( py : Python , dims : ID , is_fortran : bool ) -> & Self
408+ pub unsafe fn new < ID > ( py : Python , dims : ID , is_fortran : bool ) -> & Self
383409 where
384410 ID : IntoDimension < Dim = D > ,
385411 {
386412 let flags = if is_fortran { 1 } else { 0 } ;
387- unsafe { PyArray :: new_ ( py, dims, ptr:: null_mut ( ) , flags) }
413+ PyArray :: new_ ( py, dims, ptr:: null_mut ( ) , flags)
388414 }
389415
390416 pub ( crate ) unsafe fn new_ < ID > (
@@ -448,7 +474,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
448474 /// a fortran order array is created, otherwise a C-order array is created.
449475 ///
450476 /// For elements with `DATA_TYPE == DataType::Object`, this will fill the array
451- /// valid pointers to objects of type `<class 'int'>` with value zero .
477+ /// with valid pointers to zero-valued Python integer objects .
452478 ///
453479 /// See also [PyArray_Zeros](https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Zeros)
454480 ///
@@ -596,6 +622,16 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
596622 & mut * ( self . data ( ) . offset ( offset) as * mut _ )
597623 }
598624
625+ /// Same as [uget](#method.uget), but returns `*mut T`.
626+ #[ inline( always) ]
627+ pub unsafe fn uget_raw < Idx > ( & self , index : Idx ) -> * mut T
628+ where
629+ Idx : NpyIndex < Dim = D > ,
630+ {
631+ let offset = index. get_unchecked :: < T > ( self . strides ( ) ) ;
632+ self . data ( ) . offset ( offset) as * mut _
633+ }
634+
599635 /// Get dynamic dimensioned array from fixed dimension array.
600636 pub fn to_dyn ( & self ) -> & PyArray < T , IxDyn > {
601637 let python = self . py ( ) ;
@@ -733,20 +769,18 @@ impl<T: Element> PyArray<T, Ix1> {
733769 /// });
734770 /// ```
735771 pub fn from_slice < ' py > ( py : Python < ' py > , slice : & [ T ] ) -> & ' py Self {
736- let array = PyArray :: new ( py , [ slice . len ( ) ] , false ) ;
737- if T :: DATA_TYPE != DataType :: Object {
738- unsafe {
772+ unsafe {
773+ let array = PyArray :: new ( py , [ slice . len ( ) ] , false ) ;
774+ if T :: DATA_TYPE != DataType :: Object {
739775 array. copy_ptr ( slice. as_ptr ( ) , slice. len ( ) ) ;
740- }
741- } else {
742- unsafe {
776+ } else {
743777 let data_ptr = array. data ( ) ;
744778 for ( i, item) in slice. iter ( ) . enumerate ( ) {
745779 data_ptr. add ( i) . write ( item. clone ( ) ) ;
746780 }
747781 }
782+ array
748783 }
749- array
750784 }
751785
752786 /// Construct one-dimension PyArray
@@ -781,13 +815,13 @@ impl<T: Element> PyArray<T, Ix1> {
781815 pub fn from_exact_iter ( py : Python < ' _ > , iter : impl ExactSizeIterator < Item = T > ) -> & Self {
782816 // NumPy will always zero-initialize object pointers,
783817 // so the array can be dropped safely if the iterator panics.
784- let array = Self :: new ( py, [ iter. len ( ) ] , false ) ;
785818 unsafe {
819+ let array = Self :: new ( py, [ iter. len ( ) ] , false ) ;
786820 for ( i, item) in iter. enumerate ( ) {
787- * array. uget_mut ( [ i] ) = item;
821+ array. uget_raw ( [ i] ) . write ( item) ;
788822 }
823+ array
789824 }
790- array
791825 }
792826
793827 /// Construct one-dimension PyArray from a type which implements
@@ -809,11 +843,11 @@ impl<T: Element> PyArray<T, Ix1> {
809843 let iter = iter. into_iter ( ) ;
810844 let ( min_len, max_len) = iter. size_hint ( ) ;
811845 let mut capacity = max_len. unwrap_or_else ( || min_len. max ( 512 / mem:: size_of :: < T > ( ) ) ) ;
812- // NumPy will always zero-initialize object pointers,
813- // so the array can be dropped safely if the iterator panics.
814- let array = Self :: new ( py, [ capacity] , false ) ;
815- let mut length = 0 ;
816846 unsafe {
847+ // NumPy will always zero-initialize object pointers,
848+ // so the array can be dropped safely if the iterator panics.
849+ let array = Self :: new ( py, [ capacity] , false ) ;
850+ let mut length = 0 ;
817851 for ( i, item) in iter. enumerate ( ) {
818852 length += 1 ;
819853 if length > capacity {
@@ -822,13 +856,13 @@ impl<T: Element> PyArray<T, Ix1> {
822856 . resize ( capacity)
823857 . expect ( "PyArray::from_iter: Failed to allocate memory" ) ;
824858 }
825- * array. uget_mut ( [ i] ) = item;
859+ array. uget_raw ( [ i] ) . write ( item) ;
826860 }
861+ if capacity > length {
862+ array. resize ( length) . unwrap ( )
863+ }
864+ array
827865 }
828- if capacity > length {
829- array. resize ( length) . unwrap ( )
830- }
831- array
832866 }
833867
834868 /// Extends or trancates the length of 1 dimension PyArray.
@@ -902,15 +936,15 @@ impl<T: Element> PyArray<T, Ix2> {
902936 return Err ( FromVecError :: new ( v. len ( ) , last_len) ) ;
903937 }
904938 let dims = [ v. len ( ) , last_len] ;
905- let array = Self :: new ( py, dims, false ) ;
906939 unsafe {
940+ let array = Self :: new ( py, dims, false ) ;
907941 for ( y, vy) in v. iter ( ) . enumerate ( ) {
908942 for ( x, vyx) in vy. iter ( ) . enumerate ( ) {
909- * array. uget_mut ( [ y, x] ) = vyx. clone ( ) ;
943+ array. uget_raw ( [ y, x] ) . write ( vyx. clone ( ) ) ;
910944 }
911945 }
946+ Ok ( array)
912947 }
913- Ok ( array)
914948 }
915949}
916950
@@ -944,17 +978,17 @@ impl<T: Element> PyArray<T, Ix3> {
944978 return Err ( FromVecError :: new ( v. len ( ) , len3) ) ;
945979 }
946980 let dims = [ v. len ( ) , len2, len3] ;
947- let array = Self :: new ( py, dims, false ) ;
948981 unsafe {
982+ let array = Self :: new ( py, dims, false ) ;
949983 for ( z, vz) in v. iter ( ) . enumerate ( ) {
950984 for ( y, vzy) in vz. iter ( ) . enumerate ( ) {
951985 for ( x, vzyx) in vzy. iter ( ) . enumerate ( ) {
952- * array. uget_mut ( [ z, y, x] ) = vzyx. clone ( ) ;
986+ array. uget_raw ( [ z, y, x] ) . write ( vzyx. clone ( ) ) ;
953987 }
954988 }
955989 }
990+ Ok ( array)
956991 }
957- Ok ( array)
958992 }
959993}
960994
@@ -965,7 +999,7 @@ impl<T: Element, D> PyArray<T, D> {
965999 /// use numpy::PyArray;
9661000 /// pyo3::Python::with_gil(|py| {
9671001 /// let pyarray_f = PyArray::arange(py, 2.0, 5.0, 1.0);
968- /// let pyarray_i = PyArray::<i64, _>::new(py, [3], false);
1002+ /// let pyarray_i = unsafe { PyArray::<i64, _>::new(py, [3], false) } ;
9691003 /// assert!(pyarray_f.copy_to(pyarray_i).is_ok());
9701004 /// assert_eq!(pyarray_i.readonly().as_slice().unwrap(), &[2, 3, 4]);
9711005 /// });
0 commit comments