@@ -157,11 +157,11 @@ cdef int _datacopied(cnp.ndarray arr, object orig):
157157
158158
159159def fft (x , n = None , axis = - 1 , overwrite_x = False , fwd_scale = 1.0 ):
160- return _fft1d_impl (x , n = n , axis = axis , overwrite_arg = overwrite_x , direction = + 1 , fsc = fwd_scale )
160+ return _fft1d_impl (x , n = n , axis = axis , overwrite_x = overwrite_x , direction = + 1 , fsc = fwd_scale )
161161
162162
163163def ifft (x , n = None , axis = - 1 , overwrite_x = False , fwd_scale = 1.0 ):
164- return _fft1d_impl (x , n = n , axis = axis , overwrite_arg = overwrite_x , direction = - 1 , fsc = fwd_scale )
164+ return _fft1d_impl (x , n = n , axis = axis , overwrite_x = overwrite_x , direction = - 1 , fsc = fwd_scale )
165165
166166
167167cdef cnp .ndarray pad_array (cnp .ndarray x_arr , cnp .npy_intp n , int axis , int realQ ):
@@ -200,7 +200,7 @@ cdef cnp.ndarray pad_array(cnp.ndarray x_arr, cnp.npy_intp n, int axis, int real
200200
201201
202202cdef cnp .ndarray __process_arguments (object x , object n , object axis ,
203- object overwrite_arg , object direction ,
203+ object overwrite_x , object direction ,
204204 long * axis_ , long * n_ , int * in_place ,
205205 int * xnd , int * dir_ , int realQ ):
206206 "Internal utility to validate and process input arguments of 1D FFT functions"
@@ -213,7 +213,7 @@ cdef cnp.ndarray __process_arguments(object x, object n, object axis,
213213 else :
214214 dir_ [0 ] = - 1 if direction is - 1 else + 1
215215
216- in_place [0 ] = 1 if overwrite_arg is True else 0
216+ in_place [0 ] = 1 if overwrite_x else 0
217217
218218 # convert x to ndarray, ensure that strides are multiples of itemsize
219219 x_arr = PyArray_CheckFromAny (
@@ -294,7 +294,7 @@ cdef cnp.ndarray __allocate_result(cnp.ndarray x_arr, long n_, long axis_, int f
294294# Float/double inputs are not cast to complex, but are effectively
295295# treated as complexes with zero imaginary parts.
296296# All other types are cast to complex double.
297- def _fft1d_impl (x , n = None , axis = - 1 , overwrite_arg = False , direction = + 1 , double fsc = 1.0 ):
297+ def _fft1d_impl (x , n = None , axis = - 1 , overwrite_x = False , direction = + 1 , double fsc = 1.0 ):
298298 """
299299 Uses MKL to perform 1D FFT on the input array x along the given axis.
300300 """
@@ -308,7 +308,7 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1, double fs
308308 cdef bytes py_error_msg
309309 cdef DftiCache * _cache
310310
311- x_arr = __process_arguments (x , n , axis , overwrite_arg , direction ,
311+ x_arr = __process_arguments (x , n , axis , overwrite_x , direction ,
312312 & axis_ , & n_ , & in_place , & xnd , & dir_ , 0 )
313313
314314 x_type = cnp .PyArray_TYPE (x_arr )
@@ -410,12 +410,12 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1, double fs
410410
411411def rfftpack (x , n = None , axis = - 1 , overwrite_x = False , fwd_scale = 1.0 ):
412412 """Packed real-valued harmonics of FFT of a real sequence x"""
413- return _rr_fft1d_impl2 (x , n = n , axis = axis , overwrite_arg = overwrite_x , fsc = fwd_scale )
413+ return _rr_fft1d_impl2 (x , n = n , axis = axis , overwrite_x = overwrite_x , fsc = fwd_scale )
414414
415415
416416def irfftpack (x , n = None , axis = - 1 , overwrite_x = False , fwd_scale = 1.0 ):
417417 """Inverse FFT of a real sequence, takes packed real-valued harmonics of FFT"""
418- return _rr_ifft1d_impl2 (x , n = n , axis = axis , overwrite_arg = overwrite_x , fsc = fwd_scale )
418+ return _rr_ifft1d_impl2 (x , n = n , axis = axis , overwrite_x = overwrite_x , fsc = fwd_scale )
419419
420420
421421cdef object _rc_to_rr (cnp .ndarray rc_arr , int n , int axis , int xnd , int x_type ):
@@ -520,12 +520,12 @@ def _repack_rc_to_rr(x, n, axis):
520520 return _rc_to_rr (x , n_ , axis_ , cnp .PyArray_NDIM (x_arr ), x_type )
521521
522522
523- def _rr_fft1d_impl2 (x , n = None , axis = - 1 , overwrite_arg = False , double fsc = 1.0 ):
523+ def _rr_fft1d_impl2 (x , n = None , axis = - 1 , overwrite_x = False , double fsc = 1.0 ):
524524 """
525525 Uses MKL to perform real packed 1D FFT on the input array x along the given axis.
526526
527527 This done by using rfft and post-processing the result.
528- Thus overwrite_arg is effectively discarded.
528+ Thus overwrite_x is effectively discarded.
529529
530530 Functionally equivalent to scipy.fftpack.rfft
531531 """
@@ -539,7 +539,7 @@ def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
539539 cdef bytes py_error_msg
540540 cdef DftiCache * _cache
541541
542- x_arr = __process_arguments (x , n , axis , overwrite_arg , < object > (+ 1 ),
542+ x_arr = __process_arguments (x , n , axis , overwrite_x , < object > (+ 1 ),
543543 & axis_ , & n_ , & in_place , & xnd , & dir_ , 1 )
544544
545545 x_type = cnp .PyArray_TYPE (x_arr )
@@ -576,12 +576,12 @@ def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
576576 return _rc_to_rr (f_arr , n_ , axis_ , xnd , x_type )
577577
578578
579- def _rr_ifft1d_impl2 (x , n = None , axis = - 1 , overwrite_arg = False , double fsc = 1.0 ):
579+ def _rr_ifft1d_impl2 (x , n = None , axis = - 1 , overwrite_x = False , double fsc = 1.0 ):
580580 """
581581 Uses MKL to perform real packed 1D FFT on the input array x along the given axis.
582582
583583 This done by using rfft and post-processing the result.
584- Thus overwrite_arg is effectively discarded.
584+ Thus overwrite_x is effectively discarded.
585585
586586 Functionally equivalent to scipy.fftpack.irfft
587587 """
@@ -595,7 +595,7 @@ def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
595595 cdef bytes py_error_msg
596596 cdef DftiCache * _cache
597597
598- x_arr = __process_arguments (x , n , axis , overwrite_arg , < object > (- 1 ),
598+ x_arr = __process_arguments (x , n , axis , overwrite_x , < object > (- 1 ),
599599 & axis_ , & n_ , & in_place , & xnd , & dir_ , 1 )
600600
601601 x_type = cnp .PyArray_TYPE (x_arr )
@@ -645,7 +645,7 @@ def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
645645
646646
647647# this routine is functionally equivalent to numpy.fft.rfft
648- def _rc_fft1d_impl (x , n = None , axis = - 1 , overwrite_arg = False , double fsc = 1.0 ):
648+ def _rc_fft1d_impl (x , n = None , axis = - 1 , overwrite_x = False , double fsc = 1.0 ):
649649 """
650650 Uses MKL to perform 1D FFT on the real input array x along the given axis,
651651 producing complex output, but giving only half of the harmonics.
@@ -663,13 +663,13 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
663663 cdef bytes py_error_msg
664664 cdef DftiCache * _cache
665665
666- x_arr = __process_arguments (x , n , axis , overwrite_arg , direction ,
666+ x_arr = __process_arguments (x , n , axis , overwrite_x , direction ,
667667 & axis_ , & n_ , & in_place , & xnd , & dir_ , 1 )
668668
669669 x_type = cnp .PyArray_TYPE (x_arr )
670670
671671 if x_type is cnp .NPY_CFLOAT or x_type is cnp .NPY_CDOUBLE or x_type is cnp .NPY_CLONGDOUBLE :
672- raise TypeError ("1st argument must be a real sequence 1 " )
672+ raise TypeError ("1st argument must be a real sequence. " )
673673 elif x_type is cnp .NPY_FLOAT or x_type is cnp .NPY_DOUBLE :
674674 pass
675675 else :
@@ -723,7 +723,7 @@ cdef int _is_integral(object num):
723723
724724
725725# this routine is functionally equivalent to numpy.fft.irfft
726- def _rc_ifft1d_impl (x , n = None , axis = - 1 , overwrite_arg = False , double fsc = 1.0 ):
726+ def _rc_ifft1d_impl (x , n = None , axis = - 1 , overwrite_x = False , double fsc = 1.0 ):
727727 """
728728 Uses MKL to perform 1D FFT on the real input array x along the given axis,
729729 producing complex output, but giving only half of the harmonics.
@@ -743,7 +743,7 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
743743 int_n = _is_integral (n )
744744 # nn gives the number elements along axis of the input that we use
745745 nn = (n // 2 + 1 ) if int_n and n > 0 else n
746- x_arr = __process_arguments (x , nn , axis , overwrite_arg , direction ,
746+ x_arr = __process_arguments (x , nn , axis , overwrite_x , direction ,
747747 & axis_ , & n_ , & in_place , & xnd , & dir_ , 0 )
748748 n_ = 2 * (n_ - 1 )
749749 if int_n and (n % 2 == 1 ):
@@ -907,10 +907,10 @@ def _cook_nd_args(a, s=None, axes=None, invreal=0):
907907 return s , axes
908908
909909
910- def _iter_fftnd (a , s = None , axes = None , function = fft , overwrite_arg = False , scale_function = lambda n , ind : 1.0 ):
910+ def _iter_fftnd (a , s = None , axes = None , function = fft , overwrite_x = False , scale_function = lambda n , ind : 1.0 ):
911911 a = np .asarray (a )
912912 s , axes = _init_nd_shape_and_axes (a , s , axes )
913- ovwr = overwrite_arg
913+ ovwr = overwrite_x
914914 for ii in reversed (range (len (axes ))):
915915 a = function (a , n = s [ii ], axis = axes [ii ], overwrite_x = ovwr , fwd_scale = scale_function (s [ii ], ii ))
916916 ovwr = True
@@ -959,7 +959,7 @@ def iter_complementary(x, axes, func, kwargs, result):
959959 return result
960960
961961
962- def _direct_fftnd (x , overwrite_arg = False , direction = + 1 , double fsc = 1.0 ):
962+ def _direct_fftnd (x , overwrite_x = False , direction = + 1 , double fsc = 1.0 ):
963963 """Perform n-dimensional FFT over all axes"""
964964 cdef int err
965965 cdef long n_max = 0
@@ -972,7 +972,7 @@ def _direct_fftnd(x, overwrite_arg=False, direction=+1, double fsc=1.0):
972972 else :
973973 dir_ = - 1 if direction is - 1 else + 1
974974
975- in_place = 1 if overwrite_arg is True else 0
975+ in_place = 1 if overwrite_x else 0
976976
977977 # convert x to ndarray, ensure that strides are multiples of itemsize
978978 x_arr = PyArray_CheckFromAny (
@@ -1069,56 +1069,56 @@ def _output_dtype(dt):
10691069 return dt
10701070
10711071
1072- def _fftnd_impl (x , shape = None , axes = None , overwrite_x = False , direction = + 1 , double fsc = 1.0 ):
1072+ def _fftnd_impl (x , s = None , axes = None , overwrite_x = False , direction = + 1 , double fsc = 1.0 ):
10731073 if direction not in [- 1 , + 1 ]:
10741074 raise ValueError ("Direction of FFT should +1 or -1" )
10751075
10761076 # _direct_fftnd requires complex type, and full-dimensional transform
10771077 if isinstance (x , np .ndarray ) and x .size != 0 and x .ndim > 1 :
1078- _direct = shape is None and axes is None
1078+ _direct = s is None and axes is None
10791079 if _direct :
10801080 _direct = x .ndim <= 7 # Intel MKL only supports FFT up to 7D
10811081 if not _direct :
1082- xs , xa = _cook_nd_args (x , shape , axes )
1082+ xs , xa = _cook_nd_args (x , s , axes )
10831083 if _check_shapes_for_direct (xs , x .shape , xa ):
10841084 _direct = True
10851085 _direct = _direct and x .dtype in [np .complex64 , np .complex128 , np .float32 , np .float64 ]
10861086 else :
10871087 _direct = False
10881088
10891089 if _direct :
1090- return _direct_fftnd (x , overwrite_arg = overwrite_x , direction = direction , fsc = fsc )
1090+ return _direct_fftnd (x , overwrite_x = overwrite_x , direction = direction , fsc = fsc )
10911091 else :
1092- if (shape is None and x .dtype in [np .csingle , np .cdouble , np .single , np .double ]):
1092+ if (s is None and x .dtype in [np .csingle , np .cdouble , np .single , np .double ]):
10931093 x = np .asarray (x )
10941094 res = np .empty (x .shape , dtype = _output_dtype (x .dtype ))
10951095 return iter_complementary (
10961096 x , axes ,
10971097 _direct_fftnd ,
1098- {'overwrite_arg ' : overwrite_x , 'direction' : direction , 'fsc' : fsc },
1098+ {'overwrite_x ' : overwrite_x , 'direction' : direction , 'fsc' : fsc },
10991099 res
11001100 )
11011101 else :
11021102 sc = < object > fsc
1103- return _iter_fftnd (x , s = shape , axes = axes ,
1104- overwrite_arg = overwrite_x , scale_function = lambda n , i : sc if i == 0 else 1. ,
1103+ return _iter_fftnd (x , s = s , axes = axes ,
1104+ overwrite_x = overwrite_x , scale_function = lambda n , i : sc if i == 0 else 1. ,
11051105 function = fft if direction == 1 else ifft )
11061106
11071107
1108- def fft2 (x , shape = None , axes = (- 2 ,- 1 ), overwrite_x = False , fwd_scale = 1.0 ):
1109- return _fftnd_impl (x , shape = shape , axes = axes , overwrite_x = overwrite_x , direction = + 1 , fsc = fwd_scale )
1108+ def fft2 (x , s = None , axes = (- 2 ,- 1 ), overwrite_x = False , fwd_scale = 1.0 ):
1109+ return _fftnd_impl (x , s = s , axes = axes , overwrite_x = overwrite_x , direction = + 1 , fsc = fwd_scale )
11101110
11111111
1112- def ifft2 (x , shape = None , axes = (- 2 ,- 1 ), overwrite_x = False , fwd_scale = 1.0 ):
1113- return _fftnd_impl (x , shape = shape , axes = axes , overwrite_x = overwrite_x , direction = - 1 , fsc = fwd_scale )
1112+ def ifft2 (x , s = None , axes = (- 2 ,- 1 ), overwrite_x = False , fwd_scale = 1.0 ):
1113+ return _fftnd_impl (x , s = s , axes = axes , overwrite_x = overwrite_x , direction = - 1 , fsc = fwd_scale )
11141114
11151115
1116- def fftn (x , shape = None , axes = None , overwrite_x = False , fwd_scale = 1.0 ):
1117- return _fftnd_impl (x , shape = shape , axes = axes , overwrite_x = overwrite_x , direction = + 1 , fsc = fwd_scale )
1116+ def fftn (x , s = None , axes = None , overwrite_x = False , fwd_scale = 1.0 ):
1117+ return _fftnd_impl (x , s = s , axes = axes , overwrite_x = overwrite_x , direction = + 1 , fsc = fwd_scale )
11181118
11191119
1120- def ifftn (x , shape = None , axes = None , overwrite_x = False , fwd_scale = 1.0 ):
1121- return _fftnd_impl (x , shape = shape , axes = axes , overwrite_x = overwrite_x , direction = - 1 , fsc = fwd_scale )
1120+ def ifftn (x , s = None , axes = None , overwrite_x = False , fwd_scale = 1.0 ):
1121+ return _fftnd_impl (x , s = s , axes = axes , overwrite_x = overwrite_x , direction = - 1 , fsc = fwd_scale )
11221122
11231123
11241124def rfft2 (x , s = None , axes = (- 2 ,- 1 ), fwd_scale = 1.0 ):
@@ -1154,7 +1154,7 @@ cdef cnp.ndarray _trim_array(cnp.ndarray arr, object s, object axes):
11541154 raise ValueError ("Invalid axis (%d) specified" % ai )
11551155 if si < shp_i :
11561156 if no_trim :
1157- ind = [slice (None ,None ,None ),] * len (s )
1157+ ind = [slice (None ,None ,None ),] * len (arr_shape )
11581158 no_trim = False
11591159 ind [ai ] = slice (None , si , None )
11601160 if no_trim :
@@ -1203,12 +1203,12 @@ def rfftn(x, s=None, axes=None, fwd_scale=1.0):
12031203 tind = tuple (ind )
12041204 a_inp = a [tind ]
12051205 a_res = _fftnd_impl (
1206- a_inp , shape = ss , axes = aa ,
1206+ a_inp , s = ss , axes = aa ,
12071207 overwrite_x = True , direction = 1 )
12081208 if a_res is not a_inp :
12091209 a [tind ] = a_res # copy in place
12101210 else :
1211- for ii in range (len (axes )- 1 ):
1211+ for ii in range (len (axes ) - 2 , - 1 , - 1 ):
12121212 a = fft (a , s [ii ], axes [ii ], overwrite_x = True )
12131213 return a
12141214
@@ -1218,6 +1218,8 @@ def irfftn(x, s=None, axes=None, fwd_scale=1.0):
12181218 no_trim = (s is None ) and (axes is None )
12191219 s , axes = _cook_nd_args (a , s , axes , invreal = True )
12201220 la = axes [- 1 ]
1221+ if not no_trim :
1222+ a = _trim_array (a , s , axes )
12211223 if len (s ) > 1 :
12221224 if not no_trim :
12231225 a = _fix_dimensions (a , s , axes )
@@ -1227,14 +1229,18 @@ def irfftn(x, s=None, axes=None, fwd_scale=1.0):
12271229 if not ovr_x :
12281230 a = a .copy ()
12291231 ovr_x = True
1232+ if not np .issubdtype (a .dtype , np .complexfloating ):
1233+ # copy is needed, because output of complex type will be copied to input
1234+ a = a .astype (np .complex64 ) if a .dtype == np .float32 else a .astype (np .complex128 )
1235+ ovr_x = True
12301236 ss , aa = _remove_axis (s , axes , - 1 )
1231- ind = [slice (None ,None ,1 ),] * len (s )
1237+ ind = [slice (None , None , 1 ),] * len (s )
12321238 for ii in range (a .shape [la ]):
12331239 ind [la ] = ii
12341240 tind = tuple (ind )
12351241 a_inp = a [tind ]
12361242 a_res = _fftnd_impl (
1237- a_inp , shape = ss , axes = aa ,
1243+ a_inp , s = ss , axes = aa ,
12381244 overwrite_x = True , direction = - 1 )
12391245 if a_res is not a_inp :
12401246 a [tind ] = a_res # copy in place
0 commit comments