Skip to content

Commit 9951eeb

Browse files
Fixed _scipy_fft_backend
1 parent 50e10bd commit 9951eeb

File tree

1 file changed

+26
-37
lines changed

1 file changed

+26
-37
lines changed

mkl_fft/_scipy_fft_backend.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,28 +28,32 @@
2828
from . import _float_utils
2929
import mkl
3030

31-
import scipy.fft as _fft
32-
33-
# Complete the namespace (these are not actually used in this module)
34-
from scipy.fft import (
35-
dct, idct, dst, idst, dctn, idctn, dstn, idstn,
36-
hfft2, ihfft2, hfftn, ihfftn,
37-
fftshift, ifftshift, fftfreq, rfftfreq,
38-
get_workers, set_workers
39-
)
40-
4131
from numpy.core import (array, asarray, shape, conjugate, take, sqrt, prod)
4232
from os import cpu_count as os_cpu_count
4333
import warnings
4434

35+
36+
__doc__ = """
37+
This module implements interfaces mimicing `scipy.fft` module.
38+
39+
It also provides DftiBackend class which can be used to set mkl_fft to be used
40+
via `scipy.fft` namespace.
41+
42+
:Example:
43+
import scipy.fft
44+
import mkl_fft._scipy_fft_backend as be
45+
# Set mkl_fft to be used as backend of SciPy's FFT functions.
46+
scipy.fft.set_global_backend(be)
47+
"""
48+
4549
class _cpu_max_threads_count:
4650
def __init__(self):
4751
self.cpu_count = None
4852
self.max_threads_count = None
4953

5054
def get_cpu_count(self):
51-
max_threads = self.get_max_threads_count()
5255
if self.cpu_count is None:
56+
max_threads = self.get_max_threads_count()
5357
self.cpu_count = os_cpu_count()
5458
if self.cpu_count > max_threads:
5559
warnings.warn(
@@ -76,30 +80,27 @@ def get_max_threads_count(self):
7680
'fftshift', 'ifftshift', 'fftfreq', 'rfftfreq', 'get_workers',
7781
'set_workers', 'next_fast_len', 'DftiBackend']
7882

83+
__ua_domain__ = "numpy.scipy.fft"
84+
85+
def __ua_function__(method, args, kwargs):
86+
"""Fetch registered UA function."""
87+
fn = globals().get(method.__name__, None)
88+
if fn is None:
89+
return NotImplemented
90+
return fn(*args, **kwargs)
91+
7992

8093
class DftiBackend:
8194
__ua_domain__ = "numpy.scipy.fft"
8295
@staticmethod
8396
def __ua_function__(method, args, kwargs):
8497
"""Fetch registered UA function."""
85-
fn = __implemented.get(method, None)
98+
fn = globals().get(method.__name__, None)
8699
if fn is None:
87100
return NotImplemented
88101
return fn(*args, **kwargs)
89102

90103

91-
__implemented = dict()
92-
93-
94-
def _implements(scipy_func):
95-
"""Decorator adds function to the dictionary of implemented UA functions"""
96-
def inner(func):
97-
__implemented[scipy_func] = func
98-
return func
99-
100-
return inner
101-
102-
103104
def _unitary(norm):
104105
if norm not in (None, "ortho"):
105106
raise ValueError("Invalid norm value %s, should be None or \"ortho\"."
@@ -138,7 +139,7 @@ def _workers_to_num_threads(w):
138139
same way as scipy.fft.helpers._workers.
139140
"""
140141
if w is None:
141-
return get_workers()
142+
return _hardware_counts.get_cpu_count()
142143
_w = int(w)
143144
if (_w == 0):
144145
raise ValueError("Number of workers must be nonzero")
@@ -166,7 +167,6 @@ def __exit__(self, *args):
166167
mkl.set_num_threads_local(self.prev_num_threads)
167168

168169

169-
@_implements(_fft.fft)
170170
def fft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
171171
x = _float_utils.__upcast_float16_array(a)
172172
with Workers(workers):
@@ -176,7 +176,6 @@ def fft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
176176
return output
177177

178178

179-
@_implements(_fft.ifft)
180179
def ifft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
181180
x = _float_utils.__upcast_float16_array(a)
182181
with Workers(workers):
@@ -186,7 +185,6 @@ def ifft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
186185
return output
187186

188187

189-
@_implements(_fft.fft2)
190188
def fft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
191189
x = _float_utils.__upcast_float16_array(a)
192190
with Workers(workers):
@@ -199,7 +197,6 @@ def fft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
199197
return output
200198

201199

202-
@_implements(_fft.ifft2)
203200
def ifft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
204201
x = _float_utils.__upcast_float16_array(a)
205202
with Workers(workers):
@@ -213,7 +210,6 @@ def ifft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
213210
return output
214211

215212

216-
@_implements(_fft.fftn)
217213
def fftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
218214
x = _float_utils.__upcast_float16_array(a)
219215
with Workers(workers):
@@ -227,7 +223,6 @@ def fftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
227223
return output
228224

229225

230-
@_implements(_fft.ifftn)
231226
def ifftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
232227
x = _float_utils.__upcast_float16_array(a)
233228
with Workers(workers):
@@ -241,7 +236,6 @@ def ifftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
241236
return output
242237

243238

244-
@_implements(_fft.rfft)
245239
def rfft(a, n=None, axis=-1, norm=None, workers=None):
246240
x = _float_utils.__upcast_float16_array(a)
247241
unitary = _unitary(norm)
@@ -256,7 +250,6 @@ def rfft(a, n=None, axis=-1, norm=None, workers=None):
256250
return output
257251

258252

259-
@_implements(_fft.irfft)
260253
def irfft(a, n=None, axis=-1, norm=None, workers=None):
261254
x = _float_utils.__upcast_float16_array(a)
262255
x = _float_utils.__downcast_float128_array(x)
@@ -267,21 +260,18 @@ def irfft(a, n=None, axis=-1, norm=None, workers=None):
267260
return output
268261

269262

270-
@_implements(_fft.rfft2)
271263
def rfft2(a, s=None, axes=(-2, -1), norm=None, workers=None):
272264
x = _float_utils.__upcast_float16_array(a)
273265
x = _float_utils.__downcast_float128_array(a)
274266
return rfftn(x, s, axes, norm, workers)
275267

276268

277-
@_implements(_fft.irfft2)
278269
def irfft2(a, s=None, axes=(-2, -1), norm=None, workers=None):
279270
x = _float_utils.__upcast_float16_array(a)
280271
x = _float_utils.__downcast_float128_array(x)
281272
return irfftn(x, s, axes, norm, workers)
282273

283274

284-
@_implements(_fft.rfftn)
285275
def rfftn(a, s=None, axes=None, norm=None, workers=None):
286276
unitary = _unitary(norm)
287277
x = _float_utils.__upcast_float16_array(a)
@@ -297,7 +287,6 @@ def rfftn(a, s=None, axes=None, norm=None, workers=None):
297287
return output
298288

299289

300-
@_implements(_fft.irfftn)
301290
def irfftn(a, s=None, axes=None, norm=None, workers=None):
302291
x = _float_utils.__upcast_float16_array(a)
303292
x = _float_utils.__downcast_float128_array(x)

0 commit comments

Comments
 (0)