Skip to content

Commit ea717f7

Browse files
authored
Merge pull request #110 from fverdugo/fix_gather
Fix gather
2 parents a203a50 + e970b50 commit ea717f7

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

src/mpi_array.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,10 @@ end
296296
function gather_impl!(
297297
rcv::MPIArray, snd::MPIArray,
298298
destination, ::Type{T}) where T <: AbstractVector
299+
Tv = eltype(snd.item)
299300
@assert rcv.comm === snd.comm
300301
@assert isa(rcv.item,JaggedArray)
302+
@assert eltype(eltype(rcv.item)) == Tv
301303
comm = snd.comm
302304
if isa(destination,Integer)
303305
root = destination-1
@@ -308,14 +310,14 @@ function gather_impl!(
308310
rcv_buffer = MPI.VBuffer(rcv.item.data,counts)
309311
MPI.Gatherv!(MPI.IN_PLACE,rcv_buffer,root,comm)
310312
else
311-
MPI.Gatherv!(snd.item,nothing,root,comm)
313+
MPI.Gatherv!(convert(Vector{Tv},snd.item),nothing,root,comm)
312314
end
313315
else
314316
@assert destination === :all
315317
@assert length(rcv.item) == MPI.Comm_size(comm)
316318
counts = ptrs_to_counts(rcv.item.ptrs)
317319
rcv_buffer = MPI.VBuffer(rcv.item.data,counts)
318-
MPI.Allgatherv!(snd.item,rcv_buffer,comm)
320+
MPI.Allgatherv!(convert(Vector{Tv},snd.item),rcv_buffer,comm)
319321
end
320322
rcv
321323
end
@@ -541,18 +543,18 @@ function exchange_impl!(
541543
end
542544

543545
# This should go eventually into MPI.jl!
544-
Issend(data, comm::MPI.Comm, req::MPI.AbstractRequest=MPI.Request(); dest::Integer, tag::Integer=0) =
546+
Issend(data, comm::MPI.Comm, req=MPI.Request(); dest::Integer, tag::Integer=0) =
545547
Issend(data, dest, tag, comm, req)
546548

547-
function Issend(buf::MPI.Buffer, dest::Integer, tag::Integer, comm::MPI.Comm, req::MPI.AbstractRequest=MPI.Request())
549+
function Issend(buf::MPI.Buffer, dest::Integer, tag::Integer, comm::MPI.Comm, req=MPI.Request())
548550
@assert MPI.isnull(req)
549551
# int MPI_Issend(const void* buf, int count, MPI_Datatype datatype, int dest,
550552
# int tag, MPI_Comm comm, MPI_Request *request)
551553
MPI.API.MPI_Issend(buf.data, buf.count, buf.datatype, dest, tag, comm, req)
552554
MPI.setbuffer!(req, buf)
553555
return req
554556
end
555-
Issend(data, dest::Integer, tag::Integer, comm::MPI.Comm, req::MPI.AbstractRequest=MPI.Request()) =
557+
Issend(data, dest::Integer, tag::Integer, comm::MPI.Comm, req=MPI.Request()) =
556558
Issend(MPI.Buffer_send(data), dest, tag, comm, req)
557559

558560

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,26 @@
11
module MPIArrayTests
22

33
using PartitionedArrays
4+
using Test
45

56
with_mpi() do distribute
67
rank = distribute(LinearIndices((4,)))
78
display(rank)
89
rank = distribute(LinearIndices((2,2)))
910
display(rank)
11+
12+
n = 4
13+
row_partition = uniform_partition(rank,n)
14+
my_own_to_global = map(own_to_global,row_partition)
15+
ids = gather(my_own_to_global)
16+
map_main(ids) do myids
17+
@test myids == [[1],[2],[3],[4]]
18+
end
19+
ids = gather(my_own_to_global,destination=:all)
20+
map(ids) do myids
21+
@test myids == [[1],[2],[3],[4]]
22+
end
23+
1024
end
1125

1226
end # module

0 commit comments

Comments
 (0)