1- use crate :: npyffi:: { NPY_CASTING , NPY_ORDER } ;
2- use crate :: { Element , PyArray , PY_ARRAY_API } ;
1+ use std:: borrow:: Cow ;
2+ use std:: ffi:: { CStr , CString } ;
3+ use std:: ptr:: null_mut;
4+
35use ndarray:: { Dimension , IxDyn } ;
4- use pyo3:: { AsPyPointer , FromPyPointer , PyAny , PyNativeType , PyResult } ;
5- use std:: ffi:: CStr ;
6+ use pyo3:: { AsPyPointer , FromPyObject , FromPyPointer , PyAny , PyNativeType , PyResult } ;
7+
8+ use crate :: array:: PyArray ;
9+ use crate :: dtype:: Element ;
10+ use crate :: npyffi:: { array:: PY_ARRAY_API , NPY_CASTING , NPY_ORDER } ;
11+
12+ /// Return value of a function that can yield either an array or a scalar.
13+ pub trait ArrayOrScalar < ' py , T > : FromPyObject < ' py > { }
14+
15+ impl < ' py , T , D > ArrayOrScalar < ' py , T > for & ' py PyArray < T , D >
16+ where
17+ T : Element ,
18+ D : Dimension ,
19+ {
20+ }
21+
22+ impl < ' py , T > ArrayOrScalar < ' py , T > for T where T : Element + FromPyObject < ' py > { }
623
724/// Return the inner product of two arrays.
825///
9- /// # Example
26+ /// [NumPy's documentation][inner] has the details.
27+ ///
28+ /// # Examples
29+ ///
30+ /// Note that this function can either return a scalar...
31+ ///
1032/// ```
11- /// pyo3::Python::with_gil(|py| {
12- /// let array = numpy::pyarray![py, 1, 2, 3];
13- /// let inner: &numpy::PyArray0::<_> = numpy::inner(array, array).unwrap();
14- /// assert_eq!(inner.item(), 14);
33+ /// use pyo3::Python;
34+ /// use numpy::{inner, pyarray, PyArray0};
35+ ///
36+ /// Python::with_gil(|py| {
37+ /// let vector = pyarray![py, 1.0, 2.0, 3.0];
38+ /// let result: f64 = inner(vector, vector).unwrap();
39+ /// assert_eq!(result, 14.0);
1540/// });
1641/// ```
17- pub fn inner < ' py , T , DIN1 , DIN2 , DOUT > (
42+ ///
43+ /// ...or an array depending on its arguments.
44+ ///
45+ /// ```
46+ /// use pyo3::Python;
47+ /// use numpy::{inner, pyarray, PyArray0};
48+ ///
49+ /// Python::with_gil(|py| {
50+ /// let vector = pyarray![py, 1, 2, 3];
51+ /// let result: &PyArray0<_> = inner(vector, vector).unwrap();
52+ /// assert_eq!(result.item(), 14);
53+ /// });
54+ /// ```
55+ ///
56+ /// [inner]: https://numpy.org/doc/stable/reference/generated/numpy.inner.html
57+ pub fn inner < ' py , T , DIN1 , DIN2 , OUT > (
1858 array1 : & ' py PyArray < T , DIN1 > ,
1959 array2 : & ' py PyArray < T , DIN2 > ,
20- ) -> PyResult < & ' py PyArray < T , DOUT > >
60+ ) -> PyResult < OUT >
2161where
62+ T : Element ,
2263 DIN1 : Dimension ,
2364 DIN2 : Dimension ,
24- DOUT : Dimension ,
25- T : Element ,
65+ OUT : ArrayOrScalar < ' py , T > ,
2666{
2767 let py = array1. py ( ) ;
2868 let obj = unsafe {
@@ -34,27 +74,53 @@ where
3474
3575/// Return the dot product of two arrays.
3676///
37- /// # Example
77+ /// [NumPy's documentation][dot] has the details.
78+ ///
79+ /// # Examples
80+ ///
81+ /// Note that this function can either return an array...
82+ ///
3883/// ```
39- /// pyo3::Python::with_gil(|py| {
40- /// let a = numpy::pyarray![py, [1, 0], [0, 1]];
41- /// let b = numpy::pyarray![py, [4, 1], [2, 2]];
42- /// let dot: &numpy::PyArray2::<_> = numpy::dot(a, b).unwrap();
84+ /// use pyo3::Python;
85+ /// use ndarray::array;
86+ /// use numpy::{dot, pyarray, PyArray2};
87+ ///
88+ /// Python::with_gil(|py| {
89+ /// let matrix = pyarray![py, [1, 0], [0, 1]];
90+ /// let another_matrix = pyarray![py, [4, 1], [2, 2]];
91+ ///
92+ /// let result: &PyArray2<_> = numpy::dot(matrix, another_matrix).unwrap();
93+ ///
4394/// assert_eq!(
44- /// dot .readonly().as_array(),
45- /// ndarray:: array![[4, 1], [2, 2]]
95+ /// result .readonly().as_array(),
96+ /// array![[4, 1], [2, 2]]
4697/// );
4798/// });
4899/// ```
49- pub fn dot < ' py , T , DIN1 , DIN2 , DOUT > (
100+ ///
101+ /// ...or a scalar depending on its arguments.
102+ ///
103+ /// ```
104+ /// use pyo3::Python;
105+ /// use numpy::{dot, pyarray, PyArray0};
106+ ///
107+ /// Python::with_gil(|py| {
108+ /// let vector = pyarray![py, 1.0, 2.0, 3.0];
109+ /// let result: f64 = dot(vector, vector).unwrap();
110+ /// assert_eq!(result, 14.0);
111+ /// });
112+ /// ```
113+ ///
114+ /// [dot]: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
115+ pub fn dot < ' py , T , DIN1 , DIN2 , OUT > (
50116 array1 : & ' py PyArray < T , DIN1 > ,
51117 array2 : & ' py PyArray < T , DIN2 > ,
52- ) -> PyResult < & ' py PyArray < T , DOUT > >
118+ ) -> PyResult < OUT >
53119where
120+ T : Element ,
54121 DIN1 : Dimension ,
55122 DIN2 : Dimension ,
56- DOUT : Dimension ,
57- T : Element ,
123+ OUT : ArrayOrScalar < ' py , T > ,
58124{
59125 let py = array1. py ( ) ;
60126 let obj = unsafe {
@@ -66,31 +132,28 @@ where
66132
67133/// Return the Einstein summation convention of given tensors.
68134///
69- /// We also provide the [einsum macro](./macro.einsum.html).
70- pub fn einsum_impl < ' py , T , DOUT > (
71- subscripts : & str ,
72- arrays : & [ & ' py PyArray < T , IxDyn > ] ,
73- ) -> PyResult < & ' py PyArray < T , DOUT > >
135+ /// This is usually invoked via the the [`einsum!`] macro.
136+ pub fn einsum < ' py , T , OUT > ( subscripts : & str , arrays : & [ & ' py PyArray < T , IxDyn > ] ) -> PyResult < OUT >
74137where
75- DOUT : Dimension ,
76138 T : Element ,
139+ OUT : ArrayOrScalar < ' py , T > ,
77140{
78- let subscripts: std:: borrow:: Cow < CStr > = match CStr :: from_bytes_with_nul ( subscripts. as_bytes ( ) )
79- {
80- Ok ( subscripts) => subscripts. into ( ) ,
81- Err ( _) => std:: ffi:: CString :: new ( subscripts) . unwrap ( ) . into ( ) ,
141+ let subscripts = match CStr :: from_bytes_with_nul ( subscripts. as_bytes ( ) ) {
142+ Ok ( subscripts) => Cow :: Borrowed ( subscripts) ,
143+ Err ( _) => Cow :: Owned ( CString :: new ( subscripts) . unwrap ( ) ) ,
82144 } ;
145+
83146 let py = arrays[ 0 ] . py ( ) ;
84147 let obj = unsafe {
85148 let result = PY_ARRAY_API . PyArray_EinsteinSum (
86149 py,
87150 subscripts. as_ptr ( ) as _ ,
88151 arrays. len ( ) as _ ,
89152 arrays. as_ptr ( ) as _ ,
90- std :: ptr :: null_mut ( ) ,
153+ null_mut ( ) ,
91154 NPY_ORDER :: NPY_KEEPORDER ,
92155 NPY_CASTING :: NPY_NO_CASTING ,
93- std :: ptr :: null_mut ( ) ,
156+ null_mut ( ) ,
94157 ) ;
95158 PyAny :: from_owned_ptr_or_err ( py, result) ?
96159 } ;
@@ -99,25 +162,34 @@ where
99162
100163/// Return the Einstein summation convention of given tensors.
101164///
102- /// For more about the Einstein summation convention, you may reffer to
103- /// [the numpy document](https://numpy.org/doc/stable/reference/generated/numpy. einsum.html) .
165+ /// For more about the Einstein summation convention, please refer to
166+ /// [NumPy's documentation][ einsum] .
104167///
105168/// # Example
169+ ///
106170/// ```
107- /// pyo3::Python::with_gil(|py| {
108- /// let a = numpy::PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
109- /// let b = numpy::pyarray![py, [20, 30], [40, 50], [60, 70]];
110- /// let einsum = numpy::einsum!("ijk,ji->ik", a, b).unwrap();
171+ /// use pyo3::Python;
172+ /// use ndarray::array;
173+ /// use numpy::{einsum, pyarray, PyArray, PyArray2};
174+ ///
175+ /// Python::with_gil(|py| {
176+ /// let tensor = PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
177+ /// let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]];
178+ ///
179+ /// let result: &PyArray2<_> = einsum!("ijk,ji->ik", tensor, another_tensor).unwrap();
180+ ///
111181/// assert_eq!(
112- /// einsum .readonly().as_array(),
113- /// ndarray:: array![[640, 760, 880, 1000], [2560, 2710, 2860, 3010]]
182+ /// result .readonly().as_array(),
183+ /// array![[640, 760, 880, 1000], [2560, 2710, 2860, 3010]]
114184/// );
115185/// });
116186/// ```
187+ ///
188+ /// [einsum]: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
117189#[ macro_export]
118190macro_rules! einsum {
119- ( $subscripts: literal $( , $array: ident) + $( , ) * ) => { {
191+ ( $subscripts: literal $( , $array: ident) + $( , ) * ) => { {
120192 let arrays = [ $( $array. to_dyn( ) , ) +] ;
121- unsafe { $crate:: einsum_impl ( concat!( $subscripts, "\0 " ) , & arrays) }
193+ $crate:: einsum ( concat!( $subscripts, "\0 " ) , & arrays)
122194 } } ;
123195}
0 commit comments