Skip to content

Commit cbc69c8

Browse files
committed
Fixing generate_xdmf because paraview now requires 2D data to be 3D
1 parent ac510f8 commit cbc69c8

File tree

7 files changed

+26
-11
lines changed

7 files changed

+26
-11
lines changed

mpi4py_fft/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@
2222
from .distarray import DistArray, newDistArray, Function
2323
from .mpifft import PFFT
2424
from . import fftw
25+
from .fftw import fftlib
2526
from .io import HDF5File, NCFile, generate_xdmf

mpi4py_fft/distarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ class DistArray(np.ndarray):
5656
5757
"""
5858
def __new__(cls, global_shape, subcomm=None, val=None, dtype=float,
59-
buffer=None, alignment=None, rank=0):
59+
buffer=None, strides=None, alignment=None, rank=0):
6060
if len(global_shape[rank:]) < 2: # 1D case
61-
obj = np.ndarray.__new__(cls, global_shape, dtype=dtype, buffer=buffer)
61+
obj = np.ndarray.__new__(cls, global_shape, dtype=dtype, buffer=buffer, strides=strides)
6262
if buffer is None and isinstance(val, Number):
6363
obj.fill(val)
6464
obj._rank = rank

mpi4py_fft/fftw/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .xfftn import *
22
from .factory import get_planned_FFT, export_wisdom, import_wisdom, \
3-
forget_wisdom, cleanup, set_timelimit, get_fftw_lib
3+
forget_wisdom, cleanup, set_timelimit, get_fftw_lib, fftlib

mpi4py_fft/io/generate_xdmf.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,16 @@ def get_geometry(kind=0, dim=2):
3838
</DataItem>
3939
</Geometry>"""
4040

41-
return """<Geometry Type="VXVY">
41+
return """<Geometry Type="VXVYVZ">
4242
<DataItem Format="HDF" NumberType="Float" Precision="{0}" Dimensions="{1}">
4343
{3}:{6}/mesh/{4}
4444
</DataItem>
4545
<DataItem Format="HDF" NumberType="Float" Precision="{0}" Dimensions="{2}">
4646
{3}:{6}/mesh/{5}
4747
</DataItem>
48+
<DataItem Format="XML" NumberType="Float" Precision="8" Dimensions="1">
49+
0
50+
</DataItem>
4851
</Geometry>"""
4952

5053
if dim == 3:
@@ -74,7 +77,7 @@ def get_topology(dims, kind=0):
7477
assert len(dims) in (2, 3)
7578
co = 'Co' if kind == 0 else ''
7679
if len(dims) == 2:
77-
return """<Topology Dimensions="{0} {1}" Type="2D{2}RectMesh"/>""".format(dims[0], dims[1], co)
80+
return """<Topology Dimensions="1 {0} {1}" Type="3D{2}RectMesh"/>""".format(dims[0], dims[1], co)
7881
if len(dims) == 3:
7982
return """<Topology Dimensions="{0} {1} {2}" Type="3D{3}RectMesh"/>""".format(dims[0], dims[1], dims[2], co)
8083

@@ -83,7 +86,7 @@ def get_attribute(attr, h5filename, dims, prec):
8386
assert len(dims) in (2, 3)
8487
if len(dims) == 2:
8588
return """<Attribute Name="{0}" Center="Node">
86-
<DataItem Format="HDF" NumberType="Float" Precision="{5}" Dimensions="{1} {2}">
89+
<DataItem Format="HDF" NumberType="Float" Precision="{5}" Dimensions="1 {1} {2}">
8790
{3}:/{4}
8891
</DataItem>
8992
</Attribute>
@@ -96,7 +99,7 @@ def get_attribute(attr, h5filename, dims, prec):
9699
</Attribute>
97100
""".format(name, dims[0], dims[1], dims[2], h5filename, attr, prec)
98101

99-
def generate_xdmf(h5filename, periodic=True, order='paraview'):
102+
def generate_xdmf(h5filename, periodic=True, order='visit'):
100103
"""Generate XDMF-files
101104
102105
Parameters

mpi4py_fft/mpifft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def shape(self, forward_output=True):
363363
"""
364364
if forward_output is not True:
365365
return self.forward.input_pencil.subshape
366-
return self.backward.input_pencil.subshape
366+
return self.forward.output_array.shape
367367

368368
def local_slice(self, forward_output=True):
369369
"""The local view into the global data

tests/test_io.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ def remove_if_exists(filename):
2121
except OSError:
2222
pass
2323

24+
def cleanup():
25+
import glob
26+
files = glob.glob('*.h5')+glob.glob('*.xdmf')+glob.glob('*.nc')
27+
for f in files:
28+
remove_if_exists(f)
29+
2430
def test_2D(backend, forward_output):
2531
if backend == 'netcdf4':
2632
assert forward_output is False
@@ -162,6 +168,7 @@ def test_4D(backend, forward_output):
162168

163169
if __name__ == '__main__':
164170
#pylint: disable=unused-import
171+
cleanup()
165172
skip = {'hdf5': False, 'netcdf4': False}
166173
try:
167174
import h5py
@@ -181,3 +188,4 @@ def test_4D(backend, forward_output):
181188
test_2D(bnd, kind)
182189
if bnd == 'hdf5':
183190
test_4D(bnd, kind)
191+
cleanup()

tests/test_speed.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
import pyfftw
44
import scipy.fftpack as sp
55
from mpi4py_fft import fftw
6+
import pickle
67

78
try:
8-
fftw.import_wisdom('wisdom.dat')
9-
except AssertionError:
10-
pass
9+
#fftw.import_wisdom('wisdom.dat')
10+
pyfftw.import_wisdom(pickle.load(open('pyfftw.wisdom', 'rb')))
11+
print('Wisdom imported')
12+
except:
13+
print('Wisdom not imported')
1114

1215
N = (64, 64, 64)
1316
loops = 50

0 commit comments

Comments
 (0)