|
18 | 18 | #include <sstream> |
19 | 19 | #include <type_traits> |
20 | 20 | #include <vector> |
| 21 | +#include "dlpack.h" |
| 22 | + |
| 23 | +// GPU backend headers for device detection |
| 24 | +#ifdef AMREX_USE_CUDA |
| 25 | +#include <cuda_runtime.h> |
| 26 | +#endif |
| 27 | +#ifdef AMREX_USE_HIP |
| 28 | +#include <hip/hip_runtime.h> |
| 29 | +#endif |
21 | 30 |
|
22 | 31 |
|
23 | 32 | namespace |
@@ -222,13 +231,109 @@ namespace pyAMReX |
222 | 231 | }) |
223 | 232 |
|
224 | 233 |
|
225 | | - // TODO: __dlpack__ __dlpack_device__ |
226 | 234 | // DLPack protocol (CPU, NVIDIA GPU, AMD GPU, Intel GPU, etc.) |
227 | 235 | // https://dmlc.github.io/dlpack/latest/ |
228 | | - // https://data-apis.org/array-api/latest/design_topics/data_interchange.html |
229 | | - // https://github.com/data-apis/consortium-feedback/issues/1 |
230 | 236 | // https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h |
231 | 237 | // https://docs.cupy.dev/en/stable/user_guide/interoperability.html#dlpack-data-exchange-protocol |
| 238 | + .def("__dlpack__", [](Array4<T> const &a4, py::handle stream = py::none()) { |
| 239 | + // Allocate shape/strides arrays |
| 240 | + constexpr int ndim = 4; |
| 241 | + auto const len = length(a4); |
| 242 | + auto *shape = new int64_t[ndim]{a4.nComp(), len.z, len.y, len.x}; |
| 243 | + auto *strides = new int64_t[ndim]{a4.nstride, a4.kstride, a4.jstride, 1}; |
| 244 | + // DLPack dtype |
| 245 | + DLDataType dtype{}; |
| 246 | + if constexpr (std::is_same_v<T, float>) { dtype.code = kDLFloat; dtype.bits = 32; dtype.lanes = 1; } |
| 247 | + else if constexpr (std::is_same_v<T, double>) { dtype.code = kDLFloat; dtype.bits = 64; dtype.lanes = 1; } |
| 248 | + else if constexpr (std::is_same_v<T, int32_t>) { dtype.code = kDLInt; dtype.bits = 32; dtype.lanes = 1; } |
| 249 | + else if constexpr (std::is_same_v<T, int64_t>) { dtype.code = kDLInt; dtype.bits = 64; dtype.lanes = 1; } |
| 250 | + else if constexpr (std::is_same_v<T, uint32_t>) { dtype.code = kDLUInt; dtype.bits = 32; dtype.lanes = 1; } |
| 251 | + else if constexpr (std::is_same_v<T, uint64_t>) { dtype.code = kDLUInt; dtype.bits = 64; dtype.lanes = 1; } |
| 252 | + else { throw std::runtime_error("Unsupported dtype for DLPack"); } |
| 253 | + |
| 254 | + // Device detection based on AMReX GPU backend |
| 255 | + DLDevice device{ kDLCPU, 0 }; |
| 256 | +#ifdef AMREX_USE_CUDA |
| 257 | + // Check if data is on GPU by checking if pointer is in CUDA memory |
| 258 | + cudaPointerAttributes attr; |
| 259 | + cudaError_t err = cudaPointerGetAttributes(&attr, a4.dataPtr()); |
| 260 | + if (err == cudaSuccess && attr.memoryType == cudaMemoryTypeDevice) { |
| 261 | + device.device_type = kDLCUDA; |
| 262 | + device.device_id = attr.device; |
| 263 | + } |
| 264 | +#elif defined(AMREX_USE_HIP) |
| 265 | + // Check if data is on GPU by checking if pointer is in HIP memory |
| 266 | + hipPointerAttribute_t attr; |
| 267 | + hipError_t err = hipPointerGetAttributes(&attr, a4.dataPtr()); |
| 268 | + if (err == hipSuccess && attr.memoryType == hipMemoryTypeDevice) { |
| 269 | + device.device_type = kDLROCM; |
| 270 | + device.device_id = attr.device; |
| 271 | + } |
| 272 | +#elif defined(AMREX_USE_DPCPP) |
| 273 | + // For SYCL, we need to check if the data is on device |
| 274 | + // This is more complex as SYCL doesn't have a simple pointer check |
| 275 | + // For now, assume CPU - SYCL support would need more sophisticated detection |
| 276 | + // device.device_type = kDLExtDev; // SYCL would use extended device type |
| 277 | + // device.device_id = 0; |
| 278 | +#endif |
| 279 | + |
| 280 | + // Construct DLTensor |
| 281 | + auto *dl_tensor = new DLManagedTensor; |
| 282 | + dl_tensor->dl_tensor.data = const_cast<void*>(static_cast<const void*>(a4.dataPtr())); |
| 283 | + dl_tensor->dl_tensor.device = device; |
| 284 | + dl_tensor->dl_tensor.ndim = ndim; |
| 285 | + dl_tensor->dl_tensor.dtype = dtype; |
| 286 | + dl_tensor->dl_tensor.shape = shape; |
| 287 | + dl_tensor->dl_tensor.strides = strides; |
| 288 | + dl_tensor->dl_tensor.byte_offset = 0; |
| 289 | + dl_tensor->manager_ctx = nullptr; |
| 290 | + dl_tensor->deleter = [](DLManagedTensor *self) { |
| 291 | + delete[] self->dl_tensor.shape; |
| 292 | + delete[] self->dl_tensor.strides; |
| 293 | + delete self; |
| 294 | + }; |
| 295 | + // Return as Python capsule |
| 296 | + return py::capsule(dl_tensor, "dltensor", [](void* ptr) { |
| 297 | + auto* tensor = static_cast<DLManagedTensor*>(ptr); |
| 298 | + tensor->deleter(tensor); |
| 299 | + }); |
| 300 | + }, py::arg("stream") = py::none(), R"doc( |
| 301 | + DLPack protocol for zero-copy tensor exchange. |
| 302 | + See https://dmlc.github.io/dlpack/latest/ for details. |
| 303 | + )doc") |
| 304 | + .def("__dlpack_device__", [](Array4<T> const &a4) { |
| 305 | + // Device detection based on AMReX GPU backend |
| 306 | + int device_type = kDLCPU; |
| 307 | + int device_id = 0; |
| 308 | + |
| 309 | +#ifdef AMREX_USE_CUDA |
| 310 | + // Check if data is on GPU by checking if pointer is in CUDA memory |
| 311 | + cudaPointerAttributes attr; |
| 312 | + cudaError_t err = cudaPointerGetAttributes(&attr, a4.dataPtr()); |
| 313 | + if (err == cudaSuccess && attr.memoryType == cudaMemoryTypeDevice) { |
| 314 | + device_type = kDLCUDA; |
| 315 | + device_id = attr.device; |
| 316 | + } |
| 317 | +#elif defined(AMREX_USE_HIP) |
| 318 | + // Check if data is on GPU by checking if pointer is in HIP memory |
| 319 | + hipPointerAttribute_t attr; |
| 320 | + hipError_t err = hipPointerGetAttributes(&attr, a4.dataPtr()); |
| 321 | + if (err == hipSuccess && attr.memoryType == hipMemoryTypeDevice) { |
| 322 | + device_type = kDLROCM; |
| 323 | + device_id = attr.device; |
| 324 | + } |
| 325 | +#elif defined(AMREX_USE_DPCPP) |
| 326 | + // For SYCL, we need to check if the data is on device |
| 327 | + // This is more complex as SYCL doesn't have a simple pointer check |
| 328 | + // For now, assume CPU - SYCL support would need more sophisticated detection |
| 329 | + // device_type = kDLExtDev; // SYCL would use extended device type |
| 330 | + // device_id = 0; |
| 331 | +#endif |
| 332 | + |
| 333 | + return std::make_tuple(device_type, device_id); |
| 334 | + }, R"doc( |
| 335 | + DLPack device info (device_type, device_id). |
| 336 | + )doc") |
232 | 337 |
|
233 | 338 | .def("to_host", [](Array4<T> const & a4) { |
234 | 339 | // py::tuple to std::vector |
|
0 commit comments