2828from . import _float_utils
2929import mkl
3030
31- import scipy .fft as _fft
32-
33- # Complete the namespace (these are not actually used in this module)
34- from scipy .fft import (
35- dct , idct , dst , idst , dctn , idctn , dstn , idstn ,
36- hfft2 , ihfft2 , hfftn , ihfftn ,
37- fftshift , ifftshift , fftfreq , rfftfreq ,
38- get_workers , set_workers
39- )
40-
4131from numpy .core import (array , asarray , shape , conjugate , take , sqrt , prod )
4232from os import cpu_count as os_cpu_count
4333import warnings
4434
35+
36+ __doc__ = """
37+ This module implements interfaces mimicing `scipy.fft` module.
38+
39+ It also provides DftiBackend class which can be used to set mkl_fft to be used
40+ via `scipy.fft` namespace.
41+
42+ :Example:
43+ import scipy.fft
44+ import mkl_fft._scipy_fft_backend as be
45+ # Set mkl_fft to be used as backend of SciPy's FFT functions.
46+ scipy.fft.set_global_backend(be)
47+ """
48+
4549class _cpu_max_threads_count :
4650 def __init__ (self ):
4751 self .cpu_count = None
4852 self .max_threads_count = None
4953
5054 def get_cpu_count (self ):
51- max_threads = self .get_max_threads_count ()
5255 if self .cpu_count is None :
56+ max_threads = self .get_max_threads_count ()
5357 self .cpu_count = os_cpu_count ()
5458 if self .cpu_count > max_threads :
5559 warnings .warn (
@@ -76,30 +80,27 @@ def get_max_threads_count(self):
7680 'fftshift' , 'ifftshift' , 'fftfreq' , 'rfftfreq' , 'get_workers' ,
7781 'set_workers' , 'next_fast_len' , 'DftiBackend' ]
7882
83+ __ua_domain__ = "numpy.scipy.fft"
84+
85+ def __ua_function__ (method , args , kwargs ):
86+ """Fetch registered UA function."""
87+ fn = globals ().get (method .__name__ , None )
88+ if fn is None :
89+ return NotImplemented
90+ return fn (* args , ** kwargs )
91+
7992
8093class DftiBackend :
8194 __ua_domain__ = "numpy.scipy.fft"
8295 @staticmethod
8396 def __ua_function__ (method , args , kwargs ):
8497 """Fetch registered UA function."""
85- fn = __implemented .get (method , None )
98+ fn = globals () .get (method . __name__ , None )
8699 if fn is None :
87100 return NotImplemented
88101 return fn (* args , ** kwargs )
89102
90103
91- __implemented = dict ()
92-
93-
94- def _implements (scipy_func ):
95- """Decorator adds function to the dictionary of implemented UA functions"""
96- def inner (func ):
97- __implemented [scipy_func ] = func
98- return func
99-
100- return inner
101-
102-
103104def _unitary (norm ):
104105 if norm not in (None , "ortho" ):
105106 raise ValueError ("Invalid norm value %s, should be None or \" ortho\" ."
@@ -138,7 +139,7 @@ def _workers_to_num_threads(w):
138139 same way as scipy.fft.helpers._workers.
139140 """
140141 if w is None :
141- return get_workers ()
142+ return _hardware_counts . get_cpu_count ()
142143 _w = int (w )
143144 if (_w == 0 ):
144145 raise ValueError ("Number of workers must be nonzero" )
@@ -166,7 +167,6 @@ def __exit__(self, *args):
166167 mkl .set_num_threads_local (self .prev_num_threads )
167168
168169
169- @_implements (_fft .fft )
170170def fft (a , n = None , axis = - 1 , norm = None , overwrite_x = False , workers = None ):
171171 x = _float_utils .__upcast_float16_array (a )
172172 with Workers (workers ):
@@ -176,7 +176,6 @@ def fft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
176176 return output
177177
178178
179- @_implements (_fft .ifft )
180179def ifft (a , n = None , axis = - 1 , norm = None , overwrite_x = False , workers = None ):
181180 x = _float_utils .__upcast_float16_array (a )
182181 with Workers (workers ):
@@ -186,7 +185,6 @@ def ifft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
186185 return output
187186
188187
189- @_implements (_fft .fft2 )
190188def fft2 (a , s = None , axes = (- 2 ,- 1 ), norm = None , overwrite_x = False , workers = None ):
191189 x = _float_utils .__upcast_float16_array (a )
192190 with Workers (workers ):
@@ -199,7 +197,6 @@ def fft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
199197 return output
200198
201199
202- @_implements (_fft .ifft2 )
203200def ifft2 (a , s = None , axes = (- 2 ,- 1 ), norm = None , overwrite_x = False , workers = None ):
204201 x = _float_utils .__upcast_float16_array (a )
205202 with Workers (workers ):
@@ -213,7 +210,6 @@ def ifft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
213210 return output
214211
215212
216- @_implements (_fft .fftn )
217213def fftn (a , s = None , axes = None , norm = None , overwrite_x = False , workers = None ):
218214 x = _float_utils .__upcast_float16_array (a )
219215 with Workers (workers ):
@@ -227,7 +223,6 @@ def fftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
227223 return output
228224
229225
230- @_implements (_fft .ifftn )
231226def ifftn (a , s = None , axes = None , norm = None , overwrite_x = False , workers = None ):
232227 x = _float_utils .__upcast_float16_array (a )
233228 with Workers (workers ):
@@ -241,7 +236,6 @@ def ifftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
241236 return output
242237
243238
244- @_implements (_fft .rfft )
245239def rfft (a , n = None , axis = - 1 , norm = None , workers = None ):
246240 x = _float_utils .__upcast_float16_array (a )
247241 unitary = _unitary (norm )
@@ -256,7 +250,6 @@ def rfft(a, n=None, axis=-1, norm=None, workers=None):
256250 return output
257251
258252
259- @_implements (_fft .irfft )
260253def irfft (a , n = None , axis = - 1 , norm = None , workers = None ):
261254 x = _float_utils .__upcast_float16_array (a )
262255 x = _float_utils .__downcast_float128_array (x )
@@ -267,21 +260,18 @@ def irfft(a, n=None, axis=-1, norm=None, workers=None):
267260 return output
268261
269262
270- @_implements (_fft .rfft2 )
271263def rfft2 (a , s = None , axes = (- 2 , - 1 ), norm = None , workers = None ):
272264 x = _float_utils .__upcast_float16_array (a )
273265 x = _float_utils .__downcast_float128_array (a )
274266 return rfftn (x , s , axes , norm , workers )
275267
276268
277- @_implements (_fft .irfft2 )
278269def irfft2 (a , s = None , axes = (- 2 , - 1 ), norm = None , workers = None ):
279270 x = _float_utils .__upcast_float16_array (a )
280271 x = _float_utils .__downcast_float128_array (x )
281272 return irfftn (x , s , axes , norm , workers )
282273
283274
284- @_implements (_fft .rfftn )
285275def rfftn (a , s = None , axes = None , norm = None , workers = None ):
286276 unitary = _unitary (norm )
287277 x = _float_utils .__upcast_float16_array (a )
@@ -297,7 +287,6 @@ def rfftn(a, s=None, axes=None, norm=None, workers=None):
297287 return output
298288
299289
300- @_implements (_fft .irfftn )
301290def irfftn (a , s = None , axes = None , norm = None , workers = None ):
302291 x = _float_utils .__upcast_float16_array (a )
303292 x = _float_utils .__downcast_float128_array (x )
0 commit comments