@@ -100,10 +100,34 @@ def _tot_size(x, axes):
100100 return prod ([s [ai ] for ai in axes ])
101101
102102
103+ def _workers_to_num_threads (w ):
104+ if w is None :
105+ return mkl .domain_get_max_threads (domain = 'fft' )
106+ return int (w )
107+
108+
109+ class Workers :
110+ def __init__ (self , workers ):
111+ self .workers = workers
112+ self .n_threads = _workers_to_num_threads (workers )
113+
114+ def __enter__ (self ):
115+ try :
116+ mkl .domain_set_num_threads (self .n_threads , domain = 'fft' )
117+ except :
118+ raise ValueError ("Class argument {} result in invalid number of threads {}" .format (self .workers , self .n_threads ))
119+
120+ def __exit__ (self , * args ):
121+ # restore default
122+ max_num_threads = mkl .domain_get_max_threads (domain = 'fft' )
123+ mkl .domain_set_num_threads (max_num_threads , domain = 'fft' )
124+
125+
103126@_implements (_fft .fft )
104127def fft (a , n = None , axis = - 1 , norm = None , overwrite_x = False , workers = None ):
105128 x = _float_utils .__upcast_float16_array (a )
106- output = _pydfti .fft (x , n = n , axis = axis , overwrite_x = overwrite_x )
129+ with Workers (workers ):
130+ output = _pydfti .fft (x , n = n , axis = axis , overwrite_x = overwrite_x )
107131 if _unitary (norm ):
108132 output *= 1 / sqrt (output .shape [axis ])
109133 return output
@@ -112,7 +136,8 @@ def fft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
112136@_implements (_fft .ifft )
113137def ifft (a , n = None , axis = - 1 , norm = None , overwrite_x = False , workers = None ):
114138 x = _float_utils .__upcast_float16_array (a )
115- output = _pydfti .ifft (x , n = n , axis = axis , overwrite_x = overwrite_x )
139+ with Workers (workers ):
140+ output = _pydfti .ifft (x , n = n , axis = axis , overwrite_x = overwrite_x )
116141 if _unitary (norm ):
117142 output *= sqrt (output .shape [axis ])
118143 return output
@@ -121,7 +146,8 @@ def ifft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
121146@_implements (_fft .fft2 )
122147def fft2 (a , s = None , axes = (- 2 ,- 1 ), norm = None , overwrite_x = False , workers = None ):
123148 x = _float_utils .__upcast_float16_array (a )
124- output = _pydfti .fftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
149+ with Workers (workers ):
150+ output = _pydfti .fftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
125151 if _unitary (norm ):
126152 factor = 1
127153 for axis in axes :
@@ -133,7 +159,8 @@ def fft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
133159@_implements (_fft .ifft2 )
134160def ifft2 (a , s = None , axes = (- 2 ,- 1 ), norm = None , overwrite_x = False , workers = None ):
135161 x = _float_utils .__upcast_float16_array (a )
136- output = _pydfti .ifftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
162+ with Workers (workers ):
163+ output = _pydfti .ifftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
137164 if _unitary (norm ):
138165 factor = 1
139166 _axes = range (output .ndim ) if axes is None else axes
@@ -146,7 +173,8 @@ def ifft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
146173@_implements (_fft .fftn )
147174def fftn (a , s = None , axes = None , norm = None , overwrite_x = False , workers = None ):
148175 x = _float_utils .__upcast_float16_array (a )
149- output = _pydfti .fftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
176+ with Workers (workers ):
177+ output = _pydfti .fftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
150178 if _unitary (norm ):
151179 factor = 1
152180 _axes = range (output .ndim ) if axes is None else axes
@@ -159,7 +187,8 @@ def fftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
159187@_implements (_fft .ifftn )
160188def ifftn (a , s = None , axes = None , norm = None , overwrite_x = False , workers = None ):
161189 x = _float_utils .__upcast_float16_array (a )
162- output = _pydfti .ifftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
190+ with Workers (workers ):
191+ output = _pydfti .ifftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
163192 if _unitary (norm ):
164193 factor = 1
165194 _axes = range (output .ndim ) if axes is None else axes
@@ -170,64 +199,67 @@ def ifftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
170199
171200
172201@_implements (_fft .rfft )
173- def rfft (a , n = None , axis = - 1 , norm = None ):
202+ def rfft (a , n = None , axis = - 1 , norm = None , workers = None ):
174203 x = _float_utils .__upcast_float16_array (a )
175204 unitary = _unitary (norm )
176205 x = _float_utils .__downcast_float128_array (x )
177206 if unitary and n is None :
178207 x = asarray (x )
179208 n = x .shape [axis ]
180- output = _pydfti .rfft_numpy (x , n = n , axis = axis )
209+ with Workers (workers ):
210+ output = _pydfti .rfft_numpy (x , n = n , axis = axis )
181211 if unitary :
182212 output *= 1 / sqrt (n )
183213 return output
184214
185215
186216@_implements (_fft .irfft )
187- def irfft (a , n = None , axis = - 1 , norm = None ):
217+ def irfft (a , n = None , axis = - 1 , norm = None , workers = None ):
188218 x = _float_utils .__upcast_float16_array (a )
189219 x = _float_utils .__downcast_float128_array (x )
190- output = _pydfti .irfft_numpy (x , n = n , axis = axis )
220+ with Workers (workers ):
221+ output = _pydfti .irfft_numpy (x , n = n , axis = axis )
191222 if _unitary (norm ):
192223 output *= sqrt (output .shape [axis ])
193224 return output
194225
195226
196227@_implements (_fft .rfft2 )
197- def rfft2 (a , s = None , axes = (- 2 , - 1 ), norm = None ):
228+ def rfft2 (a , s = None , axes = (- 2 , - 1 ), norm = None , workers = None ):
198229 x = _float_utils .__upcast_float16_array (a )
199230 x = _float_utils .__downcast_float128_array (a )
200- return rfftn (x , s , axes , norm )
231+ return rfftn (x , s , axes , norm , workers )
201232
202233
203234@_implements (_fft .irfft2 )
204- def irfft2 (a , s = None , axes = (- 2 , - 1 ), norm = None ):
235+ def irfft2 (a , s = None , axes = (- 2 , - 1 ), norm = None , workers = None ):
205236 x = _float_utils .__upcast_float16_array (a )
206237 x = _float_utils .__downcast_float128_array (x )
207- return irfftn (x , s , axes , norm )
238+ return irfftn (x , s , axes , norm , workers )
208239
209240
210241@_implements (_fft .rfftn )
211- def rfftn (a , s = None , axes = None , norm = None ):
242+ def rfftn (a , s = None , axes = None , norm = None , workers = None ):
212243 unitary = _unitary (norm )
213244 x = _float_utils .__upcast_float16_array (a )
214245 x = _float_utils .__downcast_float128_array (x )
215246 if unitary :
216247 x = asarray (x )
217248 s , axes = _cook_nd_args (x , s , axes )
218-
219- output = _pydfti .rfftn_numpy (x , s , axes )
249+ with Workers ( workers ):
250+ output = _pydfti .rfftn_numpy (x , s , axes )
220251 if unitary :
221252 n_tot = prod (asarray (s , dtype = output .dtype ))
222253 output *= 1 / sqrt (n_tot )
223254 return output
224255
225256
226257@_implements (_fft .irfftn )
227- def irfftn (a , s = None , axes = None , norm = None ):
258+ def irfftn (a , s = None , axes = None , norm = None , workers = None ):
228259 x = _float_utils .__upcast_float16_array (a )
229260 x = _float_utils .__downcast_float128_array (x )
230- output = _pydfti .irfftn_numpy (x , s , axes )
261+ with Workers (workers ):
262+ output = _pydfti .irfftn_numpy (x , s , axes )
231263 if _unitary (norm ):
232264 output *= sqrt (_tot_size (output , axes ))
233265 return output
0 commit comments