Skip to content

Commit 7b094f1

Browse files
committed
Adding DistributedArray class
1 parent 8c04f8f commit 7b094f1

File tree

3 files changed

+115
-0
lines changed

3 files changed

+115
-0
lines changed

examples/darray.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import numpy as np
2+
from mpi4py import MPI
3+
from mpi4py_fft.pencil import Subcomm
4+
from mpi4py_fft.distributedarray import DistributedArray
5+
from mpi4py_fft.mpifft import PFFT, Function
6+
7+
# Test DistributedArray. Start with alignment in axis 0, then tranfer to 1 and
8+
# finally to 2
9+
N = (16, 14, 12)
10+
z0 = DistributedArray(N, dtype=np.float, alignment=0)
11+
z0[:] = np.random.randint(0, 10, z0.shape)
12+
s0 = MPI.COMM_WORLD.allreduce(np.sum(z0))
13+
print(MPI.COMM_WORLD.Get_rank(), z0.shape)
14+
z1 = z0.redistribute(1)
15+
s1 = MPI.COMM_WORLD.allreduce(np.sum(z1))
16+
print(MPI.COMM_WORLD.Get_rank(), z1.shape)
17+
z2 = z1.redistribute(2)
18+
print(MPI.COMM_WORLD.Get_rank(), z2.shape)
19+
20+
s2 = MPI.COMM_WORLD.allreduce(np.sum(z2))
21+
assert s0 == s1 == s2
22+
23+
fft = PFFT(Subcomm(MPI.COMM_WORLD, [s.Get_size() for s in z2.p0.subcomm]), N, dtype=z2.dtype)
24+
z3 = Function(fft, True)
25+
fft.forward(z2, z3)

mpi4py_fft/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
__version__ = '1.1.2'
2020
__author__ = 'Lisandro Dalcin and Mikael Mortensen'
2121

22+
from .distributedarray import DistributedArray
2223
from .mpifft import PFFT, Function
2324
from . import fftw
2425
from .utilities import HDF5File, NCFile, generate_xdmf

mpi4py_fft/distributedarray.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import numpy as np
2+
from mpi4py import MPI
3+
from .pencil import Pencil, Subcomm
4+
5+
comm = MPI.COMM_WORLD
6+
7+
class DistributedArray(np.ndarray):
8+
"""Distributed Numpy array
9+
10+
Parameters
11+
----------
12+
global_shape : sequence of ints
13+
Shape of non-distributed global array
14+
subcomm : None, Subcomm instance or sequence of ints
15+
Describes how to distribute the array
16+
val : int
17+
Initialize array with this value if buffer is not given
18+
dtype : np.dtype
19+
Type of array
20+
buffer : np.ndarray
21+
Array of correct shape
22+
alignment : None or int
23+
Make sure array is aligned in this direction
24+
25+
"""
26+
def __new__(cls, global_shape, subcomm=None, val=0, dtype=np.float,
27+
buffer=None, alignment=None):
28+
if isinstance(subcomm, Subcomm):
29+
pass
30+
else:
31+
if isinstance(subcomm, (tuple, list)):
32+
assert len(subcomm) == len(global_shape)
33+
# Do nothing if already containing communicators. A tuple of subcommunicators is not necessarily a Subcomm
34+
if not np.all([isinstance(s, MPI.Cartcomm) for s in subcomm]):
35+
subcomm = Subcomm(comm, subcomm)
36+
else:
37+
assert subcomm is None
38+
if alignment is not None:
39+
subcomm = [0] * len(global_shape)
40+
subcomm[alignment] = 1
41+
subcomm = Subcomm(comm, subcomm)
42+
43+
sizes = [s.Get_size() for s in subcomm]
44+
if alignment is not None:
45+
assert isinstance(alignment, int)
46+
assert sizes[alignment] == 1
47+
else:
48+
# Decide that alignment is the last axis with size 1
49+
alignment = np.flatnonzero(np.array(sizes) == 1)[-1]
50+
p0 = Pencil(subcomm, global_shape, axis=alignment)
51+
obj = np.ndarray.__new__(cls, p0.subshape, dtype=dtype, buffer=buffer)
52+
if buffer is None:
53+
obj.fill(val)
54+
obj.p0 = p0
55+
obj.global_shape = global_shape
56+
return obj
57+
58+
def alignment(self):
59+
return self.p0.axis
60+
61+
def __array_finalize__(self, obj):
62+
if obj is None:
63+
return
64+
self.p0 = getattr(obj, 'p0', None)
65+
self.global_shape = getattr(obj, 'global_shape', None)
66+
67+
def redistribute(self, axis):
68+
"""Global redistribution of array into alignment in ``axis``
69+
70+
Parameters
71+
----------
72+
axis : int
73+
Align array along this axis
74+
75+
Returns
76+
-------
77+
DistributedArray
78+
self array globally redistributed along new axis
79+
80+
"""
81+
p1 = self.p0.pencil(axis)
82+
transfer = self.p0.transfer(p1, self.dtype)
83+
z0 = np.zeros(p1.subshape, dtype=self.dtype)
84+
transfer.forward(self, z0)
85+
return DistributedArray(self.global_shape,
86+
subcomm=p1.subcomm,
87+
dtype=self.dtype,
88+
alignment=axis,
89+
buffer=z0)

0 commit comments

Comments
 (0)