Skip to content

Commit 4539006

Browse files
committed
Cleaning
1 parent cbac976 commit 4539006

File tree

3 files changed

+479
-144
lines changed

3 files changed

+479
-144
lines changed

src/Base/Array4.H

Lines changed: 13 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#pragma once
77

88
#include "pyAMReX.H"
9+
#include "dlpack.h"
910

1011
#include <AMReX_Array4.H>
1112
#include <AMReX_BLassert.H>
@@ -18,15 +19,6 @@
1819
#include <sstream>
1920
#include <type_traits>
2021
#include <vector>
21-
#include "dlpack.h"
22-
23-
// GPU backend headers for device detection
24-
#ifdef AMREX_USE_CUDA
25-
#include <cuda_runtime.h>
26-
#endif
27-
#ifdef AMREX_USE_HIP
28-
#include <hip/hip_runtime.h>
29-
#endif
3022

3123

3224
namespace
@@ -194,6 +186,7 @@ namespace pyAMReX
194186
*/
195187

196188

189+
/*
197190
// CPU: __array_interface__ v3
198191
// https://numpy.org/doc/stable/reference/arrays.interface.html
199192
.def_property_readonly("__array_interface__", [](Array4<T> const & a4) {
@@ -229,60 +222,26 @@ namespace pyAMReX
229222
d["version"] = 3;
230223
return d;
231224
})
225+
*/
232226

233227

234228
// DLPack protocol (CPU, NVIDIA GPU, AMD GPU, Intel GPU, etc.)
235229
// https://dmlc.github.io/dlpack/latest/
236230
// https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h
237231
// https://docs.cupy.dev/en/stable/user_guide/interoperability.html#dlpack-data-exchange-protocol
238-
.def("__dlpack__", [](Array4<T> const &a4, py::handle stream = py::none()) {
232+
.def("__dlpack__", [](Array4<T> const &a4, [[maybe_unused]] py::handle stream = py::none()) {
239233
// Allocate shape/strides arrays
240234
constexpr int ndim = 4;
241235
auto const len = length(a4);
242236
auto *shape = new int64_t[ndim]{a4.nComp(), len.z, len.y, len.x};
243237
auto *strides = new int64_t[ndim]{a4.nstride, a4.kstride, a4.jstride, 1};
244-
// DLPack dtype
245-
DLDataType dtype{};
246-
if constexpr (std::is_same_v<T, float>) { dtype.code = kDLFloat; dtype.bits = 32; dtype.lanes = 1; }
247-
else if constexpr (std::is_same_v<T, double>) { dtype.code = kDLFloat; dtype.bits = 64; dtype.lanes = 1; }
248-
else if constexpr (std::is_same_v<T, int32_t>) { dtype.code = kDLInt; dtype.bits = 32; dtype.lanes = 1; }
249-
else if constexpr (std::is_same_v<T, int64_t>) { dtype.code = kDLInt; dtype.bits = 64; dtype.lanes = 1; }
250-
else if constexpr (std::is_same_v<T, uint32_t>) { dtype.code = kDLUInt; dtype.bits = 32; dtype.lanes = 1; }
251-
else if constexpr (std::is_same_v<T, uint64_t>) { dtype.code = kDLUInt; dtype.bits = 64; dtype.lanes = 1; }
252-
else { throw std::runtime_error("Unsupported dtype for DLPack"); }
253-
254-
// Device detection based on AMReX GPU backend
255-
DLDevice device{ kDLCPU, 0 };
256-
#ifdef AMREX_USE_CUDA
257-
// Check if data is on GPU by checking if pointer is in CUDA memory
258-
cudaPointerAttributes attr;
259-
cudaError_t err = cudaPointerGetAttributes(&attr, a4.dataPtr());
260-
if (err == cudaSuccess && attr.memoryType == cudaMemoryTypeDevice) {
261-
device.device_type = kDLCUDA;
262-
device.device_id = attr.device;
263-
}
264-
#elif defined(AMREX_USE_HIP)
265-
// Check if data is on GPU by checking if pointer is in HIP memory
266-
hipPointerAttribute_t attr;
267-
hipError_t err = hipPointerGetAttributes(&attr, a4.dataPtr());
268-
if (err == hipSuccess && attr.memoryType == hipMemoryTypeDevice) {
269-
device.device_type = kDLROCM;
270-
device.device_id = attr.device;
271-
}
272-
#elif defined(AMREX_USE_DPCPP)
273-
// For SYCL, we need to check if the data is on device
274-
// This is more complex as SYCL doesn't have a simple pointer check
275-
// For now, assume CPU - SYCL support would need more sophisticated detection
276-
// device.device_type = kDLExtDev; // SYCL would use extended device type
277-
// device.device_id = 0;
278-
#endif
279238

280239
// Construct DLTensor
281240
auto *dl_tensor = new DLManagedTensor;
282241
dl_tensor->dl_tensor.data = const_cast<void*>(static_cast<const void*>(a4.dataPtr()));
283-
dl_tensor->dl_tensor.device = device;
242+
dl_tensor->dl_tensor.device = dlpack::detect_device_from_pointer(a4.dataPtr());
284243
dl_tensor->dl_tensor.ndim = ndim;
285-
dl_tensor->dl_tensor.dtype = dtype;
244+
dl_tensor->dl_tensor.dtype = dlpack::get_dlpack_dtype<T>();
286245
dl_tensor->dl_tensor.shape = shape;
287246
dl_tensor->dl_tensor.strides = strides;
288247
dl_tensor->dl_tensor.byte_offset = 0;
@@ -297,40 +256,16 @@ namespace pyAMReX
297256
auto* tensor = static_cast<DLManagedTensor*>(ptr);
298257
tensor->deleter(tensor);
299258
});
300-
}, py::arg("stream") = py::none(), R"doc(
259+
},
260+
py::arg("stream") = py::none(),
261+
R"doc(
301262
DLPack protocol for zero-copy tensor exchange.
302263
See https://dmlc.github.io/dlpack/latest/ for details.
303-
)doc")
264+
)doc"
265+
)
304266
.def("__dlpack_device__", [](Array4<T> const &a4) {
305-
// Device detection based on AMReX GPU backend
306-
int device_type = kDLCPU;
307-
int device_id = 0;
308-
309-
#ifdef AMREX_USE_CUDA
310-
// Check if data is on GPU by checking if pointer is in CUDA memory
311-
cudaPointerAttributes attr;
312-
cudaError_t err = cudaPointerGetAttributes(&attr, a4.dataPtr());
313-
if (err == cudaSuccess && attr.memoryType == cudaMemoryTypeDevice) {
314-
device_type = kDLCUDA;
315-
device_id = attr.device;
316-
}
317-
#elif defined(AMREX_USE_HIP)
318-
// Check if data is on GPU by checking if pointer is in HIP memory
319-
hipPointerAttribute_t attr;
320-
hipError_t err = hipPointerGetAttributes(&attr, a4.dataPtr());
321-
if (err == hipSuccess && attr.memoryType == hipMemoryTypeDevice) {
322-
device_type = kDLROCM;
323-
device_id = attr.device;
324-
}
325-
#elif defined(AMREX_USE_DPCPP)
326-
// For SYCL, we need to check if the data is on device
327-
// This is more complex as SYCL doesn't have a simple pointer check
328-
// For now, assume CPU - SYCL support would need more sophisticated detection
329-
// device_type = kDLExtDev; // SYCL would use extended device type
330-
// device_id = 0;
331-
#endif
332-
333-
return std::make_tuple(device_type, device_id);
267+
DLDevice device = dlpack::detect_device_from_pointer(a4.dataPtr());
268+
return std::make_tuple(device.device_type, device.device_id);
334269
}, R"doc(
335270
DLPack device info (device_type, device_id).
336271
)doc")

src/Base/dlpack.h

Lines changed: 0 additions & 66 deletions
This file was deleted.

0 commit comments

Comments
 (0)