Skip to content

Commit 0de29c5

Browse files
authored
Merge pull request numpy#19083 from hameerabbasi/dlpack
ENH: Implement the DLPack Array API protocols for ndarray.
2 parents c280e21 + 5b94a03 commit 0de29c5

File tree

16 files changed

+830
-20
lines changed

16 files changed

+830
-20
lines changed

doc/neps/nep-0047-array-api-standard.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -338,9 +338,10 @@ the options already present in NumPy are:
338338

339339
Adding support for DLPack to NumPy entails:
340340

341-
- Adding a ``ndarray.__dlpack__`` method.
342-
- Adding a ``from_dlpack`` function, which takes as input an object
343-
supporting ``__dlpack__``, and returns an ``ndarray``.
341+
- Adding a ``ndarray.__dlpack__()`` method which returns a ``dlpack`` C
342+
structure wrapped in a ``PyCapsule``.
343+
- Adding a ``np._from_dlpack(obj)`` function, where ``obj`` supports
344+
``__dlpack__()``, and returns an ``ndarray``.
344345

345346
DLPack is currently a ~200 LoC header, and is meant to be included directly, so
346347
no external dependency is needed. Implementation should be straightforward.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Add NEP 47-compatible dlpack support
2+
------------------------------------
3+
4+
Add a ``ndarray.__dlpack__()`` method which returns a ``dlpack`` C structure
5+
wrapped in a ``PyCapsule``. Also add a ``np._from_dlpack(obj)`` function, where
6+
``obj`` supports ``__dlpack__()``, and returns an ``ndarray``.

numpy/__init__.pyi

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,6 +1413,7 @@ _SupportsBuffer = Union[
14131413

14141414
_T = TypeVar("_T")
14151415
_T_co = TypeVar("_T_co", covariant=True)
1416+
_T_contra = TypeVar("_T_contra", contravariant=True)
14161417
_2Tuple = Tuple[_T, _T]
14171418
_CastingKind = L["no", "equiv", "safe", "same_kind", "unsafe"]
14181419

@@ -1432,6 +1433,10 @@ _ArrayTD64_co = NDArray[Union[bool_, integer[Any], timedelta64]]
14321433
# Introduce an alias for `dtype` to avoid naming conflicts.
14331434
_dtype = dtype
14341435

1436+
# `builtins.PyCapsule` unfortunately lacks annotations as of the moment;
1437+
# use `Any` as a stopgap measure
1438+
_PyCapsule = Any
1439+
14351440
class _SupportsItem(Protocol[_T_co]):
14361441
def item(self, args: Any, /) -> _T_co: ...
14371442

@@ -2439,6 +2444,12 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
24392444
def __ior__(self: NDArray[signedinteger[_NBit1]], other: _ArrayLikeInt_co) -> NDArray[signedinteger[_NBit1]]: ...
24402445
@overload
24412446
def __ior__(self: NDArray[object_], other: Any) -> NDArray[object_]: ...
2447+
@overload
2448+
def __ior__(self: NDArray[_ScalarType], other: _RecursiveSequence) -> NDArray[_ScalarType]: ...
2449+
@overload
2450+
def __dlpack__(self: NDArray[number[Any]], *, stream: None = ...) -> _PyCapsule: ...
2451+
@overload
2452+
def __dlpack_device__(self) -> Tuple[int, L[0]]: ...
24422453

24432454
# Keep `dtype` at the bottom to avoid name conflicts with `np.dtype`
24442455
@property
@@ -4320,3 +4331,9 @@ class chararray(ndarray[_ShapeType, _CharDType]):
43204331

43214332
# NOTE: Deprecated
43224333
# class MachAr: ...
4334+
4335+
class _SupportsDLPack(Protocol[_T_contra]):
4336+
def __dlpack__(self, *, stream: None | _T_contra = ...) -> _PyCapsule: ...
4337+
4338+
def _from_dlpack(__obj: _SupportsDLPack[None]) -> NDArray[Any]: ...
4339+

numpy/array_api/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@
136136
empty,
137137
empty_like,
138138
eye,
139-
from_dlpack,
139+
_from_dlpack,
140140
full,
141141
full_like,
142142
linspace,
@@ -155,7 +155,7 @@
155155
"empty",
156156
"empty_like",
157157
"eye",
158-
"from_dlpack",
158+
"_from_dlpack",
159159
"full",
160160
"full_like",
161161
"linspace",

numpy/array_api/_creation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def eye(
151151
return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))
152152

153153

154-
def from_dlpack(x: object, /) -> Array:
154+
def _from_dlpack(x: object, /) -> Array:
155155
# Note: dlpack support is not yet implemented on Array
156156
raise NotImplementedError("DLPack support is not yet implemented")
157157

numpy/core/_add_newdocs.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,6 +1573,19 @@
15731573
array_function_like_doc,
15741574
))
15751575

1576+
add_newdoc('numpy.core.multiarray', '_from_dlpack',
1577+
"""
1578+
_from_dlpack(x, /)
1579+
1580+
Create a NumPy array from an object implementing the ``__dlpack__``
1581+
protocol.
1582+
1583+
See Also
1584+
--------
1585+
`Array API documentation
1586+
<https://data-apis.org/array-api/latest/design_topics/data_interchange.html#syntax-for-data-interchange-with-dlpack>`_
1587+
""")
1588+
15761589
add_newdoc('numpy.core', 'fastCopyAndTranspose',
15771590
"""_fastCopyAndTranspose(a)""")
15781591

@@ -2263,6 +2276,15 @@
22632276
add_newdoc('numpy.core.multiarray', 'ndarray', ('__array_struct__',
22642277
"""Array protocol: C-struct side."""))
22652278

2279+
add_newdoc('numpy.core.multiarray', 'ndarray', ('__dlpack__',
2280+
"""a.__dlpack__(*, stream=None)
2281+
2282+
DLPack Protocol: Part of the Array API."""))
2283+
2284+
add_newdoc('numpy.core.multiarray', 'ndarray', ('__dlpack_device__',
2285+
"""a.__dlpack_device__()
2286+
2287+
DLPack Protocol: Part of the Array API."""))
22662288

22672289
add_newdoc('numpy.core.multiarray', 'ndarray', ('base',
22682290
"""

numpy/core/code_generators/genapi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
join('multiarray', 'datetime_busdaycal.c'),
4242
join('multiarray', 'datetime_strings.c'),
4343
join('multiarray', 'descriptor.c'),
44+
join('multiarray', 'dlpack.c'),
4445
join('multiarray', 'dtypemeta.c'),
4546
join('multiarray', 'einsum.c.src'),
4647
join('multiarray', 'flagsobject.c'),

numpy/core/multiarray.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,28 @@
1414
# do not change them. issue gh-15518
1515
# _get_ndarray_c_version is semi-public, on purpose not added to __all__
1616
from ._multiarray_umath import (
17-
_fastCopyAndTranspose, _flagdict, _insert, _reconstruct, _vec_string,
18-
_ARRAY_API, _monotonicity, _get_ndarray_c_version, _set_madvise_hugepage,
17+
_fastCopyAndTranspose, _flagdict, _from_dlpack, _insert, _reconstruct,
18+
_vec_string, _ARRAY_API, _monotonicity, _get_ndarray_c_version,
19+
_set_madvise_hugepage,
1920
)
2021

2122
__all__ = [
2223
'_ARRAY_API', 'ALLOW_THREADS', 'BUFSIZE', 'CLIP', 'DATETIMEUNITS',
2324
'ITEM_HASOBJECT', 'ITEM_IS_POINTER', 'LIST_PICKLE', 'MAXDIMS',
2425
'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT', 'NEEDS_INIT', 'NEEDS_PYAPI',
2526
'RAISE', 'USE_GETITEM', 'USE_SETITEM', 'WRAP', '_fastCopyAndTranspose',
26-
'_flagdict', '_insert', '_reconstruct', '_vec_string', '_monotonicity',
27-
'add_docstring', 'arange', 'array', 'asarray', 'asanyarray',
28-
'ascontiguousarray', 'asfortranarray', 'bincount', 'broadcast',
29-
'busday_count', 'busday_offset', 'busdaycalendar', 'can_cast',
27+
'_flagdict', '_from_dlpack', '_insert', '_reconstruct', '_vec_string',
28+
'_monotonicity', 'add_docstring', 'arange', 'array', 'asarray',
29+
'asanyarray', 'ascontiguousarray', 'asfortranarray', 'bincount',
30+
'broadcast', 'busday_count', 'busday_offset', 'busdaycalendar', 'can_cast',
3031
'compare_chararrays', 'concatenate', 'copyto', 'correlate', 'correlate2',
3132
'count_nonzero', 'c_einsum', 'datetime_as_string', 'datetime_data',
3233
'dot', 'dragon4_positional', 'dragon4_scientific', 'dtype',
3334
'empty', 'empty_like', 'error', 'flagsobj', 'flatiter', 'format_longfloat',
34-
'frombuffer', 'fromfile', 'fromiter', 'fromstring', 'get_handler_name',
35-
'inner', 'interp', 'interp_complex', 'is_busday', 'lexsort',
36-
'matmul', 'may_share_memory', 'min_scalar_type', 'ndarray', 'nditer',
37-
'nested_iters', 'normalize_axis_index', 'packbits',
35+
'frombuffer', 'fromfile', 'fromiter', 'fromstring',
36+
'get_handler_name', 'inner', 'interp', 'interp_complex', 'is_busday',
37+
'lexsort', 'matmul', 'may_share_memory', 'min_scalar_type', 'ndarray',
38+
'nditer', 'nested_iters', 'normalize_axis_index', 'packbits',
3839
'promote_types', 'putmask', 'ravel_multi_index', 'result_type', 'scalar',
3940
'set_datetimeparse_function', 'set_legacy_print_mode', 'set_numeric_ops',
4041
'set_string_function', 'set_typeDict', 'shares_memory',
@@ -46,6 +47,7 @@
4647
scalar.__module__ = 'numpy.core.multiarray'
4748

4849

50+
_from_dlpack.__module__ = 'numpy'
4951
arange.__module__ = 'numpy'
5052
array.__module__ = 'numpy'
5153
asarray.__module__ = 'numpy'

numpy/core/numeric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
WRAP, arange, array, asarray, asanyarray, ascontiguousarray,
1414
asfortranarray, broadcast, can_cast, compare_chararrays,
1515
concatenate, copyto, dot, dtype, empty,
16-
empty_like, flatiter, frombuffer, fromfile, fromiter, fromstring,
17-
inner, lexsort, matmul, may_share_memory,
16+
empty_like, flatiter, frombuffer, _from_dlpack, fromfile, fromiter,
17+
fromstring, inner, lexsort, matmul, may_share_memory,
1818
min_scalar_type, ndarray, nditer, nested_iters, promote_types,
1919
putmask, result_type, set_numeric_ops, shares_memory, vdot, where,
2020
zeros, normalize_axis_index)
@@ -41,7 +41,7 @@
4141
'newaxis', 'ndarray', 'flatiter', 'nditer', 'nested_iters', 'ufunc',
4242
'arange', 'array', 'asarray', 'asanyarray', 'ascontiguousarray',
4343
'asfortranarray', 'zeros', 'count_nonzero', 'empty', 'broadcast', 'dtype',
44-
'fromstring', 'fromfile', 'frombuffer', 'where',
44+
'fromstring', 'fromfile', 'frombuffer', '_from_dlpack', 'where',
4545
'argwhere', 'copyto', 'concatenate', 'fastCopyAndTranspose', 'lexsort',
4646
'set_numeric_ops', 'can_cast', 'promote_types', 'min_scalar_type',
4747
'result_type', 'isfortran', 'empty_like', 'zeros_like', 'ones_like',

numpy/core/setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,7 @@ def gl_if_msvc(build_cmd):
740740
#######################################################################
741741

742742
common_deps = [
743+
join('src', 'common', 'dlpack', 'dlpack.h'),
743744
join('src', 'common', 'array_assign.h'),
744745
join('src', 'common', 'binop_override.h'),
745746
join('src', 'common', 'cblasfuncs.h'),
@@ -749,6 +750,7 @@ def gl_if_msvc(build_cmd):
749750
join('src', 'common', 'npy_cblas.h'),
750751
join('src', 'common', 'npy_config.h'),
751752
join('src', 'common', 'npy_ctypes.h'),
753+
join('src', 'common', 'npy_dlpack.h'),
752754
join('src', 'common', 'npy_extint128.h'),
753755
join('src', 'common', 'npy_import.h'),
754756
join('src', 'common', 'npy_hashtable.h'),
@@ -881,6 +883,7 @@ def gl_if_msvc(build_cmd):
881883
join('src', 'multiarray', 'datetime_busday.c'),
882884
join('src', 'multiarray', 'datetime_busdaycal.c'),
883885
join('src', 'multiarray', 'descriptor.c'),
886+
join('src', 'multiarray', 'dlpack.c'),
884887
join('src', 'multiarray', 'dtypemeta.c'),
885888
join('src', 'multiarray', 'dragon4.c'),
886889
join('src', 'multiarray', 'dtype_transfer.c'),

0 commit comments

Comments
 (0)