66#pragma once
77
88#include " pyAMReX.H"
9+ #include " dlpack.h"
910
1011#include < AMReX_Array4.H>
1112#include < AMReX_BLassert.H>
1819#include < sstream>
1920#include < type_traits>
2021#include < vector>
21- #include " dlpack.h"
22-
23- // GPU backend headers for device detection
24- #ifdef AMREX_USE_CUDA
25- #include < cuda_runtime.h>
26- #endif
27- #ifdef AMREX_USE_HIP
28- #include < hip/hip_runtime.h>
29- #endif
3022
3123
3224namespace
@@ -194,6 +186,7 @@ namespace pyAMReX
194186 */
195187
196188
189+ /*
197190 // CPU: __array_interface__ v3
198191 // https://numpy.org/doc/stable/reference/arrays.interface.html
199192 .def_property_readonly("__array_interface__", [](Array4<T> const & a4) {
@@ -229,60 +222,26 @@ namespace pyAMReX
229222 d["version"] = 3;
230223 return d;
231224 })
225+ */
232226
233227
234228 // DLPack protocol (CPU, NVIDIA GPU, AMD GPU, Intel GPU, etc.)
235229 // https://dmlc.github.io/dlpack/latest/
236230 // https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h
237231 // https://docs.cupy.dev/en/stable/user_guide/interoperability.html#dlpack-data-exchange-protocol
238- .def (" __dlpack__" , [](Array4<T> const &a4, py::handle stream = py::none ()) {
232+ .def (" __dlpack__" , [](Array4<T> const &a4, [[maybe_unused]] py::handle stream = py::none ()) {
239233 // Allocate shape/strides arrays
240234 constexpr int ndim = 4 ;
241235 auto const len = length (a4);
242236 auto *shape = new int64_t [ndim]{a4.nComp (), len.z , len.y , len.x };
243237 auto *strides = new int64_t [ndim]{a4.nstride , a4.kstride , a4.jstride , 1 };
244- // DLPack dtype
245- DLDataType dtype{};
246- if constexpr (std::is_same_v<T, float >) { dtype.code = kDLFloat ; dtype.bits = 32 ; dtype.lanes = 1 ; }
247- else if constexpr (std::is_same_v<T, double >) { dtype.code = kDLFloat ; dtype.bits = 64 ; dtype.lanes = 1 ; }
248- else if constexpr (std::is_same_v<T, int32_t >) { dtype.code = kDLInt ; dtype.bits = 32 ; dtype.lanes = 1 ; }
249- else if constexpr (std::is_same_v<T, int64_t >) { dtype.code = kDLInt ; dtype.bits = 64 ; dtype.lanes = 1 ; }
250- else if constexpr (std::is_same_v<T, uint32_t >) { dtype.code = kDLUInt ; dtype.bits = 32 ; dtype.lanes = 1 ; }
251- else if constexpr (std::is_same_v<T, uint64_t >) { dtype.code = kDLUInt ; dtype.bits = 64 ; dtype.lanes = 1 ; }
252- else { throw std::runtime_error (" Unsupported dtype for DLPack" ); }
253-
254- // Device detection based on AMReX GPU backend
255- DLDevice device{ kDLCPU , 0 };
256- #ifdef AMREX_USE_CUDA
257- // Check if data is on GPU by checking if pointer is in CUDA memory
258- cudaPointerAttributes attr;
259- cudaError_t err = cudaPointerGetAttributes (&attr, a4.dataPtr ());
260- if (err == cudaSuccess && attr.memoryType == cudaMemoryTypeDevice) {
261- device.device_type = kDLCUDA ;
262- device.device_id = attr.device ;
263- }
264- #elif defined(AMREX_USE_HIP)
265- // Check if data is on GPU by checking if pointer is in HIP memory
266- hipPointerAttribute_t attr;
267- hipError_t err = hipPointerGetAttributes (&attr, a4.dataPtr ());
268- if (err == hipSuccess && attr.memoryType == hipMemoryTypeDevice) {
269- device.device_type = kDLROCM ;
270- device.device_id = attr.device ;
271- }
272- #elif defined(AMREX_USE_DPCPP)
273- // For SYCL, we need to check if the data is on device
274- // This is more complex as SYCL doesn't have a simple pointer check
275- // For now, assume CPU - SYCL support would need more sophisticated detection
276- // device.device_type = kDLExtDev; // SYCL would use extended device type
277- // device.device_id = 0;
278- #endif
279238
280239 // Construct DLTensor
281240 auto *dl_tensor = new DLManagedTensor;
282241 dl_tensor->dl_tensor .data = const_cast <void *>(static_cast <const void *>(a4.dataPtr ()));
283- dl_tensor->dl_tensor .device = device ;
242+ dl_tensor->dl_tensor .device = dlpack::detect_device_from_pointer (a4. dataPtr ()) ;
284243 dl_tensor->dl_tensor .ndim = ndim;
285- dl_tensor->dl_tensor .dtype = dtype ;
244+ dl_tensor->dl_tensor .dtype = dlpack::get_dlpack_dtype<T>() ;
286245 dl_tensor->dl_tensor .shape = shape;
287246 dl_tensor->dl_tensor .strides = strides;
288247 dl_tensor->dl_tensor .byte_offset = 0 ;
@@ -297,40 +256,16 @@ namespace pyAMReX
297256 auto * tensor = static_cast <DLManagedTensor*>(ptr);
298257 tensor->deleter (tensor);
299258 });
300- }, py::arg (" stream" ) = py::none (), R"doc(
259+ },
260+ py::arg (" stream" ) = py::none (),
261+ R"doc(
301262 DLPack protocol for zero-copy tensor exchange.
302263 See https://dmlc.github.io/dlpack/latest/ for details.
303- )doc" )
264+ )doc"
265+ )
304266 .def (" __dlpack_device__" , [](Array4<T> const &a4) {
305- // Device detection based on AMReX GPU backend
306- int device_type = kDLCPU ;
307- int device_id = 0 ;
308-
309- #ifdef AMREX_USE_CUDA
310- // Check if data is on GPU by checking if pointer is in CUDA memory
311- cudaPointerAttributes attr;
312- cudaError_t err = cudaPointerGetAttributes (&attr, a4.dataPtr ());
313- if (err == cudaSuccess && attr.memoryType == cudaMemoryTypeDevice) {
314- device_type = kDLCUDA ;
315- device_id = attr.device ;
316- }
317- #elif defined(AMREX_USE_HIP)
318- // Check if data is on GPU by checking if pointer is in HIP memory
319- hipPointerAttribute_t attr;
320- hipError_t err = hipPointerGetAttributes (&attr, a4.dataPtr ());
321- if (err == hipSuccess && attr.memoryType == hipMemoryTypeDevice) {
322- device_type = kDLROCM ;
323- device_id = attr.device ;
324- }
325- #elif defined(AMREX_USE_DPCPP)
326- // For SYCL, we need to check if the data is on device
327- // This is more complex as SYCL doesn't have a simple pointer check
328- // For now, assume CPU - SYCL support would need more sophisticated detection
329- // device_type = kDLExtDev; // SYCL would use extended device type
330- // device_id = 0;
331- #endif
332-
333- return std::make_tuple (device_type, device_id);
267+ DLDevice device = dlpack::detect_device_from_pointer (a4.dataPtr ());
268+ return std::make_tuple (device.device_type , device.device_id );
334269 }, R"doc(
335270 DLPack device info (device_type, device_id).
336271 )doc" )
0 commit comments