Skip to content

Commit 219c885

Browse files
committed
Overhaul conversion methods
1 parent 323a3e7 commit 219c885

File tree

3 files changed

+68
-20
lines changed

3 files changed

+68
-20
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
# Changelog
2+
<<<<<<< HEAD
23
- v0.10.0
34
- Remove `ErrorKind` and introduce some concrete error types
5+
=======
6+
- Unreleased
7+
- `PyArray::as_slice_mut` and `PyArray::as_array_mut` is now unsafe.
8+
- Introduce `PyArray::as_cell_slice`, `PyArray::as_cow_array` and `PyArray::to_vec`
9+
>>>>>>> 660f04f... Overhaul conversion methods
410
511
- v0.9.0
612
- Update PyO3 to 0.10.0

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
140140
// wrapper of `mult`
141141
#[pyfn(m, "mult")]
142142
fn mult_py(_py: Python, a: f64, x: &PyArrayDyn<f64>) -> PyResult<()> {
143-
let x = x.as_array_mut();
143+
let x = unsafe { x.as_array_mut() };
144144
mult(a, x);
145145
Ok(())
146146
}

src/array.rs

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use ndarray::*;
44
use num_traits::AsPrimitive;
55
use pyo3::{ffi, prelude::*, type_object, types::PyAny};
66
use pyo3::{AsPyPointer, PyDowncastError, PyNativeType, PyResult};
7+
use std::{cell::Cell, mem, os::raw::c_int, ptr, slice};
78
use std::{iter::ExactSizeIterator, marker::PhantomData};
8-
use std::{mem, os::raw::c_int, ptr, slice};
99

1010
use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
1111
use crate::error::{FromVecError, NotContiguousError, ShapeError};
@@ -429,7 +429,7 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
429429
}
430430
}
431431

432-
/// Get the immutable view of the internal data of `PyArray`, as slice.
432+
/// Returns the immutable view of the internal data of `PyArray` as slice.
433433
///
434434
/// Returns `ErrorKind::NotContiguous` if the internal array is not contiguous.
435435
/// # Example
@@ -456,12 +456,27 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
456456
}
457457
}
458458

459-
/// Get the mmutable view of the internal data of `PyArray`, as slice.
460-
pub fn as_slice_mut(&self) -> Result<&mut [T], NotContiguousError> {
459+
/// Returns the view of the internal data of `PyArray` as `&[Cell<T>]`.
460+
pub fn as_cell_slice(&self) -> Result<&[Cell<T>], NotContiguousError> {
461461
if !self.is_contiguous() {
462462
Err(NotContiguousError)
463463
} else {
464-
Ok(unsafe { slice::from_raw_parts_mut(self.data(), self.len()) })
464+
Ok(unsafe { slice::from_raw_parts(self.data() as _, self.len()) })
465+
}
466+
}
467+
468+
/// Returns the view of the internal data of `PyArray` as mutable slice.
469+
///
470+
/// # Soundness
471+
/// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`),
472+
/// it causes an undefined undefined behavior.
473+
///
474+
/// In such case, please consider the use of [`as_cell_slice`](#method.as_cell_slice),
475+
pub unsafe fn as_slice_mut(&self) -> Result<&mut [T], NotContiguousError> {
476+
if !self.is_contiguous() {
477+
Err(NotContiguousError)
478+
} else {
479+
Ok(slice::from_raw_parts_mut(self.data(), self.len()))
465480
}
466481
}
467482

@@ -522,10 +537,22 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
522537
unsafe { ArrayView::from_shape_ptr(self.ndarray_shape(), self.data()) }
523538
}
524539

525-
/// Almost same as [`as_array`](#method.as_array), but returns `ArrayViewMut`.
526-
pub fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
540+
/// Returns the internal array as `CowArray`. See also [`as_array`](#method.as_array).
541+
pub fn as_cow_array(&self) -> CowArray<'_, T, D> {
542+
CowArray::from(self.as_array())
543+
}
544+
545+
/// Returns the internal array as `ArrayViewMut`. See also [`as_array`](#method.as_array).
546+
///
547+
/// # Soundness
548+
/// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`),
549+
/// it causes an undefined undefined behavior.
550+
///
551+
/// In such case, please consider the use of
552+
/// [`as_cell_array`](#method.as_cell_array) or [`as_cow_array`](#method.as_cow_array).
553+
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
527554
self.type_check_assert();
528-
unsafe { ArrayViewMut::from_shape_ptr(self.ndarray_shape(), self.data()) }
555+
ArrayViewMut::from_shape_ptr(self.ndarray_shape(), self.data())
529556
}
530557

531558
/// Get an immutable reference of a specified element, with checking the passed index is valid.
@@ -568,16 +595,6 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
568595
unsafe { Some(&*self.data().offset(offset)) }
569596
}
570597

571-
/// Same as [get](#method.get), but returns `Option<&mut T>`.
572-
#[inline(always)]
573-
pub fn get_mut<Idx>(&self, index: Idx) -> Option<&mut T>
574-
where
575-
Idx: NpyIndex<Dim = D>,
576-
{
577-
let offset = index.get_checked::<T>(self.shape(), self.strides())?;
578-
unsafe { Some(&mut *(self.data().offset(offset) as *mut T)) }
579-
}
580-
581598
/// Get an immutable reference of a specified element, without checking the
582599
/// passed index is valid.
583600
///
@@ -637,7 +654,7 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
637654
}
638655
}
639656

640-
impl<T: TypeNum + Clone, D: Dimension> PyArray<T, D> {
657+
impl<T: Clone + TypeNum, D: Dimension> PyArray<T, D> {
641658
/// Get a copy of `PyArray` as
642659
/// [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html).
643660
///
@@ -655,6 +672,31 @@ impl<T: TypeNum + Clone, D: Dimension> PyArray<T, D> {
655672
pub fn to_owned_array(&self) -> Array<T, D> {
656673
self.as_array().to_owned()
657674
}
675+
676+
/// Returns the copy of the internal data of `PyArray` to `Vec`.
677+
///
678+
/// Returns `ErrorKind::NotContiguous` if the internal array is not contiguous.
679+
/// See also [`as_slice`](#method.as_slice)
680+
///
681+
/// # Example
682+
/// ```
683+
/// # fn main() {
684+
/// use numpy::PyArray2;
685+
/// use pyo3::types::IntoPyDict;
686+
/// let gil = pyo3::Python::acquire_gil();
687+
/// let py = gil.python();
688+
/// let locals = [("np", numpy::get_array_module(py).unwrap())].into_py_dict(py);
689+
/// let array: &PyArray2<i64> = py
690+
/// .eval("np.array([[0, 1], [2, 3]], dtype='int64')", Some(locals), None)
691+
/// .unwrap()
692+
/// .downcast()
693+
/// .unwrap();
694+
/// assert_eq!(array.to_vec().unwrap(), vec![0, 1, 2, 3]);
695+
/// # }
696+
/// ```
697+
pub fn to_vec(&self) -> Result<Vec<T>, NotContiguousError> {
698+
self.as_slice().map(ToOwned::to_owned)
699+
}
658700
}
659701

660702
impl<T: TypeNum> PyArray<T, Ix1> {

0 commit comments

Comments
 (0)