Skip to content

Commit 9b09967

Browse files
committed
Merged in dalcinl/checks (pull request #13)
Add error checking for NULL FFTW plan pointers
2 parents 8f37fa3 + 9294bb1 commit 9b09967

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

mpi4py_fft/fftw/fftw_xfftn.pyx

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ cdef class FFT:
149149
knd[i] = kind[i]
150150
self._plan = fftw_planxfftn(ndims, sz_in, _in, sz_out, _out, naxes,
151151
axs, knd, allflags)
152+
if self._plan == NULL:
153+
raise RuntimeError("Failure creating FFTW plan")
152154
free(sz_in)
153155
free(sz_out)
154156
free(axs)
@@ -169,6 +171,7 @@ cdef class FFT:
169171
return self._output_array
170172

171173
def print_plan(self):
174+
assert self._plan != NULL
172175
fftw_print_plan(<fftw_plan>self._plan)
173176

174177
def update_arrays(self, input_array, output_array):
@@ -235,6 +238,7 @@ cdef class FFT:
235238
"""Apply plan with explicit (and safe) update of work arrays"""
236239
if input_array is not None:
237240
self._input_array[...] = input_array
241+
assert self._plan != NULL
238242
with nogil:
239243
fftw_execute(<fftw_plan>self._plan)
240244
if normalize:
@@ -283,6 +287,7 @@ cdef class FFT:
283287

284288
_in = <void *>np.PyArray_DATA(input_array)
285289
_out = <void *>np.PyArray_DATA(output_array)
290+
assert self._plan != NULL
286291
with nogil:
287292
apply_plan(<fftw_plan>self._plan, _in, _out)
288293
if normalize:

tests/test_fftw.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def test_fftw():
103103
idct = fftw.idctn(input_array, None, axes, type, threads, iflags, output_array=oa)
104104
A2 = idct(B, implicit=True, normalize=True)
105105
assert allclose(A, A2), np.linalg.norm(A-A2)
106-
if typecode is not 'g' and not type is 4:
106+
if typecode != 'g' and type != 4:
107107
B2 = scipy_dctn(A, axes=axes, type=type)
108108
assert allclose(B, B2), np.linalg.norm(B-B2)
109109

@@ -112,7 +112,7 @@ def test_fftw():
112112
idst = fftw.idstn(input_array, None, axes, type, threads, iflags, output_array=oa)
113113
A2 = idst(B, implicit=True, normalize=True)
114114
assert allclose(A, A2), np.linalg.norm(A-A2)
115-
if typecode is not 'g' and not type is 4:
115+
if typecode != 'g' and type != 4:
116116
B2 = scipy_dstn(A, axes=axes, type=type)
117117
assert allclose(B, B2), np.linalg.norm(B-B2)
118118

@@ -132,7 +132,7 @@ def test_fftw():
132132
M = fftw.get_normalization(kds, input_array.shape, axes)
133133
assert allclose(C2*M, A)
134134
# Test vs scipy for transforms available in scipy
135-
if typecode is not 'g' and not any(f in kds for f in (fftw.FFTW_RODFT11, fftw.FFTW_REDFT11)):
135+
if typecode != 'g' and not any(f in kds for f in (fftw.FFTW_RODFT11, fftw.FFTW_REDFT11)):
136136
for m, ts in enumerate(tsf):
137137
A = eval('scipy.fftpack.'+ts[:-1])(A, axis=axes[m], type=int(ts[-1]))
138138
assert allclose(C, A), np.linalg.norm(C-A)

0 commit comments

Comments
 (0)