Skip to content

Commit 34bcef8

Browse files
committed
For consistency with DistArray local_shape -> shape and shape -> global_shape
1 parent f55e8d8 commit 34bcef8

File tree

8 files changed

+72
-183
lines changed

8 files changed

+72
-183
lines changed

examples/darray.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@
7070
s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(z2)**2)
7171
if MPI.COMM_WORLD.Get_rank() == 0:
7272
assert abs(s0-s1) < 1e-12
73-
print('hei')
73+
7474
N = (3, 3, 6, 6, 6)
7575
z2 = DistArray(N, dtype=float, val=1, alignment=2, rank=2)
7676
z2[:] = MPI.COMM_WORLD.Get_rank()
77-
#z1 = z2.redistribute(1)
78-
#z0 = z1.redistribute(0)
77+
z1 = z2.redistribute(1)
78+
z0 = z1.redistribute(0)
7979

8080
s0 = MPI.COMM_WORLD.reduce(np.linalg.norm(z2)**2)
8181
s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(z0)**2)

examples/spectral_dns_solver.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,32 +28,32 @@
2828
FFT_pad = FFT
2929

3030
# Declare variables needed to solve Navier-Stokes
31-
U = newDistArray(FFT, False, rank=1) # Velocity
32-
U_hat = newDistArray(FFT, rank=1) # Velocity transformed
33-
P = newDistArray(FFT, False) # Pressure (scalar)
34-
P_hat = newDistArray(FFT) # Pressure transformed
35-
U_hat0 = newDistArray(FFT, rank=1) # Runge-Kutta work array
36-
U_hat1 = newDistArray(FFT, rank=1) # Runge-Kutta work array
31+
U = newDistArray(FFT, False, rank=1, view=True) # Velocity
32+
U_hat = newDistArray(FFT, rank=1, view=True) # Velocity transformed
33+
P = newDistArray(FFT, False, view=True) # Pressure (scalar)
34+
P_hat = newDistArray(FFT, view=True) # Pressure transformed
35+
U_hat0 = newDistArray(FFT, rank=1, view=True) # Runge-Kutta work array
36+
U_hat1 = newDistArray(FFT, rank=1, view=True) # Runge-Kutta work array
3737
a = [1./6., 1./3., 1./3., 1./6.] # Runge-Kutta parameter
3838
b = [0.5, 0.5, 1.] # Runge-Kutta parameter
39-
dU = newDistArray(FFT, rank=1) # Right hand side of ODEs
40-
curl = newDistArray(FFT, False, rank=1)
41-
U_pad = newDistArray(FFT_pad, False, rank=1)
42-
curl_pad = newDistArray(FFT_pad, False, rank=1)
39+
dU = newDistArray(FFT, rank=1, view=True) # Right hand side of ODEs
40+
curl = newDistArray(FFT, False, rank=1, view=True)
41+
U_pad = newDistArray(FFT_pad, False, rank=1, view=True)
42+
curl_pad = newDistArray(FFT_pad, False, rank=1, view=True)
4343

4444
def get_local_mesh(FFT, L):
4545
"""Returns local mesh."""
4646
X = np.ogrid[FFT.local_slice(False)]
47-
N = FFT.shape()
47+
N = FFT.global_shape()
4848
for i in range(len(N)):
4949
X[i] = (X[i]*L[i]/N[i])
50-
X = [np.broadcast_to(x, FFT.local_shape(False)) for x in X]
50+
X = [np.broadcast_to(x, FFT.shape(False)) for x in X]
5151
return X
5252

5353
def get_local_wavenumbermesh(FFT, L):
5454
"""Returns local wavenumber mesh."""
5555
s = FFT.local_slice()
56-
N = FFT.shape()
56+
N = FFT.global_shape()
5757
# Set wavenumbers in grid
5858
k = [np.fft.fftfreq(n, 1./n).astype(int) for n in N[:-1]]
5959
k.append(np.fft.rfftfreq(N[-1], 1./N[-1]).astype(int))
@@ -62,7 +62,7 @@ def get_local_wavenumbermesh(FFT, L):
6262
Lp = 2*np.pi/L
6363
for i in range(3):
6464
Ks[i] = (Ks[i]*Lp[i]).astype(float)
65-
return [np.broadcast_to(k, FFT.local_shape(True)) for k in Ks]
65+
return [np.broadcast_to(k, FFT.shape(True)) for k in Ks]
6666

6767
X = get_local_mesh(FFT, L)
6868
K = get_local_wavenumbermesh(FFT, L)

mpi4py_fft/distarray.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,19 @@ class DistArray(np.ndarray):
1010
"""Distributed Numpy array
1111
1212
This Numpy array is part of a larger global array. Information about the
13-
distribution is contained in the attributes
13+
distribution is contained in the attributes.
1414
1515
Parameters
1616
----------
1717
global_shape : sequence of ints
1818
Shape of non-distributed global array
19-
subcomm : None, Subcomm instance or sequence of ints, optional
19+
subcomm : None, :class:`.Subcomm` object or sequence of ints, optional
2020
Describes how to distribute the array
2121
val : Number or None, optional
2222
Initialize array with this number if buffer is not given
2323
dtype : np.dtype, optional
2424
Type of array
25-
buffer : np.ndarray, optional
25+
buffer : Numpy array, optional
2626
Array of correct shape
2727
alignment : None or int, optional
2828
Make sure array is aligned in this direction. Note that alignment does
@@ -100,7 +100,13 @@ def __array_finalize__(self, obj):
100100

101101
@property
102102
def alignment(self):
103-
"""Return alignment of local ``self`` array"""
103+
"""Return alignment of local ``self`` array
104+
105+
Note
106+
----
107+
For tensors of rank > 0 the array is actually aligned along
108+
``alignment+rank``
109+
"""
104110
return self._p0.axis
105111

106112
@property
@@ -130,7 +136,7 @@ def pencil(self):
130136

131137
@property
132138
def rank(self):
133-
"""Return rank of ``self``"""
139+
"""Return tensor rank of ``self``"""
134140
return self._rank
135141

136142
def __getitem__(self, i):
@@ -191,9 +197,9 @@ def get_global_slice(self, gslice):
191197
s = self.local_slice()
192198
sp = np.nonzero([isinstance(x, slice) for x in gslice])[0]
193199
sf = tuple(np.take(s, sp))
194-
N = self.global_shape
195-
f.require_dataset('0', shape=tuple(np.take(N, sp)), dtype=self.dtype)
200+
f.require_dataset('data', shape=tuple(np.take(self.global_shape, sp)), dtype=self.dtype)
196201
gslice = list(gslice)
202+
# We are required to check if the indices in si are on this processor
197203
si = np.nonzero([isinstance(x, int) and not z == slice(None) for x, z in zip(gslice, s)])[0]
198204
on_this_proc = True
199205
for i in si:
@@ -202,12 +208,12 @@ def get_global_slice(self, gslice):
202208
else:
203209
on_this_proc = False
204210
if on_this_proc:
205-
f["0"][sf] = self[tuple(gslice)]
211+
f["data"][sf] = self[tuple(gslice)]
206212
f.close()
207213
c = None
208214
if comm.Get_rank() == 0:
209215
h = h5py.File('tmp.h5', 'r')
210-
c = h['0'].__array__()
216+
c = h['data'].__array__()
211217
h.close()
212218
os.remove('tmp.h5')
213219
return c
@@ -279,10 +285,10 @@ def redistribute(self, axis=None, darray=None):
279285
280286
Returns
281287
-------
282-
:class:`.DistArray` : darray
288+
DistArray : darray
283289
The ``self`` array globally redistributed. If keyword ``darray`` is
284290
None then a new DistArray (aligned along ``axis``) is created
285-
and returned
291+
and returned. Otherwise the provided darray is returned.
286292
"""
287293
if axis is None:
288294
assert isinstance(darray, np.ndarray)
@@ -308,7 +314,7 @@ def redistribute(self, axis=None, darray=None):
308314
return darray
309315

310316
def newDistArray(pfft, forward_output=True, val=0, rank=0, view=False):
311-
"""Return a :class:`.DistArray` for provided :class:`.PFFT` object
317+
"""Return a new :class:`.DistArray` object for provided :class:`.PFFT` object
312318
313319
Parameters
314320
----------
@@ -317,15 +323,21 @@ def newDistArray(pfft, forward_output=True, val=0, rank=0, view=False):
317323
If False then create newDistArray of shape/type for input to
318324
forward transform, otherwise create newDistArray of shape/type for
319325
output from forward transform.
320-
val : int or float
326+
val : int or float, optional
321327
Value used to initialize array.
322-
rank: int
328+
rank: int, optional
323329
Scalar has rank 0, vector 1 and matrix 2
324-
view : bool
330+
view : bool, optional
325331
If True return view of the underlying Numpy array, i.e., return
326332
cls.view(np.ndarray). Note that the DistArray still will
327333
be accessible through the base attribute of the view.
328334
335+
Returns
336+
-------
337+
Distarray
338+
A new :class:`.DistArray` object. Return the ``ndarray`` view if
339+
keyword ``view`` is True.
340+
329341
Examples
330342
--------
331343
>>> from mpi4py import MPI
@@ -335,17 +347,13 @@ def newDistArray(pfft, forward_output=True, val=0, rank=0, view=False):
335347
>>> u_hat = newDistArray(FFT, True, rank=1)
336348
337349
"""
350+
global_shape = pfft.global_shape(forward_output)
351+
p0 = pfft.pencil[forward_output]
338352
if forward_output is True:
339-
shape = pfft.forward.output_array.shape
340353
dtype = pfft.forward.output_array.dtype
341-
p0 = pfft.pencil[1]
342354
else:
343-
shape = pfft.forward.input_array.shape
344355
dtype = pfft.forward.input_array.dtype
345-
p0 = pfft.pencil[0]
346-
commsizes = [s.Get_size() for s in p0.subcomm]
347-
global_shape = tuple([s*p for s, p in zip(shape, commsizes)])
348-
global_shape = (len(shape),)*rank + global_shape
356+
global_shape = (len(global_shape),)*rank + global_shape
349357
z = DistArray(global_shape, subcomm=p0.subcomm, val=val, dtype=dtype,
350358
rank=rank)
351359
return z.v if view else z

mpi4py_fft/mpifft.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def destroy(self):
327327
for trans in self.transfer:
328328
trans.destroy()
329329

330-
def local_shape(self, forward_output=True):
330+
def shape(self, forward_output=True):
331331
"""The local (to each processor) shape of data
332332
333333
Parameters
@@ -360,14 +360,27 @@ def local_slice(self, forward_output=True):
360360
ip.subshape)]
361361
return tuple(s)
362362

363-
def shape(self, forward_output=False):
364-
"""Return shape of tensor for space
363+
def local_shape(self, forward_output=False):
364+
"""The local (to each processor) shape of data
365+
366+
Parameters
367+
----------
368+
forward_output : bool, optional
369+
Return shape of output array (spectral space) if True, else return
370+
shape of input array (physical space)
371+
"""
372+
import warnings
373+
warnings.warn("local_shape() is deprecated; use shape().", FutureWarning)
374+
return self.shape(forward_output)
375+
376+
def global_shape(self, forward_output=False):
377+
"""Return global shape of associated tensors
365378
366379
Parameters
367380
----------
368381
forward_output : bool, optional
369-
If True then return shape of spectral space, i.e., the input to
370-
a backward transfer. If False then return shape of physical
382+
If True then return global shape of spectral space, i.e., the input
383+
to a backward transfer. If False then return shape of physical
371384
space, i.e., the input to a forward transfer.
372385
"""
373386
if forward_output:

mpi4py_fft/utilities/h5py_file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def _write_slice_step(self, name, step, slices, field, **kw):
140140
group = "/".join((name, "{}D".format(ndims), slname))
141141
if group not in self.f:
142142
self.f.create_group(group)
143-
N = self.T.shape(forward_output)
143+
N = self.T.global_shape(forward_output)
144144
self.f[group].require_dataset(str(step), shape=tuple(np.take(N, sp)), dtype=field.dtype)
145145
if inside == 1:
146146
self.f["/".join((group, str(step)))][sf] = field[sl]
@@ -151,5 +151,5 @@ def _write_group(self, name, u, step, **kw):
151151
group = "/".join((name, "{}D".format(self.T.dimensions())))
152152
if group not in self.f:
153153
self.f.create_group(group)
154-
self.f[group].require_dataset(str(step), shape=self.T.shape(forward_output), dtype=u.dtype)
154+
self.f[group].require_dataset(str(step), shape=self.T.global_shape(forward_output), dtype=u.dtype)
155155
self.f["/".join((group, str(step)))][s] = u

mpi4py_fft/utilities/nc_file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, ncname, T, domain=None, clobber=True, mode='r', **kw):
4242
from netCDF4 import Dataset
4343
self.filename = ncname
4444
self.f = f = Dataset(ncname, mode=mode, clobber=clobber, parallel=True, comm=comm, **kw)
45-
self.N = N = T.shape(False)[-T.dimensions():]
45+
self.N = N = T.global_shape(False)[-T.dimensions():]
4646
dtype = self.T.dtype(False)
4747
assert dtype.char in 'fdg'
4848
self._dtype = dtype
@@ -66,7 +66,7 @@ def __init__(self, ncname, T, domain=None, clobber=True, mode='r', **kw):
6666
self.dims.append(xyz)
6767
nc_xyz = f.createVariable(xyz, self._dtype, (xyz))
6868
nc_xyz[:] = d[i]
69-
f.setncatts({"ndim": T.dimensions(), "shape": T.shape(False)})
69+
f.setncatts({"ndim": T.dimensions(), "shape": T.global_shape(False)})
7070
f.sync()
7171
self.close()
7272

0 commit comments

Comments
 (0)