2929
3030import mkl_fft .interfaces as mfi
3131
32+ try :
33+ scipy_fft = mfi .scipy_fft
34+ except AttributeError :
35+ scipy_fft = None
36+
37+ interfaces = []
38+ ids = []
39+ if scipy_fft is not None :
40+ interfaces .append (scipy_fft )
41+ ids .append ("scipy" )
42+ interfaces .append (mfi .numpy_fft )
43+ ids .append ("numpy" )
44+
3245
3346@pytest .mark .parametrize ("norm" , [None , "forward" , "backward" , "ortho" ])
3447@pytest .mark .parametrize (
3548 "dtype" , [np .float32 , np .float64 , np .complex64 , np .complex128 ]
3649)
3750def test_scipy_fft (norm , dtype ):
51+ pytest .importorskip ("scipy" , reason = "requires scipy" )
3852 x = np .ones (511 , dtype = dtype )
3953 w = mfi .scipy_fft .fft (x , norm = norm , workers = None , plan = None )
4054 xx = mfi .scipy_fft .ifft (w , norm = norm , workers = None , plan = None )
@@ -57,6 +71,7 @@ def test_numpy_fft(norm, dtype):
5771@pytest .mark .parametrize ("norm" , [None , "forward" , "backward" , "ortho" ])
5872@pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
5973def test_scipy_rfft (norm , dtype ):
74+ pytest .importorskip ("scipy" , reason = "requires scipy" )
6075 x = np .ones (511 , dtype = dtype )
6176 w = mfi .scipy_fft .rfft (x , norm = norm , workers = None , plan = None )
6277 xx = mfi .scipy_fft .irfft (
@@ -87,6 +102,7 @@ def test_numpy_rfft(norm, dtype):
87102 "dtype" , [np .float32 , np .float64 , np .complex64 , np .complex128 ]
88103)
89104def test_scipy_fftn (norm , dtype ):
105+ pytest .importorskip ("scipy" , reason = "requires scipy" )
90106 x = np .ones ((37 , 83 ), dtype = dtype )
91107 w = mfi .scipy_fft .fftn (x , norm = norm , workers = None , plan = None )
92108 xx = mfi .scipy_fft .ifftn (w , norm = norm , workers = None , plan = None )
@@ -109,6 +125,7 @@ def test_numpy_fftn(norm, dtype):
109125@pytest .mark .parametrize ("norm" , [None , "forward" , "backward" , "ortho" ])
110126@pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
111127def test_scipy_rfftn (norm , dtype ):
128+ pytest .importorskip ("scipy" , reason = "requires scipy" )
112129 x = np .ones ((37 , 83 ), dtype = dtype )
113130 w = mfi .scipy_fft .rfftn (x , norm = norm , workers = None , plan = None )
114131 xx = mfi .scipy_fft .irfftn (w , s = x .shape , norm = norm , workers = None , plan = None )
@@ -143,32 +160,30 @@ def _get_blacklisted_dtypes():
143160
144161@pytest .mark .parametrize ("dtype" , _get_blacklisted_dtypes ())
145162def test_scipy_no_support_for (dtype ):
163+ pytest .importorskip ("scipy" , reason = "requires scipy" )
146164 x = np .ones (16 , dtype = dtype )
147165 assert_raises (NotImplementedError , mfi .scipy_fft .ifft , x )
148166
149167
150168def test_scipy_fft_arg_validate ():
169+ pytest .importorskip ("scipy" , reason = "requires scipy" )
151170 with pytest .raises (ValueError ):
152171 mfi .scipy_fft .fft ([1 , 2 , 3 , 4 ], norm = b"invalid" )
153172
154173 with pytest .raises (NotImplementedError ):
155174 mfi .scipy_fft .fft ([1 , 2 , 3 , 4 ], plan = "magic" )
156175
157176
158- @pytest .mark .parametrize (
159- "func" , [mfi .scipy_fft .rfft2 , mfi .numpy_fft .rfft2 ], ids = ["scipy" , "numpy" ]
160- )
161- def test_axes (func ):
177+ @pytest .mark .parametrize ("interface" , interfaces , ids = ids )
178+ def test_axes (interface ):
162179 x = np .arange (24.0 ).reshape (2 , 3 , 4 )
163- res = func (x , axes = (1 , 2 ))
180+ res = interface . rfft2 (x , axes = (1 , 2 ))
164181 exp = np .fft .rfft2 (x , axes = (1 , 2 ))
165182 tol = 64 * np .finfo (np .float64 ).eps
166183 assert np .allclose (res , exp , atol = tol , rtol = tol )
167184
168185
169- @pytest .mark .parametrize (
170- "interface" , [mfi .scipy_fft , mfi .numpy_fft ], ids = ["scipy" , "numpy" ]
171- )
186+ @pytest .mark .parametrize ("interface" , interfaces , ids = ids )
172187@pytest .mark .parametrize (
173188 "func" , ["fftshift" , "ifftshift" , "fftfreq" , "rfftfreq" ]
174189)
0 commit comments