Skip to content

Commit 5aa785e

Browse files
committed
Working on DistributedArray, adding getDarray instead of Function
1 parent 47344d7 commit 5aa785e

File tree

12 files changed

+324
-167
lines changed

12 files changed

+324
-167
lines changed

docs/Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ help:
1414

1515
.PHONY: help Makefile
1616

17+
doctest:
18+
@$(SPHINXBUILD) -b doctest "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
19+
1720
# Catch-all target: route all unknown targets to Sphinx using the new
1821
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
1922
%: Makefile

docs/source/mpi4py_fft.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ mpi4py\_fft.pencil module
3838
:undoc-members:
3939
:show-inheritance:
4040

41+
mpi4py\_fft.distributedarray module
42+
-----------------------------------
43+
44+
.. automodule:: mpi4py_fft.distributedarray
45+
:members:
46+
:undoc-members:
47+
:show-inheritance:
48+
4149
Module contents
4250
---------------
4351

examples/darray.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,49 @@
11
import numpy as np
22
from mpi4py import MPI
33
from mpi4py_fft.pencil import Subcomm
4-
from mpi4py_fft.distributedarray import DistributedArray
5-
from mpi4py_fft.mpifft import PFFT, Function
4+
from mpi4py_fft.distributedarray import DistributedArray, getDarray
5+
from mpi4py_fft.mpifft import PFFT
66

77
# Test DistributedArray. Start with alignment in axis 0, then tranfer to 1 and
88
# finally to 2
99
N = (16, 14, 12)
1010
z0 = DistributedArray(N, dtype=np.float, alignment=0)
1111
z0[:] = np.random.randint(0, 10, z0.shape)
1212
s0 = MPI.COMM_WORLD.allreduce(np.sum(z0))
13-
print(MPI.COMM_WORLD.Get_rank(), z0.shape)
14-
z1 = z0.redistribute(1)
13+
z1 = z0.redistribute(2)
1514
s1 = MPI.COMM_WORLD.allreduce(np.sum(z1))
16-
print(MPI.COMM_WORLD.Get_rank(), z1.shape)
17-
z2 = z1.redistribute(2)
18-
print(MPI.COMM_WORLD.Get_rank(), z2.shape)
19-
15+
z2 = z1.redistribute(1)
2016
s2 = MPI.COMM_WORLD.allreduce(np.sum(z2))
2117
assert s0 == s1 == s2
2218

23-
fft = PFFT(Subcomm(MPI.COMM_WORLD, [s.Get_size() for s in z2.p0.subcomm]), N, dtype=z2.dtype)
24-
z3 = Function(fft, True)
25-
fft.forward(z2, z3)
19+
fft = PFFT(MPI.COMM_WORLD, darray=z2, axes=(0, 2, 1))
20+
z3 = getDarray(fft, forward_output=True)
21+
z2c = z2.copy()
22+
fft.forward(z2, z3)
23+
fft.backward(z3, z2)
24+
s0, s1 = np.linalg.norm(z2), np.linalg.norm(z2c)
25+
assert abs(s0-s1) < 1e-12, s0-s1
26+
27+
print(z3.local_slice(), z3.substart, z3.commsizes)
28+
29+
v0 = getDarray(fft, forward_output=False, rank=1)
30+
v0[:] = np.random.random(v0.shape)
31+
v0c = v0.copy()
32+
v1 = getDarray(fft, forward_output=True, rank=1)
33+
34+
for i in range(3):
35+
v1[i] = fft.forward(v0[i], v1[i])
36+
for i in range(3):
37+
v0[i] = fft.backward(v1[i], v0[i])
38+
s0, s1 = np.linalg.norm(v0c), np.linalg.norm(v0)
39+
assert abs(s0-s1) < 1e-12
40+
41+
print(v0.substart, v0.commsizes)
42+
43+
nfft = PFFT(MPI.COMM_WORLD, darray=v0[0], axes=(0, 2, 1))
44+
for i in range(3):
45+
v1[i] = nfft.forward(v0[i], v1[i])
46+
for i in range(3):
47+
v0[i] = nfft.backward(v1[i], v0[i])
48+
s0, s1 = np.linalg.norm(v0c), np.linalg.norm(v0)
49+
assert abs(s0-s1) < 1e-12

examples/spectral_dns_solver.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from time import time
1111
import numpy as np
1212
from mpi4py import MPI
13-
from mpi4py_fft.mpifft import PFFT, Function
13+
from mpi4py_fft import PFFT, getDarray
1414

1515
# Set viscosity, end time and time step
1616
nu = 0.000625
@@ -28,18 +28,18 @@
2828
FFT_pad = FFT
2929

3030
# Declare variables needed to solve Navier-Stokes
31-
U = Function(FFT, False, tensor=3) # Velocity
32-
U_hat = Function(FFT, tensor=3) # Velocity transformed
33-
P = Function(FFT, False) # Pressure (scalar)
34-
P_hat = Function(FFT) # Pressure transformed
35-
U_hat0 = Function(FFT, tensor=3) # Runge-Kutta work array
36-
U_hat1 = Function(FFT, tensor=3) # Runge-Kutta work array
37-
a = [1./6., 1./3., 1./3., 1./6.] # Runge-Kutta parameter
38-
b = [0.5, 0.5, 1.] # Runge-Kutta parameter
39-
dU = Function(FFT, tensor=3) # Right hand side of ODEs
40-
curl = Function(FFT, False, tensor=3)
41-
U_pad = Function(FFT_pad, False, tensor=3)
42-
curl_pad = Function(FFT_pad, False, tensor=3)
31+
U = getDarray(FFT, False, rank=1) # Velocity
32+
U_hat = getDarray(FFT, rank=1) # Velocity transformed
33+
P = getDarray(FFT, False) # Pressure (scalar)
34+
P_hat = getDarray(FFT) # Pressure transformed
35+
U_hat0 = getDarray(FFT, rank=1) # Runge-Kutta work array
36+
U_hat1 = getDarray(FFT, rank=1) # Runge-Kutta work array
37+
a = [1./6., 1./3., 1./3., 1./6.] # Runge-Kutta parameter
38+
b = [0.5, 0.5, 1.] # Runge-Kutta parameter
39+
dU = getDarray(FFT, rank=1) # Right hand side of ODEs
40+
curl = getDarray(FFT, False, rank=1)
41+
U_pad = getDarray(FFT_pad, False, rank=1)
42+
curl_pad = getDarray(FFT_pad, False, rank=1)
4343

4444
def get_local_mesh(FFT, L):
4545
"""Returns local mesh."""
@@ -52,10 +52,8 @@ def get_local_mesh(FFT, L):
5252

5353
def get_local_wavenumbermesh(FFT, L):
5454
"""Returns local wavenumber mesh."""
55-
5655
s = FFT.local_slice()
5756
N = FFT.shape()
58-
5957
# Set wavenumbers in grid
6058
k = [np.fft.fftfreq(n, 1./n).astype(int) for n in N[:-1]]
6159
k.append(np.fft.rfftfreq(N[-1], 1./N[-1]).astype(int))

examples/transforms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import functools
22
import numpy as np
33
from mpi4py import MPI
4-
from mpi4py_fft.mpifft import PFFT, Function
4+
from mpi4py_fft import PFFT, DistributedArray
55
from mpi4py_fft.fftw import dctn, idctn
66

77
# Set global size of the computational box
@@ -17,16 +17,16 @@
1717

1818
assert fft.axes == pfft.axes
1919

20-
u = Function(fft, False)
20+
u = DistributedArray(pfft=fft, forward_output=False)
2121
u[:] = np.random.random(u.shape).astype(u.dtype)
2222

23-
u_hat = Function(fft)
23+
u_hat = DistributedArray(pfft=fft, forward_output=True)
2424
u_hat = fft.forward(u, u_hat)
2525
uj = np.zeros_like(u)
2626
uj = fft.backward(u_hat, uj)
2727
assert np.allclose(uj, u)
2828

29-
u_padded = Function(pfft, False)
29+
u_padded = DistributedArray(pfft=pfft, forward_output=False)
3030
uc = u_hat.copy()
3131
u_padded = pfft.backward(u_hat, u_padded)
3232
u_hat = pfft.forward(u_padded, u_hat)

mpi4py_fft/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
__version__ = '1.1.2'
2020
__author__ = 'Lisandro Dalcin and Mikael Mortensen'
2121

22-
from .distributedarray import DistributedArray
23-
from .mpifft import PFFT, Function
22+
from .distributedarray import DistributedArray, getDarray
23+
from .mpifft import PFFT
2424
from . import fftw
2525
from .utilities import HDF5File, NCFile, generate_xdmf

0 commit comments

Comments
 (0)