From 566bd197bcd97f8522c84a400877f7697de7c0ac Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Tue, 9 Sep 2025 15:50:31 +0000 Subject: [PATCH] oneAPI-aware MPI --- Project.toml | 4 ++++ ext/OneAPIExt.jl | 27 +++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) create mode 100644 ext/OneAPIExt.jl diff --git a/Project.toml b/Project.toml index 7e3d66252..ac4123773 100644 --- a/Project.toml +++ b/Project.toml @@ -34,15 +34,19 @@ Requires = "~0.5, 1.0" Serialization = "1" Sockets = "1" julia = "1.6" +oneAPI = "2.1" [extensions] AMDGPUExt = "AMDGPU" CUDAExt = "CUDA" +OneAPIExt = "oneAPI" [extras] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" diff --git a/ext/OneAPIExt.jl b/ext/OneAPIExt.jl new file mode 100644 index 000000000..6891aaf43 --- /dev/null +++ b/ext/OneAPIExt.jl @@ -0,0 +1,27 @@ +module OneAPIExt + +import MPI +isdefined(Base, :get_extension) ? (import oneAPI) : (import ..oneAPI) +import MPI: MPIPtr, Buffer, Datatype + +function Base.cconvert(::Type{MPIPtr}, A::oneAPI.oneArray{T}) where T + A +end + +function Base.unsafe_convert(::Type{MPIPtr}, X::oneAPI.oneArray{T}) where T + reinterpret(MPIPtr, Base.unsafe_convert(oneAPI.ZePtr{T}, X)) +end + +# only need to define this for strided arrays: all others can be handled by generic machinery +function Base.unsafe_convert(::Type{MPIPtr}, V::SubArray{T,N,P,I,true}) where {T,N,P<:oneAPI.oneArray,I} + X = parent(V) + pX = Base.unsafe_convert(oneAPI.ZePtr{T}, X) + pV = pX + ((V.offset1 + V.stride1) - first(LinearIndices(X)))*sizeof(T) + return reinterpret(MPIPtr, pV) +end + +function Buffer(arr::oneAPI.oneArray) + Buffer(arr, Cint(length(arr)), Datatype(eltype(arr))) +end + +end # OneAPIExt