2626
2727import mkl_fft .interfaces as mfi
2828import pytest
29+ import numpy as np
2930
3031
3132def test_interfaces_has_numpy ():
@@ -34,3 +35,43 @@ def test_interfaces_has_numpy():
3435
3536def test_interfaces_has_scipy ():
3637 assert hasattr (mfi , 'scipy_fft' )
38+
39+
40+ @pytest .mark .parametrize ('norm' , [None , "forward" , "backward" , "ortho" ])
41+ @pytest .mark .parametrize ('dtype' , [np .float32 , np .float64 , np .complex64 , np .complex128 ])
42+ def test_scipy_fft (norm , dtype ):
43+ x = np .ones (511 , dtype = dtype )
44+ w = mfi .scipy_fft .fft (x , norm = norm )
45+ xx = mfi .scipy_fft .ifft (w , norm = norm )
46+ tol = 64 * np .finfo (np .dtype (dtype )).eps
47+ assert np .allclose (x , xx , atol = tol , rtol = tol )
48+
49+
50+ @pytest .mark .parametrize ('norm' , [None , "forward" , "backward" , "ortho" ])
51+ @pytest .mark .parametrize ('dtype' , [np .float32 , np .float64 ])
52+ def test_scipy_rfft (norm , dtype ):
53+ x = np .ones (511 , dtype = dtype )
54+ w = mfi .scipy_fft .rfft (x , norm = norm )
55+ xx = mfi .scipy_fft .irfft (w , n = x .shape [0 ], norm = norm )
56+ tol = 64 * np .finfo (np .dtype (dtype )).eps
57+ assert np .allclose (x , xx , atol = tol , rtol = tol )
58+
59+
60+ @pytest .mark .parametrize ('norm' , [None , "forward" , "backward" , "ortho" ])
61+ @pytest .mark .parametrize ('dtype' , [np .float32 , np .float64 , np .complex64 , np .complex128 ])
62+ def test_scipy_fftn (norm , dtype ):
63+ x = np .ones ((37 , 83 ), dtype = dtype )
64+ w = mfi .scipy_fft .fftn (x , norm = norm )
65+ xx = mfi .scipy_fft .ifftn (w , norm = norm )
66+ tol = 64 * np .finfo (np .dtype (dtype )).eps
67+ assert np .allclose (x , xx , atol = tol , rtol = tol )
68+
69+
70+ @pytest .mark .parametrize ('norm' , [None , "forward" , "backward" , "ortho" ])
71+ @pytest .mark .parametrize ('dtype' , [np .float32 , np .float64 ])
72+ def test_scipy_rftn (norm , dtype ):
73+ x = np .ones ((37 , 83 ), dtype = dtype )
74+ w = mfi .scipy_fft .rfftn (x , norm = norm )
75+ xx = mfi .scipy_fft .ifftn (w , s = x .shape , norm = norm )
76+ tol = 64 * np .finfo (np .dtype (dtype )).eps
77+ assert np .allclose (x , xx , atol = tol , rtol = tol )
0 commit comments