@@ -912,6 +912,48 @@ def _iter_fftnd(a, s=None, axes=None, function=fft, overwrite_arg=False, scale_f
912912 return a
913913
914914
915+ def flat_to_multi (ind , shape ):
916+ nd = len (shape )
917+ m_ind = [- 1 ] * nd
918+ j = ind
919+ for i in range (nd ):
920+ si = shape [nd - 1 - i ]
921+ q = j // si
922+ r = j - si * q
923+ m_ind [nd - 1 - i ] = r
924+ j = q
925+ return m_ind
926+
927+
928+ def iter_complementary (x , axes , func , kwargs , result ):
929+ if axes is None :
930+ return func (x , ** kwargs )
931+ x_shape = x .shape
932+ nd = x .ndim
933+ r = list (range (nd ))
934+ sl = [slice (None , None , None )] * nd
935+ if not isinstance (axes , tuple ):
936+ axes = (axes ,)
937+ for ai in axes :
938+ r [ai ] = None
939+ size = 1
940+ sub_shape = []
941+ dual_ind = []
942+ for ri in r :
943+ if ri is not None :
944+ size * = x_shape [ri ]
945+ sub_shape .append (x_shape [ri ])
946+ dual_ind .append (ri )
947+
948+ for ind in range (size ):
949+ m_ind = flat_to_multi (ind , sub_shape )
950+ for k1 , k2 in zip (dual_ind , m_ind ):
951+ sl [k1 ] = k2
952+ np .copyto (result [tuple (sl )], func (x [tuple (sl )], ** kwargs ))
953+
954+ return result
955+
956+
915957def _direct_fftnd (x , overwrite_arg = False , direction = + 1 , double fsc = 1.0 ):
916958 """Perform n-dimensional FFT over all axes"""
917959 cdef int err
@@ -988,6 +1030,7 @@ def _direct_fftnd(x, overwrite_arg=False, direction=+1, double fsc=1.0):
9881030
9891031 return f_arr
9901032
1033+
9911034def _check_shapes_for_direct (xs , shape , axes ):
9921035 if len (axes ) > 7 : # Intel MKL supports up to 7D
9931036 return False
@@ -1006,6 +1049,14 @@ def _check_shapes_for_direct(xs, shape, axes):
10061049 return True
10071050
10081051
1052+ def _output_dtype (dt ):
1053+ if dt == np .double :
1054+ return np .cdouble
1055+ if dt == np .single :
1056+ return np .csingle
1057+ return dt
1058+
1059+
10091060def _fftnd_impl (x , shape = None , axes = None , overwrite_x = False , direction = + 1 , double fsc = 1.0 ):
10101061 if direction not in [- 1 , + 1 ]:
10111062 raise ValueError ("Direction of FFT should +1 or -1" )
@@ -1026,10 +1077,20 @@ def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1, doubl
10261077 if _direct :
10271078 return _direct_fftnd (x , overwrite_arg = overwrite_x , direction = direction , fsc = fsc )
10281079 else :
1029- sc = (< object > fsc )** (1 / x .ndim )
1030- return _iter_fftnd (x , s = shape , axes = axes ,
1031- overwrite_arg = overwrite_x , scale_function = lambda n : sc ,
1032- function = fft if direction == 1 else ifft )
1080+ if (shape is None ):
1081+ x = np .asarray (x )
1082+ res = np .empty (x .shape , dtype = _output_dtype (x .dtype ))
1083+ return iter_complementary (
1084+ x , axes ,
1085+ _direct_fftnd ,
1086+ {'overwrite_arg' : overwrite_x , 'direction' : direction , 'fsc' : fsc },
1087+ res
1088+ )
1089+ else :
1090+ sc = (< object > fsc )** (1 / x .ndim )
1091+ return _iter_fftnd (x , s = shape , axes = axes ,
1092+ overwrite_arg = overwrite_x , scale_function = lambda n : sc ,
1093+ function = fft if direction == 1 else ifft )
10331094
10341095
10351096def fft2 (x , shape = None , axes = (- 2 ,- 1 ), overwrite_x = False , forward_scale = 1.0 ):
0 commit comments