3939)
4040
4141from numpy .core import (array , asarray , shape , conjugate , take , sqrt , prod )
42+ from os import cpu_count as os_cpu_count
43+ import warnings
44+
45+ class _cpu_max_threads_count :
46+ def __init__ (self ):
47+ self .cpu_count = None
48+ self .max_threads_count = None
49+
50+ def get_cpu_count (self ):
51+ max_threads = self .get_max_threads_count ()
52+ if self .cpu_count is None :
53+ self .cpu_count = os_cpu_count ()
54+ if self .cpu_count > max_threads :
55+ warnings .warn (
56+ ("os.cpu_count() returned value of {} greater than mkl.get_max_threads()'s value of {}. "
57+ "Using negative values of worker option may amount to requesting more threads than "
58+ "Intel(R) MKL can acommodate."
59+ ).format (self .cpu_count , max_threads ))
60+ return self .cpu_count
61+
62+ def get_max_threads_count (self ):
63+ if self .max_threads_count is None :
64+ self .max_threads_count = mkl .get_max_threads ()
65+
66+ return self .max_threads_count
67+
68+
69+ _hardware_counts = _cpu_max_threads_count ()
70+
4271
4372__all__ = ['fft' , 'ifft' , 'fft2' , 'ifft2' , 'fftn' , 'ifftn' ,
4473 'rfft' , 'irfft' , 'rfft2' , 'irfft2' , 'rfftn' , 'irfftn' ,
@@ -101,9 +130,20 @@ def _tot_size(x, axes):
101130
102131
103132def _workers_to_num_threads (w ):
133+ """Handle conversion of workers to a positive number of threads in the
134+ same way as scipy.fft.helpers._workers.
135+ """
104136 if w is None :
105- return mkl .domain_get_max_threads (domain = 'fft' )
106- return int (w )
137+ return get_workers ()
138+ _w = int (w )
139+ if (_w == 0 ):
140+ raise ValueError ("Number of workers must be nonzero" )
141+ if (_w < 0 ):
142+ _w += _hardware_counts .get_cpu_count () + 1
143+ if _w <= 0 :
144+ raise ValueError ("workers value out of range; got {}, must not be"
145+ " less than {}" .format (w , - _hardware_counts .get_cpu_count ()))
146+ return _w
107147
108148
109149class Workers :
@@ -119,8 +159,8 @@ def __enter__(self):
119159
120160 def __exit__ (self , * args ):
121161 # restore default
122- max_num_threads = mkl . domain_get_max_threads ( domain = 'fft' )
123- mkl .domain_set_num_threads (max_num_threads , domain = 'fft' )
162+ n_threads = _hardware_counts . get_max_threads_count ( )
163+ mkl .domain_set_num_threads (n_threads , domain = 'fft' )
124164
125165
126166@_implements (_fft .fft )
0 commit comments