Skip to content

Commit b507090

Browse files
committed
Adding support for non isbitstypes to gather/scatter
1 parent f03ff94 commit b507090

File tree

2 files changed

+44
-14
lines changed

2 files changed

+44
-14
lines changed

src/mpi_array.jl

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -281,13 +281,21 @@ function gather_impl!(
281281
comm = snd.comm
282282
if isa(destination,Integer)
283283
root = destination-1
284-
if MPI.Comm_rank(comm) == root
285-
@assert length(rcv.item) == MPI.Comm_size(comm)
286-
rcv.item[destination] = snd.item
287-
rcv_buffer = MPI.UBuffer(rcv.item,1)
288-
MPI.Gather!(MPI.IN_PLACE,rcv_buffer,root,comm)
284+
if isbitstype(T)
285+
if MPI.Comm_rank(comm) == root
286+
@assert length(rcv.item) == MPI.Comm_size(comm)
287+
rcv.item[destination] = snd.item
288+
rcv_buffer = MPI.UBuffer(rcv.item,1)
289+
MPI.Gather!(MPI.IN_PLACE,rcv_buffer,root,comm)
290+
else
291+
MPI.Gather!(snd.item_ref,nothing,root,comm)
292+
end
289293
else
290-
MPI.Gather!(snd.item_ref,nothing,root,comm)
294+
if MPI.Comm_rank(comm) == root
295+
rcv.item[:] = MPI.gather(snd.item,comm;root)
296+
else
297+
MPI.gather(snd.item,comm;root)
298+
end
291299
end
292300
else
293301
@assert destination === :all
@@ -330,17 +338,25 @@ end
330338
function scatter_impl!(
331339
rcv::MPIArray,snd::MPIArray,
332340
source,::Type{T}) where T
333-
@assert source !== :all "All to all not implemented"
334-
@assert rcv.comm === snd.comm
335-
@assert eltype(snd.item) == typeof(rcv.item)
336341
comm = snd.comm
337342
root = source - 1
338-
if MPI.Comm_rank(comm) == root
339-
snd_buffer = MPI.UBuffer(snd.item,1)
340-
rcv.item = snd.item[source]
341-
MPI.Scatter!(snd_buffer,MPI.IN_PLACE,root,comm)
343+
@assert source !== :all "All to all not implemented"
344+
@assert rcv.comm === snd.comm
345+
if isbitstype(T)
346+
@assert eltype(snd.item) == typeof(rcv.item)
347+
if MPI.Comm_rank(comm) == root
348+
snd_buffer = MPI.UBuffer(snd.item,1)
349+
rcv.item = snd.item[source]
350+
MPI.Scatter!(snd_buffer,MPI.IN_PLACE,root,comm)
351+
else
352+
MPI.Scatter!(nothing,rcv.item_ref,root,comm)
353+
end
342354
else
343-
MPI.Scatter!(nothing,rcv.item_ref,root,comm)
355+
if MPI.Comm_rank(comm) == root
356+
rcv.item_ref[] = MPI.scatter(snd.item,comm;root)
357+
else
358+
rcv.item_ref[] = MPI.scatter(nothing,comm;root)
359+
end
344360
end
345361
rcv
346362
end

test/primitives_tests.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11

22
using Test
33

4+
struct NonIsBitsType{T}
5+
data::Vector{T}
6+
end
7+
Base.:(==)(a::NonIsBitsType,b::NonIsBitsType) = a.data == b.data
8+
49
function primitives_tests(distribute)
510

611
rank = distribute(LinearIndices((2,2)))
@@ -61,6 +66,15 @@ function primitives_tests(distribute)
6166
@test rcv == [[1],[1,2],[1,2,3],[1,2,3,4]]
6267
end
6368

69+
snd2 = map(rank) do rank
70+
NonIsBitsType([2])
71+
end
72+
rcv2 = gather(snd2)
73+
snd3 = scatter(rcv2)
74+
map(snd2,snd3) do snd2,snd3
75+
@test snd2 == snd3
76+
end
77+
6478
rcv = multicast(rank,source=2)
6579
map(rcv) do rcv
6680
@test rcv == 2

0 commit comments

Comments
 (0)