Skip to content

Commit cbac976

Browse files
committed
Vibe Start
1 parent 2a2587b commit cbac976

File tree

2 files changed

+174
-3
lines changed

2 files changed

+174
-3
lines changed

src/Base/Array4.H

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@
1818
#include <sstream>
1919
#include <type_traits>
2020
#include <vector>
21+
#include "dlpack.h"
22+
23+
// GPU backend headers for device detection
24+
#ifdef AMREX_USE_CUDA
25+
#include <cuda_runtime.h>
26+
#endif
27+
#ifdef AMREX_USE_HIP
28+
#include <hip/hip_runtime.h>
29+
#endif
2130

2231

2332
namespace
@@ -222,13 +231,109 @@ namespace pyAMReX
222231
})
223232

224233

225-
// TODO: __dlpack__ __dlpack_device__
226234
// DLPack protocol (CPU, NVIDIA GPU, AMD GPU, Intel GPU, etc.)
227235
// https://dmlc.github.io/dlpack/latest/
228-
// https://data-apis.org/array-api/latest/design_topics/data_interchange.html
229-
// https://github.com/data-apis/consortium-feedback/issues/1
230236
// https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h
231237
// https://docs.cupy.dev/en/stable/user_guide/interoperability.html#dlpack-data-exchange-protocol
238+
.def("__dlpack__", [](Array4<T> const &a4, py::handle stream = py::none()) {
239+
// Allocate shape/strides arrays
240+
constexpr int ndim = 4;
241+
auto const len = length(a4);
242+
auto *shape = new int64_t[ndim]{a4.nComp(), len.z, len.y, len.x};
243+
auto *strides = new int64_t[ndim]{a4.nstride, a4.kstride, a4.jstride, 1};
244+
// DLPack dtype
245+
DLDataType dtype{};
246+
if constexpr (std::is_same_v<T, float>) { dtype.code = kDLFloat; dtype.bits = 32; dtype.lanes = 1; }
247+
else if constexpr (std::is_same_v<T, double>) { dtype.code = kDLFloat; dtype.bits = 64; dtype.lanes = 1; }
248+
else if constexpr (std::is_same_v<T, int32_t>) { dtype.code = kDLInt; dtype.bits = 32; dtype.lanes = 1; }
249+
else if constexpr (std::is_same_v<T, int64_t>) { dtype.code = kDLInt; dtype.bits = 64; dtype.lanes = 1; }
250+
else if constexpr (std::is_same_v<T, uint32_t>) { dtype.code = kDLUInt; dtype.bits = 32; dtype.lanes = 1; }
251+
else if constexpr (std::is_same_v<T, uint64_t>) { dtype.code = kDLUInt; dtype.bits = 64; dtype.lanes = 1; }
252+
else { throw std::runtime_error("Unsupported dtype for DLPack"); }
253+
254+
// Device detection based on AMReX GPU backend
255+
DLDevice device{ kDLCPU, 0 };
256+
#ifdef AMREX_USE_CUDA
257+
// Check if data is on GPU by checking if pointer is in CUDA memory
258+
cudaPointerAttributes attr;
259+
cudaError_t err = cudaPointerGetAttributes(&attr, a4.dataPtr());
260+
if (err == cudaSuccess && attr.memoryType == cudaMemoryTypeDevice) {
261+
device.device_type = kDLCUDA;
262+
device.device_id = attr.device;
263+
}
264+
#elif defined(AMREX_USE_HIP)
265+
// Check if data is on GPU by checking if pointer is in HIP memory
266+
hipPointerAttribute_t attr;
267+
hipError_t err = hipPointerGetAttributes(&attr, a4.dataPtr());
268+
if (err == hipSuccess && attr.memoryType == hipMemoryTypeDevice) {
269+
device.device_type = kDLROCM;
270+
device.device_id = attr.device;
271+
}
272+
#elif defined(AMREX_USE_DPCPP)
273+
// For SYCL, we need to check if the data is on device
274+
// This is more complex as SYCL doesn't have a simple pointer check
275+
// For now, assume CPU - SYCL support would need more sophisticated detection
276+
// device.device_type = kDLExtDev; // SYCL would use extended device type
277+
// device.device_id = 0;
278+
#endif
279+
280+
// Construct DLTensor
281+
auto *dl_tensor = new DLManagedTensor;
282+
dl_tensor->dl_tensor.data = const_cast<void*>(static_cast<const void*>(a4.dataPtr()));
283+
dl_tensor->dl_tensor.device = device;
284+
dl_tensor->dl_tensor.ndim = ndim;
285+
dl_tensor->dl_tensor.dtype = dtype;
286+
dl_tensor->dl_tensor.shape = shape;
287+
dl_tensor->dl_tensor.strides = strides;
288+
dl_tensor->dl_tensor.byte_offset = 0;
289+
dl_tensor->manager_ctx = nullptr;
290+
dl_tensor->deleter = [](DLManagedTensor *self) {
291+
delete[] self->dl_tensor.shape;
292+
delete[] self->dl_tensor.strides;
293+
delete self;
294+
};
295+
// Return as Python capsule
296+
return py::capsule(dl_tensor, "dltensor", [](void* ptr) {
297+
auto* tensor = static_cast<DLManagedTensor*>(ptr);
298+
tensor->deleter(tensor);
299+
});
300+
}, py::arg("stream") = py::none(), R"doc(
301+
DLPack protocol for zero-copy tensor exchange.
302+
See https://dmlc.github.io/dlpack/latest/ for details.
303+
)doc")
304+
.def("__dlpack_device__", [](Array4<T> const &a4) {
305+
// Device detection based on AMReX GPU backend
306+
int device_type = kDLCPU;
307+
int device_id = 0;
308+
309+
#ifdef AMREX_USE_CUDA
310+
// Check if data is on GPU by checking if pointer is in CUDA memory
311+
cudaPointerAttributes attr;
312+
cudaError_t err = cudaPointerGetAttributes(&attr, a4.dataPtr());
313+
if (err == cudaSuccess && attr.memoryType == cudaMemoryTypeDevice) {
314+
device_type = kDLCUDA;
315+
device_id = attr.device;
316+
}
317+
#elif defined(AMREX_USE_HIP)
318+
// Check if data is on GPU by checking if pointer is in HIP memory
319+
hipPointerAttribute_t attr;
320+
hipError_t err = hipPointerGetAttributes(&attr, a4.dataPtr());
321+
if (err == hipSuccess && attr.memoryType == hipMemoryTypeDevice) {
322+
device_type = kDLROCM;
323+
device_id = attr.device;
324+
}
325+
#elif defined(AMREX_USE_DPCPP)
326+
// For SYCL, we need to check if the data is on device
327+
// This is more complex as SYCL doesn't have a simple pointer check
328+
// For now, assume CPU - SYCL support would need more sophisticated detection
329+
// device_type = kDLExtDev; // SYCL would use extended device type
330+
// device_id = 0;
331+
#endif
332+
333+
return std::make_tuple(device_type, device_id);
334+
}, R"doc(
335+
DLPack device info (device_type, device_id).
336+
)doc")
232337

233338
.def("to_host", [](Array4<T> const & a4) {
234339
// py::tuple to std::vector

src/Base/dlpack.h

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#ifndef AMREX_DLPACK_H_
2+
#define AMREX_DLPACK_H_
3+
4+
#ifdef __cplusplus
5+
extern "C" {
6+
#endif
7+
8+
#include <stdint.h>
9+
#include <stddef.h>
10+
11+
// Device type codes
12+
#define kDLCPU 1
13+
#define kDLCUDA 2
14+
#define kDLCUDAHost 3
15+
#define kDLOpenCL 4
16+
#define kDLVulkan 7
17+
#define kDLMetal 8
18+
#define kDLVPI 9
19+
#define kDLROCM 10
20+
#define kDLROCMHost 11
21+
#define kDLExtDev 12
22+
23+
// Data type codes
24+
#define kDLInt 0
25+
#define kDLUInt 1
26+
#define kDLFloat 2
27+
28+
// Device context
29+
typedef struct {
30+
int32_t device_type;
31+
int32_t device_id;
32+
} DLDevice;
33+
34+
// Data type
35+
typedef struct {
36+
uint8_t code; // kDLFloat=2, kDLInt=0, kDLUInt=1
37+
uint8_t bits; // number of bits, e.g., 32, 64
38+
uint16_t lanes; // number of lanes (for vector types)
39+
} DLDataType;
40+
41+
// Tensor structure
42+
typedef struct {
43+
void* data;
44+
DLDevice device;
45+
int32_t ndim;
46+
int64_t* shape;
47+
int64_t* strides; // in elements, not bytes; can be NULL for compact
48+
uint64_t byte_offset;
49+
DLDataType dtype;
50+
} DLTensor;
51+
52+
// Managed tensor with deleter
53+
struct DLManagedTensor;
54+
typedef void (*DLManagedTensorDeleter)(struct DLManagedTensor* self);
55+
56+
typedef struct DLManagedTensor {
57+
DLTensor dl_tensor;
58+
void* manager_ctx;
59+
DLManagedTensorDeleter deleter;
60+
} DLManagedTensor;
61+
62+
#ifdef __cplusplus
63+
} // extern "C"
64+
#endif
65+
66+
#endif // AMREX_DLPACK_H_

0 commit comments

Comments
 (0)