@@ -23,6 +23,17 @@ struct ndarray_handle {
2323 bool ro;
2424};
2525
26+ static void ndarray_capsule_destructor (PyObject *o) {
27+ error_scope scope; // temporarily save any existing errors
28+ managed_dltensor *mt =
29+ (managed_dltensor *) PyCapsule_GetPointer (o, " dltensor" );
30+
31+ if (mt)
32+ ndarray_dec_ref ((ndarray_handle *) mt->manager_ctx );
33+ else
34+ PyErr_Clear ();
35+ }
36+
2637static void nb_ndarray_dealloc (PyObject *self) {
2738 PyTypeObject *tp = Py_TYPE (self);
2839 ndarray_dec_ref (((nb_ndarray *) self)->th );
@@ -123,12 +134,52 @@ static void nb_ndarray_releasebuffer(PyObject *, Py_buffer *view) {
123134 PyMem_Free (view->strides );
124135}
125136
137+
138+ static PyObject *nb_ndarray_dlpack (PyObject *self, PyTypeObject *,
139+ PyObject *const *, Py_ssize_t ,
140+ PyObject *) {
141+ nb_ndarray *self_nd = (nb_ndarray *) self;
142+ ndarray_handle *th = self_nd->th ;
143+
144+ PyObject *r =
145+ PyCapsule_New (th->ndarray , " dltensor" , ndarray_capsule_destructor);
146+ if (r)
147+ ndarray_inc_ref (th);
148+ return r;
149+ }
150+
151+ static PyObject *nb_ndarray_dlpack_device (PyObject *self, PyTypeObject *,
152+ PyObject *const *, Py_ssize_t ,
153+ PyObject *) {
154+ nb_ndarray *self_nd = (nb_ndarray *) self;
155+ dlpack::dltensor &t = self_nd->th ->ndarray ->dltensor ;
156+ PyObject *r = PyTuple_New (2 );
157+ PyObject *r0 = PyLong_FromLong (t.device .device_type );
158+ PyObject *r1 = PyLong_FromLong (t.device .device_id );
159+ if (!r || !r0 || !r1) {
160+ Py_XDECREF (r);
161+ Py_XDECREF (r0);
162+ Py_XDECREF (r1);
163+ return nullptr ;
164+ }
165+ NB_TUPLE_SET_ITEM (r, 0 , r0);
166+ NB_TUPLE_SET_ITEM (r, 1 , r1);
167+ return r;
168+ }
169+
170+ static PyMethodDef nb_ndarray_members[] = {
171+ { " __dlpack__" , (PyCFunction) nb_ndarray_dlpack, METH_FASTCALL | METH_KEYWORDS, nullptr },
172+ { " __dlpack_device__" , (PyCFunction) nb_ndarray_dlpack_device, METH_FASTCALL | METH_KEYWORDS, nullptr },
173+ { nullptr , nullptr , 0 , nullptr }
174+ };
175+
126176static PyTypeObject *nd_ndarray_tp () noexcept {
127177 PyTypeObject *tp = internals->nb_ndarray ;
128178
129179 if (NB_UNLIKELY (!tp)) {
130180 PyType_Slot slots[] = {
131181 { Py_tp_dealloc, (void *) nb_ndarray_dealloc },
182+ { Py_tp_methods, (void *) nb_ndarray_members },
132183#if PY_VERSION_HEX >= 0x03090000
133184 { Py_bf_getbuffer, (void *) nd_ndarray_tpbuffer },
134185 { Py_bf_releasebuffer, (void *) nb_ndarray_releasebuffer },
@@ -649,17 +700,6 @@ ndarray_handle *ndarray_create(void *value, size_t ndim, const size_t *shape_in,
649700 return result.release ();
650701}
651702
652- static void ndarray_capsule_destructor (PyObject *o) {
653- error_scope scope; // temporarily save any existing errors
654- managed_dltensor *mt =
655- (managed_dltensor *) PyCapsule_GetPointer (o, " dltensor" );
656-
657- if (mt)
658- ndarray_dec_ref ((ndarray_handle *) mt->manager_ctx );
659- else
660- PyErr_Clear ();
661- }
662-
663703PyObject *ndarray_export (ndarray_handle *th, int framework,
664704 rv_policy policy, cleanup_list *cleanup) noexcept {
665705 if (!th)
@@ -706,79 +746,47 @@ PyObject *ndarray_export(ndarray_handle *th, int framework,
706746 }
707747 }
708748
709- if (framework == numpy::value) {
710- try {
711- nb_ndarray *h = PyObject_New (nb_ndarray, nd_ndarray_tp ());
712- if (!h)
713- return nullptr ;
714- h->th = th;
715- ndarray_inc_ref (th);
716-
717- object o = steal ((PyObject *) h);
718- return module_::import_ (" numpy" )
719- .attr (" array" )(o, arg (" copy" ) = copy)
720- .release ()
721- .ptr ();
722- } catch (const std::exception &e) {
723- PyErr_Format (PyExc_RuntimeError,
724- " nanobind::detail::ndarray_export(): could not "
725- " convert ndarray to NumPy array: %s" , e.what ());
726- return nullptr ;
727- }
728- }
729-
730- object package;
731- try {
732- switch (framework) {
733- case no_framework::value:
734- break ;
735-
736- case pytorch::value:
737- package = module_::import_ (" torch.utils.dlpack" );
738- break ;
739-
740- case tensorflow::value:
741- package = module_::import_ (" tensorflow.experimental.dlpack" );
742- break ;
743-
744- case jax::value:
745- package = module_::import_ (" jax.dlpack" );
746- break ;
747-
748- case cupy::value:
749- package = module_::import_ (" cupy" );
750- break ;
751-
752- default :
753- check (false , " nanobind::detail::ndarray_export(): unknown "
754- " framework specified!" );
755- }
756- } catch (const std::exception &e) {
757- PyErr_Format (PyExc_RuntimeError,
758- " nanobind::detail::ndarray_export(): could not import ndarray "
759- " framework: %s" , e.what ());
760- return nullptr ;
761- }
762-
763749 object o;
764750 if (copy && framework == no_framework::value && th->self ) {
765751 o = borrow (th->self );
752+ } else if (framework == numpy::value || framework == jax::value) {
753+ nb_ndarray *h = PyObject_New (nb_ndarray, nd_ndarray_tp ());
754+ if (!h)
755+ return nullptr ;
756+ h->th = th;
757+ ndarray_inc_ref (th);
758+ o = steal ((PyObject *) h);
766759 } else {
767760 o = steal (PyCapsule_New (th->ndarray , " dltensor" ,
768761 ndarray_capsule_destructor));
769762 ndarray_inc_ref (th);
770763 }
771764
765+ try {
766+ if (framework == numpy::value) {
767+ return module_::import_ (" numpy" )
768+ .attr (" array" )(o, arg (" copy" ) = copy)
769+ .release ()
770+ .ptr ();
771+ } else {
772+ const char *pkg_name;
773+ switch (framework) {
774+ case pytorch::value: pkg_name = " torch.utils.dlpack" ; break ;
775+ case tensorflow::value: pkg_name = " tensorflow.experimental.dlpack" ; break ;
776+ case jax::value: pkg_name = " jax.dlpack" ; break ;
777+ case cupy::value: pkg_name = " cupy" ; break ;
778+ default : pkg_name = nullptr ;
779+ }
772780
773- if (package.is_valid ()) {
774- try {
775- o = package.attr (" from_dlpack" )(o);
776- } catch (const std::exception &e) {
777- PyErr_Format (PyExc_RuntimeError,
778- " nanobind::detail::ndarray_export(): could not "
779- " import ndarray: %s" , e.what ());
780- return nullptr ;
781+ if (pkg_name)
782+ o = module_::import_ (pkg_name).attr (" from_dlpack" )(o);
781783 }
784+ } catch (const std::exception &e) {
785+ PyErr_Format (PyExc_RuntimeError,
786+ " nanobind::detail::ndarray_export(): could not "
787+ " import ndarray: %s" ,
788+ e.what ());
789+ return nullptr ;
782790 }
783791
784792 if (copy) {
0 commit comments