@@ -34,18 +34,50 @@ except ImportError:
3434 from numpy .core ._multiarray_tests import internal_overlap
3535
3636from libc .string cimport memcpy
37+ cimport cpython .pycapsule
38+ from cpython .exc cimport (PyErr_Occurred , PyErr_Clear )
39+ from cpython .mem cimport (PyMem_Malloc , PyMem_Free )
40+
41+ from threading import local as threading_local
42+
43+ # thread-local storage
44+ _tls = threading_local ()
45+
46+ cdef const char * capsule_name = "dfti_cache"
47+
48+ cdef void _capsule_destructor (object caps ):
49+ cdef DftiCache * _cache = NULL
50+ cdef int status = 0
51+ if (caps is None ):
52+ print ("Nothing to destroy" )
53+ return
54+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (caps , capsule_name )
55+ status = _free_dfti_cache (_cache )
56+ PyMem_Free (_cache )
57+ if (status != 0 ):
58+ raise ValueError ("Internal Error: Freeing DFTI Cache returned with error = {}" .format (status ))
59+
60+
61+ def _tls_dfti_cache_capsule ():
62+ cdef DftiCache * _cache_struct
63+
64+ init = getattr (_tls , 'initialized' , None )
65+ if (init is None ):
66+ _cache_struct = < DftiCache * > PyMem_Malloc (sizeof (DftiCache ));
67+ # important to initialized
68+ _cache_struct .initialized = 0
69+ _cache_struct .hand = NULL
70+ _tls .initialized = True
71+ _tls .capsule = cpython .pycapsule .PyCapsule_New (< void * > _cache_struct , capsule_name , & _capsule_destructor )
72+ capsule = getattr (_tls , 'capsule' , None )
73+ if (not cpython .pycapsule .PyCapsule_IsValid (capsule , capsule_name )):
74+ raise ValueError ("Internal Error: invalid capsule stored in TLS" )
75+ return capsule
3776
38- from threading import Lock
39- _lock = Lock ()
4077
4178cdef extern from "Python.h" :
4279 ctypedef int size_t
4380
44- void * PyMem_Malloc (size_t n )
45- void PyMem_Free (void * buf )
46-
47- int PyErr_Occurred ()
48- void PyErr_Clear ()
4981 long PyInt_AsLong (object ob )
5082 int PyObject_HasAttrString (object , char * )
5183
@@ -58,32 +90,36 @@ cdef extern from *:
5890 object PyArray_BASE (cnp .ndarray )
5991
6092cdef extern from "src/mklfft.h" :
61- int cdouble_mkl_fft1d_in (cnp .ndarray , int , int )
62- int cfloat_mkl_fft1d_in (cnp .ndarray , int , int )
63- int float_cfloat_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , int )
64- int cfloat_cfloat_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray )
65- int double_cdouble_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , int )
66- int cdouble_cdouble_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray )
67-
68- int cdouble_mkl_ifft1d_in (cnp .ndarray , int , int )
69- int cfloat_mkl_ifft1d_in (cnp .ndarray , int , int )
70- int float_cfloat_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , int )
71- int cfloat_cfloat_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray )
72- int double_cdouble_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , int )
73- int cdouble_cdouble_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray )
74-
75- int double_mkl_rfft_in (cnp .ndarray , int , int )
76- int double_mkl_irfft_in (cnp .ndarray , int , int )
77- int float_mkl_rfft_in (cnp .ndarray , int , int )
78- int float_mkl_irfft_in (cnp .ndarray , int , int )
79-
80- int double_double_mkl_rfft_out (cnp .ndarray , int , int , cnp .ndarray )
81- int double_double_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray )
82- int float_float_mkl_rfft_out (cnp .ndarray , int , int , cnp .ndarray )
83- int float_float_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray )
84-
85- int cdouble_double_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray )
86- int cfloat_float_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray )
93+ cdef struct DftiCache :
94+ void * hand
95+ int initialized
96+ int _free_dfti_cache (DftiCache * )
97+ int cdouble_mkl_fft1d_in (cnp .ndarray , int , int , DftiCache * )
98+ int cfloat_mkl_fft1d_in (cnp .ndarray , int , int , DftiCache * )
99+ int float_cfloat_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , int , DftiCache * )
100+ int cfloat_cfloat_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
101+ int double_cdouble_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , int , DftiCache * )
102+ int cdouble_cdouble_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
103+
104+ int cdouble_mkl_ifft1d_in (cnp .ndarray , int , int , DftiCache * )
105+ int cfloat_mkl_ifft1d_in (cnp .ndarray , int , int , DftiCache * )
106+ int float_cfloat_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , int , DftiCache * )
107+ int cfloat_cfloat_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarra , DftiCache * )
108+ int double_cdouble_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , int , DftiCache * )
109+ int cdouble_cdouble_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
110+
111+ int double_mkl_rfft_in (cnp .ndarray , int , int , DftiCache * )
112+ int double_mkl_irfft_in (cnp .ndarray , int , int , DftiCache * )
113+ int float_mkl_rfft_in (cnp .ndarray , int , int , DftiCache * )
114+ int float_mkl_irfft_in (cnp .ndarray , int , int , DftiCache * )
115+
116+ int double_double_mkl_rfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
117+ int double_double_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
118+ int float_float_mkl_rfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
119+ int float_float_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
120+
121+ int cdouble_double_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
122+ int cfloat_float_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
87123
88124 int cdouble_cdouble_mkl_fftnd_in (cnp .ndarray )
89125 int cdouble_cdouble_mkl_ifftnd_in (cnp .ndarray )
@@ -268,6 +304,7 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
268304 cdef int ALL_HARMONICS = 1
269305 cdef char * c_error_msg = NULL
270306 cdef bytes py_error_msg
307+ cdef DftiCache * _cache
271308
272309 x_arr = __process_arguments (x , n , axis , overwrite_arg , direction ,
273310 & axis_ , & n_ , & in_place , & xnd , & dir_ , 0 )
@@ -295,19 +332,20 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
295332 in_place = 1
296333
297334 if in_place :
298- with _lock :
299- if x_type is cnp .NPY_CDOUBLE :
300- if dir_ < 0 :
301- status = cdouble_mkl_ifft1d_in (x_arr , n_ , < int > axis_ )
302- else :
303- status = cdouble_mkl_fft1d_in (x_arr , n_ , < int > axis_ )
304- elif x_type is cnp .NPY_CFLOAT :
305- if dir_ < 0 :
306- status = cfloat_mkl_ifft1d_in (x_arr , n_ , < int > axis_ )
307- else :
308- status = cfloat_mkl_fft1d_in (x_arr , n_ , < int > axis_ )
335+ _cache_capsule = _tls_dfti_cache_capsule ()
336+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
337+ if x_type is cnp .NPY_CDOUBLE :
338+ if dir_ < 0 :
339+ status = cdouble_mkl_ifft1d_in (x_arr , n_ , < int > axis_ , _cache )
340+ else :
341+ status = cdouble_mkl_fft1d_in (x_arr , n_ , < int > axis_ , _cache )
342+ elif x_type is cnp .NPY_CFLOAT :
343+ if dir_ < 0 :
344+ status = cfloat_mkl_ifft1d_in (x_arr , n_ , < int > axis_ , _cache )
309345 else :
310- status = 1
346+ status = cfloat_mkl_fft1d_in (x_arr , n_ , < int > axis_ , _cache )
347+ else :
348+ status = 1
311349
312350 if status :
313351 c_error_msg = mkl_dfti_error (status )
@@ -327,37 +365,38 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
327365 f_arr = __allocate_result (x_arr , n_ , axis_ , f_type );
328366
329367 # call out-of-place FFT
330- with _lock :
331- if f_type is cnp .NPY_CDOUBLE :
332- if x_type is cnp .NPY_DOUBLE :
333- if dir_ < 0 :
334- status = double_cdouble_mkl_ifft1d_out (
335- x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS )
336- else :
337- status = double_cdouble_mkl_fft1d_out (
338- x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS )
339- elif x_type is cnp .NPY_CDOUBLE :
340- if dir_ < 0 :
341- status = cdouble_cdouble_mkl_ifft1d_out (
342- x_arr , n_ , < int > axis_ , f_arr )
343- else :
344- status = cdouble_cdouble_mkl_fft1d_out (
345- x_arr , n_ , < int > axis_ , f_arr )
346- else :
347- if x_type is cnp .NPY_FLOAT :
348- if dir_ < 0 :
349- status = float_cfloat_mkl_ifft1d_out (
350- x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS )
351- else :
352- status = float_cfloat_mkl_fft1d_out (
353- x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS )
354- elif x_type is cnp .NPY_CFLOAT :
355- if dir_ < 0 :
356- status = cfloat_cfloat_mkl_ifft1d_out (
357- x_arr , n_ , < int > axis_ , f_arr )
358- else :
359- status = cfloat_cfloat_mkl_fft1d_out (
360- x_arr , n_ , < int > axis_ , f_arr )
368+ _cache_capsule = _tls_dfti_cache_capsule ()
369+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
370+ if f_type is cnp .NPY_CDOUBLE :
371+ if x_type is cnp .NPY_DOUBLE :
372+ if dir_ < 0 :
373+ status = double_cdouble_mkl_ifft1d_out (
374+ x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS , _cache )
375+ else :
376+ status = double_cdouble_mkl_fft1d_out (
377+ x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS , _cache )
378+ elif x_type is cnp .NPY_CDOUBLE :
379+ if dir_ < 0 :
380+ status = cdouble_cdouble_mkl_ifft1d_out (
381+ x_arr , n_ , < int > axis_ , f_arr , _cache )
382+ else :
383+ status = cdouble_cdouble_mkl_fft1d_out (
384+ x_arr , n_ , < int > axis_ , f_arr , _cache )
385+ else :
386+ if x_type is cnp .NPY_FLOAT :
387+ if dir_ < 0 :
388+ status = float_cfloat_mkl_ifft1d_out (
389+ x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS , _cache )
390+ else :
391+ status = float_cfloat_mkl_fft1d_out (
392+ x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS , _cache )
393+ elif x_type is cnp .NPY_CFLOAT :
394+ if dir_ < 0 :
395+ status = cfloat_cfloat_mkl_ifft1d_out (
396+ x_arr , n_ , < int > axis_ , f_arr , _cache )
397+ else :
398+ status = cfloat_cfloat_mkl_fft1d_out (
399+ x_arr , n_ , < int > axis_ , f_arr , _cache )
361400
362401 if (status ):
363402 c_error_msg = mkl_dfti_error (status )
@@ -388,6 +427,7 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
388427 cdef int x_type , status
389428 cdef char * c_error_msg = NULL
390429 cdef bytes py_error_msg
430+ cdef DftiCache * _cache
391431
392432 x_arr = __process_arguments (x , n , axis , overwrite_arg , direction ,
393433 & axis_ , & n_ , & in_place , & xnd , & dir_ , 1 )
@@ -413,19 +453,20 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
413453 in_place = 1
414454
415455 if in_place :
416- with _lock :
417- if x_type is cnp .NPY_DOUBLE :
418- if dir_ < 0 :
419- status = double_mkl_irfft_in (x_arr , n_ , < int > axis_ )
420- else :
421- status = double_mkl_rfft_in (x_arr , n_ , < int > axis_ )
422- elif x_type is cnp .NPY_FLOAT :
423- if dir_ < 0 :
424- status = float_mkl_irfft_in (x_arr , n_ , < int > axis_ )
425- else :
426- status = float_mkl_rfft_in (x_arr , n_ , < int > axis_ )
456+ _cache_capsule = _tls_dfti_cache_capsule ()
457+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
458+ if x_type is cnp .NPY_DOUBLE :
459+ if dir_ < 0 :
460+ status = double_mkl_irfft_in (x_arr , n_ , < int > axis_ , _cache )
461+ else :
462+ status = double_mkl_rfft_in (x_arr , n_ , < int > axis_ , _cache )
463+ elif x_type is cnp .NPY_FLOAT :
464+ if dir_ < 0 :
465+ status = float_mkl_irfft_in (x_arr , n_ , < int > axis_ , _cache )
427466 else :
428- status = 1
467+ status = float_mkl_rfft_in (x_arr , n_ , < int > axis_ , _cache )
468+ else :
469+ status = 1
429470
430471 if status :
431472 c_error_msg = mkl_dfti_error (status )
@@ -443,17 +484,18 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
443484 f_arr = __allocate_result (x_arr , n_ , axis_ , x_type );
444485
445486 # call out-of-place FFT
446- with _lock :
447- if x_type is cnp .NPY_DOUBLE :
448- if dir_ < 0 :
449- status = double_double_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr )
450- else :
451- status = double_double_mkl_rfft_out (x_arr , n_ , < int > axis_ , f_arr )
487+ _cache_capsule = _tls_dfti_cache_capsule ()
488+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
489+ if x_type is cnp .NPY_DOUBLE :
490+ if dir_ < 0 :
491+ status = double_double_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
452492 else :
453- if dir_ < 0 :
454- status = float_float_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr )
455- else :
456- status = float_float_mkl_rfft_out (x_arr , n_ , < int > axis_ , f_arr )
493+ status = double_double_mkl_rfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
494+ else :
495+ if dir_ < 0 :
496+ status = float_float_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
497+ else :
498+ status = float_float_mkl_rfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
457499
458500 if (status ):
459501 c_error_msg = mkl_dfti_error (status )
@@ -479,6 +521,7 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
479521 cdef int direction = 1 # dummy, only used for the sake of arg-processing
480522 cdef char * c_error_msg = NULL
481523 cdef bytes py_error_msg
524+ cdef DftiCache * _cache
482525
483526 x_arr = __process_arguments (x , n , axis , overwrite_arg , direction ,
484527 & axis_ , & n_ , & in_place , & xnd , & dir_ , 1 )
@@ -509,11 +552,13 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
509552
510553 # call out-of-place FFT
511554 if x_type is cnp .NPY_FLOAT :
512- with _lock :
513- status = float_cfloat_mkl_fft1d_out (x_arr , n_ , < int > axis_ , f_arr , HALF_HARMONICS )
555+ _cache_capsule = _tls_dfti_cache_capsule ()
556+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
557+ status = float_cfloat_mkl_fft1d_out (x_arr , n_ , < int > axis_ , f_arr , HALF_HARMONICS , _cache )
514558 else :
515- with _lock :
516- status = double_cdouble_mkl_fft1d_out (x_arr , n_ , < int > axis_ , f_arr , HALF_HARMONICS )
559+ _cache_capsule = _tls_dfti_cache_capsule ()
560+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
561+ status = double_cdouble_mkl_fft1d_out (x_arr , n_ , < int > axis_ , f_arr , HALF_HARMONICS , _cache )
517562
518563 if (status ):
519564 c_error_msg = mkl_dfti_error (status )
@@ -553,6 +598,7 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
553598 cdef int direction = 1 # dummy, only used for the sake of arg-processing
554599 cdef char * c_error_msg = NULL
555600 cdef bytes py_error_msg
601+ cdef DftiCache * _cache
556602
557603 int_n = _is_integral (n )
558604 # nn gives the number elements along axis of the input that we use
@@ -591,11 +637,13 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
591637
592638 # call out-of-place FFT
593639 if x_type is cnp .NPY_CFLOAT :
594- with _lock :
595- status = cfloat_float_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr )
640+ _cache_capsule = _tls_dfti_cache_capsule ()
641+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
642+ status = cfloat_float_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
596643 else :
597- with _lock :
598- status = cdouble_double_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr )
644+ _cache_capsule = _tls_dfti_cache_capsule ()
645+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
646+ status = cdouble_double_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
599647
600648 if (status ):
601649 c_error_msg = mkl_dfti_error (status )
0 commit comments