@@ -34,18 +34,52 @@ except ImportError:
3434 from numpy .core ._multiarray_tests import internal_overlap
3535
3636from libc .string cimport memcpy
37+ cimport cpython .pycapsule
38+ from cpython .exc cimport (PyErr_Occurred , PyErr_Clear )
39+ from cpython .mem cimport (PyMem_Malloc , PyMem_Free )
3740
3841from threading import Lock
42+ from threading import local as threading_local
3943_lock = Lock ()
4044
45+ # thread-local storage
46+ _tls = threading_local ()
47+
48+ cdef const char * capsule_name = "dfti_cache"
49+
50+ cdef void _capsule_destructor (object caps ):
51+ cdef DftiCache * _cache = NULL
52+ cdef int status = 0
53+ if (caps is None ):
54+ print ("Nothing to destroy" )
55+ return
56+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (caps , capsule_name )
57+ status = _free_dfti_cache (_cache )
58+ PyMem_Free (_cache )
59+ if (status != 0 ):
60+ raise ValueError ("Internal Error: Freeing DFTI Cache returned with error = {}" .format (status ))
61+
62+
63+ def _tls_dfti_cache_capsule ():
64+ cdef DftiCache * _cache_struct
65+
66+ init = getattr (_tls , 'initialized' , None )
67+ if (init is None ):
68+ _cache_struct = < DftiCache * > PyMem_Malloc (sizeof (DftiCache ));
69+ # important to initialized
70+ _cache_struct .initialized = 0
71+ _cache_struct .hand = NULL
72+ _tls .initialized = True
73+ _tls .capsule = cpython .pycapsule .PyCapsule_New (< void * > _cache_struct , capsule_name , & _capsule_destructor )
74+ capsule = getattr (_tls , 'capsule' , None )
75+ if (not cpython .pycapsule .PyCapsule_IsValid (capsule , capsule_name )):
76+ raise ValueError ("Internal Error: invalid capsule stored in TLS" )
77+ return capsule
78+
79+
4180cdef extern from "Python.h" :
4281 ctypedef int size_t
4382
44- void * PyMem_Malloc (size_t n )
45- void PyMem_Free (void * buf )
46-
47- int PyErr_Occurred ()
48- void PyErr_Clear ()
4983 long PyInt_AsLong (object ob )
5084 int PyObject_HasAttrString (object , char * )
5185
@@ -58,32 +92,36 @@ cdef extern from *:
5892 object PyArray_BASE (cnp .ndarray )
5993
6094cdef extern from "src/mklfft.h" :
61- int cdouble_mkl_fft1d_in (cnp .ndarray , int , int )
62- int cfloat_mkl_fft1d_in (cnp .ndarray , int , int )
63- int float_cfloat_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , int )
64- int cfloat_cfloat_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray )
65- int double_cdouble_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , int )
66- int cdouble_cdouble_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray )
67-
68- int cdouble_mkl_ifft1d_in (cnp .ndarray , int , int )
69- int cfloat_mkl_ifft1d_in (cnp .ndarray , int , int )
70- int float_cfloat_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , int )
71- int cfloat_cfloat_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray )
72- int double_cdouble_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , int )
73- int cdouble_cdouble_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray )
74-
75- int double_mkl_rfft_in (cnp .ndarray , int , int )
76- int double_mkl_irfft_in (cnp .ndarray , int , int )
77- int float_mkl_rfft_in (cnp .ndarray , int , int )
78- int float_mkl_irfft_in (cnp .ndarray , int , int )
79-
80- int double_double_mkl_rfft_out (cnp .ndarray , int , int , cnp .ndarray )
81- int double_double_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray )
82- int float_float_mkl_rfft_out (cnp .ndarray , int , int , cnp .ndarray )
83- int float_float_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray )
84-
85- int cdouble_double_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray )
86- int cfloat_float_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray )
95+ cdef struct DftiCache :
96+ void * hand
97+ int initialized
98+ int _free_dfti_cache (DftiCache * )
99+ int cdouble_mkl_fft1d_in (cnp .ndarray , int , int , DftiCache * )
100+ int cfloat_mkl_fft1d_in (cnp .ndarray , int , int , DftiCache * )
101+ int float_cfloat_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , int , DftiCache * )
102+ int cfloat_cfloat_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
103+ int double_cdouble_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , int , DftiCache * )
104+ int cdouble_cdouble_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
105+
106+ int cdouble_mkl_ifft1d_in (cnp .ndarray , int , int , DftiCache * )
107+ int cfloat_mkl_ifft1d_in (cnp .ndarray , int , int , DftiCache * )
108+ int float_cfloat_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , int , DftiCache * )
109+ int cfloat_cfloat_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarra , DftiCache * )
110+ int double_cdouble_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , int , DftiCache * )
111+ int cdouble_cdouble_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
112+
113+ int double_mkl_rfft_in (cnp .ndarray , int , int , DftiCache * )
114+ int double_mkl_irfft_in (cnp .ndarray , int , int , DftiCache * )
115+ int float_mkl_rfft_in (cnp .ndarray , int , int , DftiCache * )
116+ int float_mkl_irfft_in (cnp .ndarray , int , int , DftiCache * )
117+
118+ int double_double_mkl_rfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
119+ int double_double_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
120+ int float_float_mkl_rfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
121+ int float_float_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
122+
123+ int cdouble_double_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
124+ int cfloat_float_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
87125
88126 int cdouble_cdouble_mkl_fftnd_in (cnp .ndarray )
89127 int cdouble_cdouble_mkl_ifftnd_in (cnp .ndarray )
@@ -268,6 +306,7 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
268306 cdef int ALL_HARMONICS = 1
269307 cdef char * c_error_msg = NULL
270308 cdef bytes py_error_msg
309+ cdef DftiCache * _cache
271310
272311 x_arr = __process_arguments (x , n , axis , overwrite_arg , direction ,
273312 & axis_ , & n_ , & in_place , & xnd , & dir_ , 0 )
@@ -296,16 +335,18 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
296335
297336 if in_place :
298337 with _lock :
338+ _cache_capsule = _tls_dfti_cache_capsule ()
339+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
299340 if x_type is cnp .NPY_CDOUBLE :
300341 if dir_ < 0 :
301- status = cdouble_mkl_ifft1d_in (x_arr , n_ , < int > axis_ )
342+ status = cdouble_mkl_ifft1d_in (x_arr , n_ , < int > axis_ , _cache )
302343 else :
303- status = cdouble_mkl_fft1d_in (x_arr , n_ , < int > axis_ )
344+ status = cdouble_mkl_fft1d_in (x_arr , n_ , < int > axis_ , _cache )
304345 elif x_type is cnp .NPY_CFLOAT :
305346 if dir_ < 0 :
306- status = cfloat_mkl_ifft1d_in (x_arr , n_ , < int > axis_ )
347+ status = cfloat_mkl_ifft1d_in (x_arr , n_ , < int > axis_ , _cache )
307348 else :
308- status = cfloat_mkl_fft1d_in (x_arr , n_ , < int > axis_ )
349+ status = cfloat_mkl_fft1d_in (x_arr , n_ , < int > axis_ , _cache )
309350 else :
310351 status = 1
311352
@@ -328,36 +369,38 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
328369
329370 # call out-of-place FFT
330371 with _lock :
372+ _cache_capsule = _tls_dfti_cache_capsule ()
373+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
331374 if f_type is cnp .NPY_CDOUBLE :
332375 if x_type is cnp .NPY_DOUBLE :
333376 if dir_ < 0 :
334377 status = double_cdouble_mkl_ifft1d_out (
335- x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS )
378+ x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS , _cache )
336379 else :
337380 status = double_cdouble_mkl_fft1d_out (
338- x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS )
381+ x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS , _cache )
339382 elif x_type is cnp .NPY_CDOUBLE :
340383 if dir_ < 0 :
341384 status = cdouble_cdouble_mkl_ifft1d_out (
342- x_arr , n_ , < int > axis_ , f_arr )
385+ x_arr , n_ , < int > axis_ , f_arr , _cache )
343386 else :
344387 status = cdouble_cdouble_mkl_fft1d_out (
345- x_arr , n_ , < int > axis_ , f_arr )
388+ x_arr , n_ , < int > axis_ , f_arr , _cache )
346389 else :
347390 if x_type is cnp .NPY_FLOAT :
348391 if dir_ < 0 :
349392 status = float_cfloat_mkl_ifft1d_out (
350- x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS )
393+ x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS , _cache )
351394 else :
352395 status = float_cfloat_mkl_fft1d_out (
353- x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS )
396+ x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS , _cache )
354397 elif x_type is cnp .NPY_CFLOAT :
355398 if dir_ < 0 :
356399 status = cfloat_cfloat_mkl_ifft1d_out (
357- x_arr , n_ , < int > axis_ , f_arr )
400+ x_arr , n_ , < int > axis_ , f_arr , _cache )
358401 else :
359402 status = cfloat_cfloat_mkl_fft1d_out (
360- x_arr , n_ , < int > axis_ , f_arr )
403+ x_arr , n_ , < int > axis_ , f_arr , _cache )
361404
362405 if (status ):
363406 c_error_msg = mkl_dfti_error (status )
@@ -388,6 +431,7 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
388431 cdef int x_type , status
389432 cdef char * c_error_msg = NULL
390433 cdef bytes py_error_msg
434+ cdef DftiCache * _cache
391435
392436 x_arr = __process_arguments (x , n , axis , overwrite_arg , direction ,
393437 & axis_ , & n_ , & in_place , & xnd , & dir_ , 1 )
@@ -414,16 +458,18 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
414458
415459 if in_place :
416460 with _lock :
461+ _cache_capsule = _tls_dfti_cache_capsule ()
462+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
417463 if x_type is cnp .NPY_DOUBLE :
418464 if dir_ < 0 :
419- status = double_mkl_irfft_in (x_arr , n_ , < int > axis_ )
465+ status = double_mkl_irfft_in (x_arr , n_ , < int > axis_ , _cache )
420466 else :
421- status = double_mkl_rfft_in (x_arr , n_ , < int > axis_ )
467+ status = double_mkl_rfft_in (x_arr , n_ , < int > axis_ , _cache )
422468 elif x_type is cnp .NPY_FLOAT :
423469 if dir_ < 0 :
424- status = float_mkl_irfft_in (x_arr , n_ , < int > axis_ )
470+ status = float_mkl_irfft_in (x_arr , n_ , < int > axis_ , _cache )
425471 else :
426- status = float_mkl_rfft_in (x_arr , n_ , < int > axis_ )
472+ status = float_mkl_rfft_in (x_arr , n_ , < int > axis_ , _cache )
427473 else :
428474 status = 1
429475
@@ -444,16 +490,18 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
444490
445491 # call out-of-place FFT
446492 with _lock :
493+ _cache_capsule = _tls_dfti_cache_capsule ()
494+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
447495 if x_type is cnp .NPY_DOUBLE :
448496 if dir_ < 0 :
449- status = double_double_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr )
497+ status = double_double_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
450498 else :
451- status = double_double_mkl_rfft_out (x_arr , n_ , < int > axis_ , f_arr )
499+ status = double_double_mkl_rfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
452500 else :
453501 if dir_ < 0 :
454- status = float_float_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr )
502+ status = float_float_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
455503 else :
456- status = float_float_mkl_rfft_out (x_arr , n_ , < int > axis_ , f_arr )
504+ status = float_float_mkl_rfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
457505
458506 if (status ):
459507 c_error_msg = mkl_dfti_error (status )
@@ -479,6 +527,7 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
479527 cdef int direction = 1 # dummy, only used for the sake of arg-processing
480528 cdef char * c_error_msg = NULL
481529 cdef bytes py_error_msg
530+ cdef DftiCache * _cache
482531
483532 x_arr = __process_arguments (x , n , axis , overwrite_arg , direction ,
484533 & axis_ , & n_ , & in_place , & xnd , & dir_ , 1 )
@@ -510,10 +559,14 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
510559 # call out-of-place FFT
511560 if x_type is cnp .NPY_FLOAT :
512561 with _lock :
513- status = float_cfloat_mkl_fft1d_out (x_arr , n_ , < int > axis_ , f_arr , HALF_HARMONICS )
562+ _cache_capsule = _tls_dfti_cache_capsule ()
563+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
564+ status = float_cfloat_mkl_fft1d_out (x_arr , n_ , < int > axis_ , f_arr , HALF_HARMONICS , _cache )
514565 else :
515566 with _lock :
516- status = double_cdouble_mkl_fft1d_out (x_arr , n_ , < int > axis_ , f_arr , HALF_HARMONICS )
567+ _cache_capsule = _tls_dfti_cache_capsule ()
568+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
569+ status = double_cdouble_mkl_fft1d_out (x_arr , n_ , < int > axis_ , f_arr , HALF_HARMONICS , _cache )
517570
518571 if (status ):
519572 c_error_msg = mkl_dfti_error (status )
@@ -553,6 +606,7 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
553606 cdef int direction = 1 # dummy, only used for the sake of arg-processing
554607 cdef char * c_error_msg = NULL
555608 cdef bytes py_error_msg
609+ cdef DftiCache * _cache
556610
557611 int_n = _is_integral (n )
558612 # nn gives the number elements along axis of the input that we use
@@ -592,10 +646,14 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
592646 # call out-of-place FFT
593647 if x_type is cnp .NPY_CFLOAT :
594648 with _lock :
595- status = cfloat_float_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr )
649+ _cache_capsule = _tls_dfti_cache_capsule ()
650+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
651+ status = cfloat_float_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
596652 else :
597653 with _lock :
598- status = cdouble_double_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr )
654+ _cache_capsule = _tls_dfti_cache_capsule ()
655+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
656+ status = cdouble_double_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
599657
600658 if (status ):
601659 c_error_msg = mkl_dfti_error (status )
0 commit comments