@@ -56,7 +56,7 @@ cdef void _capsule_destructor(object caps):
5656 PyMem_Free (_cache )
5757 if (status != 0 ):
5858 raise ValueError ("Internal Error: Freeing DFTI Cache returned with error = {}" .format (status ))
59-
59+
6060
6161def _tls_dfti_cache_capsule ():
6262 cdef DftiCache * _cache_struct
@@ -72,7 +72,7 @@ def _tls_dfti_cache_capsule():
7272 capsule = getattr (_tls , 'capsule' , None )
7373 if (not cpython .pycapsule .PyCapsule_IsValid (capsule , capsule_name )):
7474 raise ValueError ("Internal Error: invalid capsule stored in TLS" )
75- return capsule
75+ return capsule
7676
7777
7878cdef extern from "Python.h" :
@@ -113,11 +113,6 @@ cdef extern from "src/mklfft.h":
113113 int float_mkl_rfft_in (cnp .ndarray , int , int , DftiCache * )
114114 int float_mkl_irfft_in (cnp .ndarray , int , int , DftiCache * )
115115
116- int double_double_mkl_rfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
117- int double_double_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
118- int float_float_mkl_rfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
119- int float_float_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
120-
121116 int cdouble_double_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
122117 int cfloat_float_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
123118
@@ -408,101 +403,239 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
408403
409404def rfft (x , n = None , axis = - 1 , overwrite_x = False ):
410405 """Packed real-valued harmonics of FFT of a real sequence x"""
411- return _rrfft1d_impl (x , n = n , axis = axis , overwrite_arg = overwrite_x , direction = + 1 )
406+ return _rr_fft1d_impl2 (x , n = n , axis = axis , overwrite_arg = overwrite_x )
412407
413408
414409def irfft (x , n = None , axis = - 1 , overwrite_x = False ):
415410 """Inverse FFT of a real sequence, takes packed real-valued harmonics of FFT"""
416- return _rrfft1d_impl (x , n = n , axis = axis , overwrite_arg = overwrite_x , direction = - 1 )
411+ return _rr_ifft1d_impl2 (x , n = n , axis = axis , overwrite_arg = overwrite_x )
412+
413+
414+ cdef object _rc_to_rr (cnp .ndarray rc_arr , int n , int axis , int xnd , int x_type ):
415+ cdef object res
416+ cdef object sl , sl1 , sl2
417+
418+ inp = < object > rc_arr
419+
420+ slice_ = [slice (None , None , None )] * xnd
421+ sl_0 = list (slice_ )
422+ sl_0 [axis ] = 0
423+
424+ sl_1 = list (slice_ )
425+ sl_1 [axis ] = 1
426+ if (inp .flags ['C' ] and inp .strides [axis ] == inp .itemsize ):
427+ res = inp
428+ res = res .view (dtype = np .single if (x_type == cnp .NPY_FLOAT ) else np .double )
429+ res [tuple (sl_1 )] = res [tuple (sl_0 )]
430+
431+ slice_ [axis ] = slice (1 , n + 1 , None )
432+
433+ return res [tuple (slice_ )]
434+ else :
435+ res_shape = list (inp .shape )
436+ res_shape [axis ] = n
437+ res = np .empty (tuple (res_shape ), dtype = np .single if (x_type == cnp .NPY_FLOAT ) else np .double )
438+
439+ res [tuple (sl_0 )] = inp [tuple (sl_0 )].real
440+ sl_dst_real = list (slice_ )
441+ sl_dst_real [axis ] = slice (1 , None , 2 )
442+ sl_src_real = list (slice_ )
443+ sl_src_real [axis ] = slice (1 , None , None )
444+ res [tuple (sl_dst_real )] = inp [tuple (sl_src_real )].real
445+ sl_dst_imag = list (slice_ )
446+ sl_dst_imag [axis ] = slice (2 , None , 2 )
447+ sl_src_imag = list (slice_ )
448+ sl_src_imag [axis ] = slice (1 , inp .shape [axis ] if (n & 1 ) else inp .shape [axis ] - 1 , None )
449+ res [tuple (sl_dst_imag )] = inp [tuple (sl_src_imag )].imag
450+
451+ return res [tuple (slice_ )]
452+
453+ cdef object _rr_to_rc (cnp .ndarray rr_arr , int n , int axis , int xnd , int x_type ):
454+
455+ inp = < object > rr_arr
456+
457+ rc_shape = list (inp .shape )
458+ rc_shape [axis ] = (n // 2 + 1 )
459+ rc_shape = tuple (rc_shape )
460+
461+ rc_dtype = np .cdouble if x_type == cnp .NPY_DOUBLE else np .csingle
462+ rc = np .empty (rc_shape , dtype = rc_dtype , order = 'C' )
463+
464+ slice_ = [slice (None , None , None )] * xnd
465+ sl_src_real = list (slice_ )
466+ sl_src_imag = list (slice_ )
467+ sl_src_real [axis ] = slice (1 , n , 2 )
468+ sl_src_imag [axis ] = slice (2 , n , 2 )
469+
470+ sl_dest_real = list (slice_ )
471+ sl_dest_real [axis ] = slice (1 , None , None )
472+ sl_dest_imag = list (slice_ )
473+ sl_dest_imag [axis ] = slice (1 , (n + 1 )// 2 , None )
474+
475+ sl_0 = list (slice_ )
476+ sl_0 [axis ] = 0
477+
478+ rc_real = rc .real
479+ rc_imag = rc .imag
480+
481+ rc_real [tuple (sl_dest_real )] = inp [tuple (sl_src_real )]
482+ rc_imag [tuple (sl_dest_imag )] = inp [tuple (sl_src_imag )]
483+ rc_real [tuple (sl_0 )] = inp [tuple (sl_0 )]
484+ rc_imag [tuple (sl_0 )] = 0
485+ if (n & 1 == 0 ):
486+ sl_last = list (slice_ )
487+ sl_last [axis ] = - 1
488+ rc_imag [tuple (sl_last )] = 0
489+
490+ return rc
491+
492+
493+ def _repack_rr_to_rc (x , n , axis ):
494+ """Debugging utility"""
495+ cdef cnp .ndarray x_arr
496+ cdef int n_ = n , axis_ = axis
497+ cdef x_type
498+
499+ x_arr = < cnp .ndarray > np .asarray (x )
500+ x_type = cnp .PyArray_TYPE (x_arr )
501+ return _rr_to_rc (x , n_ , axis_ , cnp .PyArray_NDIM (x_arr ), x_type )
502+
503+
504+ def _repack_rc_to_rr (x , n , axis ):
505+ """Debugging utility"""
506+ cdef cnp .ndarray x_arr
507+ cdef int n_ = n , axis_ = axis
508+ cdef c_type , x_type
509+
510+ x_arr = < cnp .ndarray > np .asarray (x )
511+ c_type = cnp .PyArray_TYPE (x_arr )
512+ x_type = cnp .NPY_DOUBLE if c_type == cnp .NPY_CDOUBLE else cnp .NPY_FLOAT
513+ return _rc_to_rr (x , n_ , axis_ , cnp .PyArray_NDIM (x_arr ), x_type )
417514
418515
419- def _rrfft1d_impl (x , n = None , axis = - 1 , overwrite_arg = False , direction = + 1 ):
516+ def _rr_fft1d_impl2 (x , n = None , axis = - 1 , overwrite_arg = False ):
420517 """
421518 Uses MKL to perform real packed 1D FFT on the input array x along the given axis.
519+
520+ This done by using rfft_numpy and post-processing the result.
521+ Thus overwrite_arg is effectively discarded.
522+
523+ Functionally equivalent to scipy.fftpack.rfft
422524 """
423525 cdef cnp .ndarray x_arr "x_arrayObject"
424526 cdef cnp .ndarray f_arr "f_arrayObject"
425527 cdef int xnd , err , n_max = 0 , in_place , dir_
426528 cdef long n_ , axis_
427- cdef int x_type , status
529+ cdef int HALF_HARMONICS = 0 # give only positive index harmonics
530+ cdef int x_type , status , f_type
428531 cdef char * c_error_msg = NULL
429532 cdef bytes py_error_msg
430533 cdef DftiCache * _cache
431534
432- x_arr = __process_arguments (x , n , axis , overwrite_arg , direction ,
535+ x_arr = __process_arguments (x , n , axis , overwrite_arg , < object > ( + 1 ) ,
433536 & axis_ , & n_ , & in_place , & xnd , & dir_ , 1 )
434537
435538 x_type = cnp .PyArray_TYPE (x_arr )
436539
437540 if x_type is cnp .NPY_FLOAT or x_type is cnp .NPY_DOUBLE :
438- # we can operate in place if requested.
439- if in_place :
440- if not cnp .PyArray_ISONESEGMENT (x_arr ):
441- in_place = 0 if internal_overlap (x_arr ) else 1 ;
541+ in_place = 0
442542 elif x_type is cnp .NPY_CFLOAT or x_type is cnp .NPY_CDOUBLE :
443543 raise TypeError ("1st argument must be a real sequence" )
444544 else :
445- # we must cast the input and allocate the output,
446- # so we cast to double and operate in place
447545 try :
448546 x_arr = < cnp .ndarray > cnp .PyArray_FROM_OTF (
449547 x_arr , cnp .NPY_DOUBLE , cnp .NPY_BEHAVED )
450548 except :
451549 raise TypeError ("1st argument must be a real sequence" )
452550 x_type = cnp .PyArray_TYPE (x_arr )
453- in_place = 1
551+ in_place = 0
454552
455- if in_place :
456- _cache_capsule = _tls_dfti_cache_capsule ()
457- _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
458- if x_type is cnp .NPY_DOUBLE :
459- if dir_ < 0 :
460- status = double_mkl_irfft_in (x_arr , n_ , < int > axis_ , _cache )
461- else :
462- status = double_mkl_rfft_in (x_arr , n_ , < int > axis_ , _cache )
463- elif x_type is cnp .NPY_FLOAT :
464- if dir_ < 0 :
465- status = float_mkl_irfft_in (x_arr , n_ , < int > axis_ , _cache )
466- else :
467- status = float_mkl_rfft_in (x_arr , n_ , < int > axis_ , _cache )
468- else :
469- status = 1
553+ f_type = cnp .NPY_CFLOAT if x_type is cnp .NPY_FLOAT else cnp .NPY_CDOUBLE
554+ f_arr = __allocate_result (x_arr , n_ // 2 + 1 , axis_ , f_type );
470555
471- if status :
472- c_error_msg = mkl_dfti_error (status )
473- py_error_msg = c_error_msg
474- raise ValueError ("Internal error occurred: {}" .format (py_error_msg ))
556+ _cache_capsule = _tls_dfti_cache_capsule ()
557+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
558+ if x_type is cnp .NPY_DOUBLE :
559+ status = double_cdouble_mkl_fft1d_out (x_arr , n_ , < int > axis_ , f_arr , HALF_HARMONICS , _cache )
560+ else :
561+ status = float_cfloat_mkl_fft1d_out (x_arr , n_ , < int > axis_ , f_arr , HALF_HARMONICS , _cache )
475562
476- n_max = < long > cnp .PyArray_DIM (x_arr , axis_ )
477- if (n_ < n_max ):
478- ind = [slice (0 , None , None ), ] * xnd
479- ind [axis_ ] = slice (0 , n_ , None )
480- x_arr = x_arr [tuple (ind )]
563+ if (status ):
564+ c_error_msg = mkl_dfti_error (status )
565+ py_error_msg = c_error_msg
566+ raise ValueError ("Internal error occurred: {}" .format (py_error_msg ))
481567
482- return x_arr
568+ # post-process and return
569+ return _rc_to_rr (f_arr , n_ , axis_ , xnd , x_type )
570+
571+
572+ def _rr_ifft1d_impl2 (x , n = None , axis = - 1 , overwrite_arg = False ):
573+ """
574+ Uses MKL to perform real packed 1D FFT on the input array x along the given axis.
575+
576+ This done by using rfft_numpy and post-processing the result.
577+ Thus overwrite_arg is effectively discarded.
578+
579+ Functionally equivalent to scipy.fftpack.irfft
580+ """
581+ cdef cnp .ndarray x_arr "x_arrayObject"
582+ cdef cnp .ndarray f_arr "f_arrayObject"
583+ cdef int xnd , err , n_max = 0 , in_place , dir_ , int_n
584+ cdef long n_ , axis_
585+ cdef int x_type , rc_type , status
586+ cdef int direction = 1 # dummy, only used for the sake of arg-processing
587+ cdef char * c_error_msg = NULL
588+ cdef bytes py_error_msg
589+ cdef DftiCache * _cache
590+
591+ x_arr = __process_arguments (x , n , axis , overwrite_arg , < object > (- 1 ),
592+ & axis_ , & n_ , & in_place , & xnd , & dir_ , 1 )
593+
594+ x_type = cnp .PyArray_TYPE (x_arr )
595+
596+ if x_type is cnp .NPY_FLOAT or x_type is cnp .NPY_DOUBLE :
597+ pass
483598 else :
484- f_arr = __allocate_result (x_arr , n_ , axis_ , x_type );
599+ # we must cast the input and allocate the output,
600+ # so we cast to complex double and operate in place
601+ try :
602+ x_arr = < cnp .ndarray > cnp .PyArray_FROM_OTF (
603+ x_arr , cnp .NPY_DOUBLE , cnp .NPY_BEHAVED )
604+ except :
605+ raise ValueError ("First argument should be a real or a complex sequence of single or double precision" )
606+ x_type = cnp .PyArray_TYPE (x_arr )
607+ in_place = 1
485608
486- # call out-of-place FFT
609+ # need to convert this into complex array
610+ rc_obj = _rr_to_rc (x_arr , n_ , axis_ , xnd , x_type )
611+ rc_arr = < cnp .ndarray > rc_obj
612+
613+ rc_type = cnp .NPY_CFLOAT if x_type is cnp .NPY_FLOAT else cnp .NPY_CDOUBLE
614+ in_place = False
615+ if in_place :
616+ f_arr = x_arr
617+ else :
618+ f_arr = __allocate_result (x_arr , n_ , axis_ , x_type )
619+
620+ # call out-of-place FFT
621+ if rc_type is cnp .NPY_CFLOAT :
487622 _cache_capsule = _tls_dfti_cache_capsule ()
488623 _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
489- if x_type is cnp .NPY_DOUBLE :
490- if dir_ < 0 :
491- status = double_double_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
492- else :
493- status = double_double_mkl_rfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
494- else :
495- if dir_ < 0 :
496- status = float_float_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
497- else :
498- status = float_float_mkl_rfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
624+ status = cfloat_float_mkl_irfft_out (rc_arr , n_ , < int > axis_ , f_arr , _cache )
625+ elif rc_type is cnp .NPY_CDOUBLE :
626+ _cache_capsule = _tls_dfti_cache_capsule ()
627+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
628+ status = cdouble_double_mkl_irfft_out (rc_arr , n_ , < int > axis_ , f_arr , _cache )
629+ else :
630+ raise ValueError ("Internal mkl_fft error occurred: Unrecognized rc_type" )
499631
500- if (status ):
501- c_error_msg = mkl_dfti_error (status )
502- py_error_msg = c_error_msg
503- raise ValueError ("Internal error occurred: {}" .format (py_error_msg ))
632+ if (status ):
633+ c_error_msg = mkl_dfti_error (status )
634+ py_error_msg = c_error_msg
635+ raise ValueError ("Internal error occurred: {}" .format (str (py_error_msg )))
636+
637+ return f_arr
504638
505- return f_arr
506639
507640# this routine is functionally equivalent to numpy.fft.rfft
508641def _rc_fft1d_impl (x , n = None , axis = - 1 , overwrite_arg = False ):
@@ -582,13 +715,13 @@ cdef int _is_integral(object num):
582715 return _integral
583716
584717
585- # this routine is functionally equivalent to numpy.fft.rfft
718+ # this routine is functionally equivalent to numpy.fft.irfft
586719def _rc_ifft1d_impl (x , n = None , axis = - 1 , overwrite_arg = False ):
587720 """
588721 Uses MKL to perform 1D FFT on the real input array x along the given axis,
589722 producing complex output, but giving only half of the harmonics.
590723
591- cf. numpy.fft.rfft
724+ cf. numpy.fft.irfft
592725 """
593726 cdef cnp .ndarray x_arr "x_arrayObject"
594727 cdef cnp .ndarray f_arr "f_arrayObject"
@@ -891,8 +1024,8 @@ def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1):
8911024 if _direct :
8921025 return _direct_fftnd (x , overwrite_arg = overwrite_x , direction = direction )
8931026 else :
894- return _iter_fftnd (x , s = shape , axes = axes ,
895- overwrite_arg = overwrite_x ,
1027+ return _iter_fftnd (x , s = shape , axes = axes ,
1028+ overwrite_arg = overwrite_x ,
8961029 function = fft if direction == 1 else ifft )
8971030
8981031
@@ -933,7 +1066,7 @@ def _remove_axis(s, axes, axis_to_remove):
9331066
9341067
9351068cdef cnp .ndarray _trim_array (cnp .ndarray arr , object s , object axes ):
936- """Forms a view into subarray of arr if any element of shape parameter s is
1069+ """Forms a view into subarray of arr if any element of shape parameter s is
9371070 smaller than the corresponding element of the shape of the input array arr,
9381071 otherwise returns the input array"""
9391072 arr_shape = (< object > arr ).shape
0 commit comments