4040
4141from numpy .core import (array , asarray , shape , conjugate , take , sqrt , prod )
4242
43+ _max_threads_count = mkl .get_max_threads ()
44+
45+
4346__all__ = ['fft' , 'ifft' , 'fft2' , 'ifft2' , 'fftn' , 'ifftn' ,
4447 'rfft' , 'irfft' , 'rfft2' , 'irfft2' , 'rfftn' , 'irfftn' ,
4548 'hfft' , 'ihfft' , 'hfft2' , 'ihfft2' , 'hfftn' , 'ihfftn' ,
@@ -101,9 +104,20 @@ def _tot_size(x, axes):
101104
102105
103106def _workers_to_num_threads (w ):
107+ """Handle conversion of workers to a positive number of threads in the
108+ same way as scipy.fft.helpers._workers.
109+ """
104110 if w is None :
105- return mkl .domain_get_max_threads (domain = 'fft' )
106- return int (w )
111+ return get_workers ()
112+ _w = int (w )
113+ if (_w == 0 ):
114+ raise ValueError ("Number of workers must be nonzero" )
115+ if (_w < 0 ):
116+ _w += _max_threads_count + 1
117+ if _w <= 0 :
118+ raise ValueError ("workers value out of range; got {}, must not be"
119+ " less than {}" .format (w , - _max_threads_count ))
120+ return _w
107121
108122
109123class Workers :
@@ -119,8 +133,7 @@ def __enter__(self):
119133
120134 def __exit__ (self , * args ):
121135 # restore default
122- max_num_threads = mkl .domain_get_max_threads (domain = 'fft' )
123- mkl .domain_set_num_threads (max_num_threads , domain = 'fft' )
136+ mkl .domain_set_num_threads (_max_threads_count , domain = 'fft' )
124137
125138
126139@_implements (_fft .fft )
0 commit comments