Skip to content

Commit 2dae0ca

Browse files
committed
Fixing distributed array tests
1 parent 4a94917 commit 2dae0ca

File tree

5 files changed

+160
-10
lines changed

5 files changed

+160
-10
lines changed

examples/darray.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828

2929
print(z3.local_slice(), z3.substart, z3.commsizes)
3030

31-
#v0 = newDarray(fft, forward_output=False, rank=1)
32-
v0 = Function(fft, forward_output=False, rank=1)
31+
v0 = newDarray(fft, forward_output=False, rank=1)
32+
#v0 = Function(fft, forward_output=False, rank=1)
3333
v0[:] = np.random.random(v0.shape)
3434
v0c = v0.copy()
3535
v1 = newDarray(fft, forward_output=True, rank=1)
@@ -68,5 +68,16 @@
6868

6969
s0 = MPI.COMM_WORLD.reduce(np.linalg.norm(z)**2)
7070
s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(z2)**2)
71-
print(s0, s1)
72-
assert abs(s0-s1) < 1e-12
71+
if MPI.COMM_WORLD.Get_rank() == 0:
72+
assert abs(s0-s1) < 1e-12
73+
74+
N = (3, 3, 6, 6, 6)
75+
z2 = DistributedArray(N, dtype=float, val=1, alignment=2, rank=2)
76+
z2[:] = MPI.COMM_WORLD.Get_rank()
77+
z1 = z2.redistribute(1)
78+
z0 = z1.redistribute(0)
79+
80+
s0 = MPI.COMM_WORLD.reduce(np.linalg.norm(z2)**2)
81+
s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(z0)**2)
82+
if MPI.COMM_WORLD.Get_rank() == 0:
83+
assert abs(s0-s1) < 1e-12

mpi4py_fft/distributedarray.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ class DistributedArray(np.ndarray):
5353
"""
5454
def __new__(cls, global_shape, subcomm=None, val=None, dtype=np.float,
5555
buffer=None, alignment=None, rank=0):
56-
5756
if rank > 0:
5857
assert global_shape[:rank] == (len(global_shape[rank:]),)*rank
5958

@@ -67,13 +66,16 @@ def __new__(cls, global_shape, subcomm=None, val=None, dtype=np.float,
6766
subcomm = Subcomm(comm, subcomm)
6867
else:
6968
assert subcomm is None
69+
subcomm = [0] * len(global_shape[rank:])
7070
if alignment is not None:
71-
subcomm = [0] * len(global_shape[rank:])
7271
subcomm[alignment] = 1
72+
else:
73+
subcomm[-1] = 1
74+
alignment = len(subcomm)-1
7375
subcomm = Subcomm(comm, subcomm)
7476
sizes = [s.Get_size() for s in subcomm]
7577
if alignment is not None:
76-
assert isinstance(alignment, int)
78+
assert isinstance(alignment, (int, np.integer))
7779
assert sizes[alignment] == 1
7880
else:
7981
# Decide that alignment is the last axis with size 1
@@ -131,12 +133,16 @@ def rank(self):
131133
return self._rank
132134

133135
def __getitem__(self, i):
134-
# Return DistributedArray if i is an integer and rank > 0
136+
# Return DistributedArray if the result is a component of a tensor
135137
# Otherwise return ndarray view
136138
if isinstance(i, int) and self.rank > 0:
137139
v0 = np.ndarray.__getitem__(self, i)
138140
v0._rank -= 1
139141
return v0
142+
if isinstance(i, tuple) and self.rank == 2:
143+
v0 = np.ndarray.__getitem__(self, i)
144+
v0._rank = 0
145+
return v0
140146
return np.ndarray.__getitem__(self.v, i)
141147

142148
@property
@@ -288,7 +294,16 @@ def redistribute(self, axis=None, darray=None):
288294
dtype=self.dtype,
289295
alignment=axis,
290296
rank=self.rank)
291-
transfer.forward(self, darray)
297+
if self.rank == 0:
298+
transfer.forward(self, darray)
299+
elif self.rank == 1:
300+
for i in range(self.shape[0]):
301+
transfer.forward(self[i], darray[i])
302+
elif self.rank == 2:
303+
for i in range(self.shape[0]):
304+
for j in range(self.shape[1]):
305+
transfer.forward(self[i, j], darray[i, j])
306+
292307
return darray
293308

294309
def newDarray(pfft, forward_output=True, val=0, rank=0, view=False):

mpi4py_fft/mpifft.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,12 @@ def __init__(self, comm, shape=None, axes=None, dtype=float, slab=False,
200200
axes = list(axes) if np.ndim(axes) else [axes]
201201
else:
202202
axes = list(range(len(shape)))
203+
if darray is not None:
204+
# Make sure aligned axis of darray is transformed first
205+
axes = list(np.roll(axes, len(shape)-1-darray.alignment))
203206

204207
for i, ax in enumerate(axes):
205-
if isinstance(ax, int):
208+
if isinstance(ax, (int, np.integer)):
206209
if ax < 0:
207210
ax += len(shape)
208211
axes[i] = (ax,)

tests/runtests.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ if [ $PY -eq 3 ]; then
1414
python -m coverage run -m test_fftw
1515
python -m coverage run -m test_libfft
1616
python -m coverage run -m test_io
17+
python -m coverage run -m test_darray
1718
mpiexec -n 2 python -m coverage run -m test_pencil
1819

1920
#mpiexec -n 4 python -m coverage test_pencil.py
@@ -25,6 +26,8 @@ if [ $PY -eq 3 ]; then
2526
mpiexec -n 2 python -m coverage run spectral_dns_solver.py
2627
mpiexec -n 2 python -m coverage run -m test_io
2728
mpiexec -n 4 python -m coverage run -m test_io
29+
mpiexec -n 2 python -m coverage run -m test_darray
30+
mpiexec -n 4 python -m coverage run -m test_darray
2831

2932
python -m coverage combine
3033

@@ -38,5 +41,7 @@ else
3841
#mpiexec -n 4 python test_mpifft.py
3942
# mpiexec -n 8 python test_mpifft.py
4043
# mpiexec -n 12 python test_mpifft.py
44+
mpiexec -n 2 python test_io.py
45+
mpiexec -n 2 python test_darray.py
4146
mpiexec -n 2 python spectral_dns_solver.py
4247
fi

tests/test_darray.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import numpy as np
2+
from mpi4py import MPI
3+
from mpi4py_fft import DistributedArray, newDarray, PFFT
4+
from mpi4py_fft.pencil import Subcomm
5+
6+
comm = MPI.COMM_WORLD
7+
8+
def test_2Darray():
9+
N = (8, 8)
10+
for subcomm in ((0, 1), (1, 0), None, Subcomm(comm, (0, 1))):
11+
for rank in (0, 1, 2):
12+
M = (2,)*rank + N
13+
alignment = None
14+
if subcomm is None and rank == 1:
15+
alignment = 1
16+
a = DistributedArray(M, subcomm=subcomm, val=1, rank=rank, alignment=alignment)
17+
assert a.rank == rank
18+
assert a.global_shape == M
19+
s = a.substart
20+
c = a.subcomm
21+
z = a.commsizes
22+
p = a.pencil
23+
assert np.prod(np.array(z)) == comm.Get_size()
24+
if rank > 0:
25+
a0 = a[0]
26+
assert isinstance(a0, DistributedArray)
27+
assert a0.rank == rank-1
28+
aa = a.v
29+
assert isinstance(aa, np.ndarray)
30+
k = a.get_global_slice((0,)*rank+(0, slice(None)))
31+
if comm.Get_rank() == 0:
32+
assert len(k) == N[1]
33+
assert np.sum(k) == N[1]
34+
k = a.get_global_slice((0,)*rank+(slice(None), 0))
35+
if comm.Get_rank() == 0:
36+
assert len(k) == N[0]
37+
assert np.sum(k) == N[0]
38+
ls = a.local_slice()
39+
newaxis = (a.alignment+1)%2
40+
p0, t = a.get_pencil_and_transfer(newaxis)
41+
a[:] = MPI.COMM_WORLD.Get_rank()
42+
b = a.redistribute(newaxis)
43+
a = b.redistribute(darray=a)
44+
s0 = MPI.COMM_WORLD.reduce(np.linalg.norm(a)**2)
45+
s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(b)**2)
46+
if MPI.COMM_WORLD.Get_rank() == 0:
47+
assert abs(s0-s1) < 1e-1
48+
49+
def test_3Darray():
50+
N = (8, 8, 8)
51+
for subcomm in ((0, 0, 1), (0, 1, 0), (1, 0, 0), (0, 1, 1), (1, 0, 1), (1, 1, 0), None, Subcomm(comm, (0, 0, 1))):
52+
for rank in (0, 1, 2):
53+
M = (3,)*rank + N
54+
alignment = None
55+
if subcomm is None and rank == 1:
56+
alignment = 2
57+
a = DistributedArray(M, subcomm=subcomm, val=1, rank=rank, alignment=alignment)
58+
assert a.rank == rank
59+
assert a.global_shape == M
60+
s = a.substart
61+
c = a.subcomm
62+
z = a.commsizes
63+
p = a.pencil
64+
assert np.prod(np.array(z)) == comm.Get_size()
65+
if rank > 0:
66+
a0 = a[0]
67+
assert isinstance(a0, DistributedArray)
68+
assert a0.rank == rank-1
69+
if rank == 2:
70+
a0 = a[0, 1]
71+
assert isinstance(a0, DistributedArray)
72+
assert a0.rank == 0
73+
aa = a.v
74+
assert isinstance(aa, np.ndarray)
75+
k = a.get_global_slice((0,)*rank+(0, 0, slice(None)))
76+
if comm.Get_rank() == 0:
77+
assert len(k) == N[2]
78+
assert np.sum(k) == N[2]
79+
k = a.get_global_slice((0,)*rank+(slice(None), 0, 0))
80+
if comm.Get_rank() == 0:
81+
assert len(k) == N[0]
82+
assert np.sum(k) == N[0]
83+
ls = a.local_slice()
84+
newaxis = (a.alignment+1)%3
85+
p0, t = a.get_pencil_and_transfer(newaxis)
86+
a[:] = MPI.COMM_WORLD.Get_rank()
87+
b = a.redistribute(newaxis)
88+
a = b.redistribute(darray=a)
89+
s0 = MPI.COMM_WORLD.reduce(np.linalg.norm(a)**2)
90+
s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(b)**2)
91+
if MPI.COMM_WORLD.Get_rank() == 0:
92+
assert abs(s0-s1) < 1e-1
93+
94+
def test_newDarray():
95+
N = (8, 8, 8)
96+
pfft = PFFT(MPI.COMM_WORLD, N)
97+
for forward_output in (True, False):
98+
for view in (True, False):
99+
for rank in (0, 1, 2):
100+
a = newDarray(pfft, forward_output=forward_output,
101+
rank=rank, view=view)
102+
if view is False:
103+
assert isinstance(a, DistributedArray)
104+
assert a.rank == rank
105+
if rank == 0:
106+
qfft = PFFT(MPI.COMM_WORLD, darray=a)
107+
elif rank == 1:
108+
qfft = PFFT(MPI.COMM_WORLD, darray=a[0])
109+
else:
110+
assert isinstance(a, np.ndarray)
111+
assert a.base.rank == rank
112+
113+
if __name__ == '__main__':
114+
test_2Darray()
115+
test_3Darray()
116+
test_newDarray()

0 commit comments

Comments
 (0)