@@ -4,7 +4,7 @@ use std::ptr::null_mut;
44
55use ndarray:: { Dimension , IxDyn } ;
66use pyo3:: types:: PyAnyMethods ;
7- use pyo3:: { AsPyPointer , Bound , FromPyObject , PyNativeType , PyResult } ;
7+ use pyo3:: { Borrowed , Bound , FromPyObject , PyNativeType , PyResult } ;
88
99use crate :: array:: PyArray ;
1010use crate :: dtype:: Element ;
2020{
2121}
2222
23+ impl < ' py , T , D > ArrayOrScalar < ' py , T > for Bound < ' py , PyArray < T , D > >
24+ where
25+ T : Element ,
26+ D : Dimension ,
27+ {
28+ }
29+
2330impl < ' py , T > ArrayOrScalar < ' py , T > for T where T : Element + FromPyObject < ' py > { }
2431
32+ /// Deprecated form of [`inner_bound`]
33+ #[ deprecated(
34+ since = "0.21.0" ,
35+ note = "will be replaced by `inner_bound` in the future"
36+ ) ]
37+ pub fn inner < ' py , T , DIN1 , DIN2 , OUT > (
38+ array1 : & ' py PyArray < T , DIN1 > ,
39+ array2 : & ' py PyArray < T , DIN2 > ,
40+ ) -> PyResult < OUT >
41+ where
42+ T : Element ,
43+ DIN1 : Dimension ,
44+ DIN2 : Dimension ,
45+ OUT : ArrayOrScalar < ' py , T > ,
46+ {
47+ inner_bound ( & array1. as_borrowed ( ) , & array2. as_borrowed ( ) )
48+ }
49+
2550/// Return the inner product of two arrays.
2651///
2752/// [NumPy's documentation][inner] has the details.
@@ -31,33 +56,33 @@ impl<'py, T> ArrayOrScalar<'py, T> for T where T: Element + FromPyObject<'py> {}
3156/// Note that this function can either return a scalar...
3257///
3358/// ```
34- /// use pyo3::Python;
35- /// use numpy::{inner , pyarray, PyArray0};
59+ /// use pyo3::{ Python, PyNativeType} ;
60+ /// use numpy::{inner_bound , pyarray, PyArray0};
3661///
3762/// Python::with_gil(|py| {
38- /// let vector = pyarray![py, 1.0, 2.0, 3.0];
39- /// let result: f64 = inner( vector, vector).unwrap();
63+ /// let vector = pyarray![py, 1.0, 2.0, 3.0].as_borrowed() ;
64+ /// let result: f64 = inner_bound(& vector, & vector).unwrap();
4065/// assert_eq!(result, 14.0);
4166/// });
4267/// ```
4368///
4469/// ...or an array depending on its arguments.
4570///
4671/// ```
47- /// use pyo3::Python;
48- /// use numpy::{inner , pyarray, PyArray0};
72+ /// use pyo3::{ Python, Bound, PyNativeType} ;
73+ /// use numpy::{inner_bound , pyarray, PyArray0, PyArray0Methods };
4974///
5075/// Python::with_gil(|py| {
51- /// let vector = pyarray![py, 1, 2, 3];
52- /// let result: & PyArray0<_> = inner( vector, vector).unwrap();
76+ /// let vector = pyarray![py, 1, 2, 3].as_borrowed() ;
77+ /// let result: Bound<'_, PyArray0<_>> = inner_bound(& vector, & vector).unwrap();
5378/// assert_eq!(result.item(), 14);
5479/// });
5580/// ```
5681///
5782/// [inner]: https://numpy.org/doc/stable/reference/generated/numpy.inner.html
58- pub fn inner < ' py , T , DIN1 , DIN2 , OUT > (
59- array1 : & ' py PyArray < T , DIN1 > ,
60- array2 : & ' py PyArray < T , DIN2 > ,
83+ pub fn inner_bound < ' py , T , DIN1 , DIN2 , OUT > (
84+ array1 : & Bound < ' py , PyArray < T , DIN1 > > ,
85+ array2 : & Bound < ' py , PyArray < T , DIN2 > > ,
6186) -> PyResult < OUT >
6287where
6388 T : Element ,
7398 obj. extract ( )
7499}
75100
101+ /// Deprecated form of [`dot_bound`]
102+ #[ deprecated(
103+ since = "0.21.0" ,
104+ note = "will be replaced by `dot_bound` in the future"
105+ ) ]
106+ pub fn dot < ' py , T , DIN1 , DIN2 , OUT > (
107+ array1 : & ' py PyArray < T , DIN1 > ,
108+ array2 : & ' py PyArray < T , DIN2 > ,
109+ ) -> PyResult < OUT >
110+ where
111+ T : Element ,
112+ DIN1 : Dimension ,
113+ DIN2 : Dimension ,
114+ OUT : ArrayOrScalar < ' py , T > ,
115+ {
116+ dot_bound ( & array1. as_borrowed ( ) , & array2. as_borrowed ( ) )
117+ }
118+
76119/// Return the dot product of two arrays.
77120///
78121/// [NumPy's documentation][dot] has the details.
@@ -82,15 +125,15 @@ where
82125/// Note that this function can either return an array...
83126///
84127/// ```
85- /// use pyo3::Python;
128+ /// use pyo3::{ Python, Bound, PyNativeType} ;
86129/// use ndarray::array;
87- /// use numpy::{dot , pyarray, PyArray2};
130+ /// use numpy::{dot_bound , pyarray, PyArray2, PyArrayMethods };
88131///
89132/// Python::with_gil(|py| {
90- /// let matrix = pyarray![py, [1, 0], [0, 1]];
91- /// let another_matrix = pyarray![py, [4, 1], [2, 2]];
133+ /// let matrix = pyarray![py, [1, 0], [0, 1]].as_borrowed() ;
134+ /// let another_matrix = pyarray![py, [4, 1], [2, 2]].as_borrowed() ;
92135///
93- /// let result: & PyArray2<_> = numpy::dot( matrix, another_matrix).unwrap();
136+ /// let result: Bound<'_, PyArray2<_>> = dot_bound(& matrix, & another_matrix).unwrap();
94137///
95138/// assert_eq!(
96139/// result.readonly().as_array(),
@@ -102,20 +145,20 @@ where
102145/// ...or a scalar depending on its arguments.
103146///
104147/// ```
105- /// use pyo3::Python;
106- /// use numpy::{dot , pyarray, PyArray0};
148+ /// use pyo3::{ Python, PyNativeType} ;
149+ /// use numpy::{dot_bound , pyarray, PyArray0};
107150///
108151/// Python::with_gil(|py| {
109- /// let vector = pyarray![py, 1.0, 2.0, 3.0];
110- /// let result: f64 = dot( vector, vector).unwrap();
152+ /// let vector = pyarray![py, 1.0, 2.0, 3.0].as_borrowed() ;
153+ /// let result: f64 = dot_bound(& vector, & vector).unwrap();
111154/// assert_eq!(result, 14.0);
112155/// });
113156/// ```
114157///
115158/// [dot]: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
116- pub fn dot < ' py , T , DIN1 , DIN2 , OUT > (
117- array1 : & ' py PyArray < T , DIN1 > ,
118- array2 : & ' py PyArray < T , DIN2 > ,
159+ pub fn dot_bound < ' py , T , DIN1 , DIN2 , OUT > (
160+ array1 : & Bound < ' py , PyArray < T , DIN1 > > ,
161+ array2 : & Bound < ' py , PyArray < T , DIN2 > > ,
119162) -> PyResult < OUT >
120163where
121164 T : Element ,
@@ -131,10 +174,30 @@ where
131174 obj. extract ( )
132175}
133176
177+ /// Deprecated form of [`einsum_bound`]
178+ #[ deprecated(
179+ since = "0.21.0" ,
180+ note = "will be replaced by `einsum_bound` in the future"
181+ ) ]
182+ pub fn einsum < ' py , T , OUT > ( subscripts : & str , arrays : & [ & ' py PyArray < T , IxDyn > ] ) -> PyResult < OUT >
183+ where
184+ T : Element ,
185+ OUT : ArrayOrScalar < ' py , T > ,
186+ {
187+ // Safety: &PyArray<T, IxDyn> has the same size and layout in memory as
188+ // Borrowed<'_, '_, PyArray<T, IxDyn>>
189+ einsum_bound ( subscripts, unsafe {
190+ std:: slice:: from_raw_parts ( arrays. as_ptr ( ) . cast ( ) , arrays. len ( ) )
191+ } )
192+ }
193+
134194/// Return the Einstein summation convention of given tensors.
135195///
136196/// This is usually invoked via the the [`einsum!`][crate::einsum!] macro.
137- pub fn einsum < ' py , T , OUT > ( subscripts : & str , arrays : & [ & ' py PyArray < T , IxDyn > ] ) -> PyResult < OUT >
197+ pub fn einsum_bound < ' py , T , OUT > (
198+ subscripts : & str ,
199+ arrays : & [ Borrowed < ' _ , ' py , PyArray < T , IxDyn > > ] ,
200+ ) -> PyResult < OUT >
138201where
139202 T : Element ,
140203 OUT : ArrayOrScalar < ' py , T > ,
@@ -161,6 +224,20 @@ where
161224 obj. extract ( )
162225}
163226
227+ /// Deprecated form of [`einsum_bound!`]
228+ #[ deprecated(
229+ since = "0.21.0" ,
230+ note = "will be replaced by `einsum_bound!` in the future"
231+ ) ]
232+ #[ macro_export]
233+ macro_rules! einsum {
234+ ( $subscripts: literal $( , $array: ident) + $( , ) * ) => { {
235+ use pyo3:: PyNativeType ;
236+ let arrays = [ $( $array. to_dyn( ) . as_borrowed( ) , ) +] ;
237+ $crate:: einsum_bound( concat!( $subscripts, "\0 " ) , & arrays)
238+ } } ;
239+ }
240+
164241/// Return the Einstein summation convention of given tensors.
165242///
166243/// For more about the Einstein summation convention, please refer to
@@ -169,15 +246,15 @@ where
169246/// # Example
170247///
171248/// ```
172- /// use pyo3::Python;
249+ /// use pyo3::{ Python, Bound, PyNativeType} ;
173250/// use ndarray::array;
174- /// use numpy::{einsum , pyarray, PyArray, PyArray2, PyArrayMethods};
251+ /// use numpy::{einsum_bound , pyarray, PyArray, PyArray2, PyArrayMethods};
175252///
176253/// Python::with_gil(|py| {
177- /// let tensor = PyArray::arange_bound(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap().into_gil_ref() ;
178- /// let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]];
254+ /// let tensor = PyArray::arange_bound(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
255+ /// let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]].as_borrowed() ;
179256///
180- /// let result: & PyArray2<_> = einsum !("ijk,ji->ik", tensor, another_tensor).unwrap();
257+ /// let result: Bound<'_, PyArray2<_>> = einsum_bound !("ijk,ji->ik", tensor, another_tensor).unwrap();
181258///
182259/// assert_eq!(
183260/// result.readonly().as_array(),
@@ -188,9 +265,9 @@ where
188265///
189266/// [einsum]: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
190267#[ macro_export]
191- macro_rules! einsum {
268+ macro_rules! einsum_bound {
192269 ( $subscripts: literal $( , $array: ident) + $( , ) * ) => { {
193- let arrays = [ $( $array. to_dyn( ) , ) +] ;
194- $crate:: einsum ( concat!( $subscripts, "\0 " ) , & arrays)
270+ let arrays = [ $( $array. to_dyn( ) . as_borrowed ( ) , ) +] ;
271+ $crate:: einsum_bound ( concat!( $subscripts, "\0 " ) , & arrays)
195272 } } ;
196273}
0 commit comments