3939)
4040
4141from numpy .core import (array , asarray , shape , conjugate , take , sqrt , prod )
42+ from os import cpu_count as os_cpu_count
43+ import warnings
4244
43- _max_threads_count = mkl .get_max_threads ()
45+ class _cpu_max_threads_count :
46+ def __init__ (self ):
47+ self .cpu_count = None
48+ self .max_threads_count = None
4449
50+ def get_cpu_count (self ):
51+ max_threads = self .get_max_threads_count ()
52+ if self .cpu_count is None :
53+ self .cpu_count = os_cpu_count ()
54+ if self .cpu_count > max_threads :
55+ warnings .warn (
56+ ("os.cpu_count() returned value of {} greater than mkl.get_max_threads()'s value of {}. "
57+ "Using negative values of worker option may amount to requesting more threads than "
58+ "Intel(R) MKL can acommodate."
59+ ).format (self .cpu_count , max_threads ))
60+ return self .cpu_count
61+
62+ def get_max_threads_count (self ):
63+ if self .max_threads_count is None :
64+ self .max_threads_count = mkl .get_max_threads ()
65+
66+ return self .max_threads_count
67+
68+
69+ _hardware_counts = _cpu_max_threads_count ()
70+
4571
4672__all__ = ['fft' , 'ifft' , 'fft2' , 'ifft2' , 'fftn' , 'ifftn' ,
4773 'rfft' , 'irfft' , 'rfft2' , 'irfft2' , 'rfftn' , 'irfftn' ,
@@ -113,10 +139,10 @@ def _workers_to_num_threads(w):
113139 if (_w == 0 ):
114140 raise ValueError ("Number of workers must be nonzero" )
115141 if (_w < 0 ):
116- _w += _max_threads_count + 1
142+ _w += _hardware_counts . get_cpu_count () + 1
117143 if _w <= 0 :
118144 raise ValueError ("workers value out of range; got {}, must not be"
119- " less than {}" .format (w , - _max_threads_count ))
145+ " less than {}" .format (w , - _hardware_counts . get_cpu_count () ))
120146 return _w
121147
122148
@@ -133,7 +159,8 @@ def __enter__(self):
133159
134160 def __exit__ (self , * args ):
135161 # restore default
136- mkl .domain_set_num_threads (_max_threads_count , domain = 'fft' )
162+ n_threads = _hardware_counts .get_max_threads_count ()
163+ mkl .domain_set_num_threads (n_threads , domain = 'fft' )
137164
138165
139166@_implements (_fft .fft )
0 commit comments