@@ -594,6 +594,85 @@ def irfft_numpy(x, n=None, axis=-1):
594594
595595# ============================== ND ====================================== #
596596
597+ # copied from scipy.fftpack.helper
598+ def _init_nd_shape_and_axes (x , shape , axes ):
599+ """Handle shape and axes arguments for n-dimensional transforms.
600+ Returns the shape and axes in a standard form, taking into account negative
601+ values and checking for various potential errors.
602+ Parameters
603+ ----------
604+ x : array_like
605+ The input array.
606+ shape : int or array_like of ints or None
607+ The shape of the result. If both `shape` and `axes` (see below) are
608+ None, `shape` is ``x.shape``; if `shape` is None but `axes` is
609+ not None, then `shape` is ``scipy.take(x.shape, axes, axis=0)``.
610+ If `shape` is -1, the size of the corresponding dimension of `x` is
611+ used.
612+ axes : int or array_like of ints or None
613+ Axes along which the calculation is computed.
614+ The default is over all axes.
615+ Negative indices are automatically converted to their positive
616+ counterpart.
617+ Returns
618+ -------
619+ shape : array
620+ The shape of the result. It is a 1D integer array.
621+ axes : array
622+ The shape of the result. It is a 1D integer array.
623+ """
624+ x = np .asarray (x )
625+ noshape = shape is None
626+ noaxes = axes is None
627+
628+ if noaxes :
629+ axes = np .arange (x .ndim , dtype = np .intc )
630+ else :
631+ axes = np .atleast_1d (axes )
632+
633+ if axes .size == 0 :
634+ axes = axes .astype (np .intc )
635+
636+ if not axes .ndim == 1 :
637+ raise ValueError ("when given, axes values must be a scalar or vector" )
638+ if not np .issubdtype (axes .dtype , np .integer ):
639+ raise ValueError ("when given, axes values must be integers" )
640+
641+ axes = np .where (axes < 0 , axes + x .ndim , axes )
642+
643+ if axes .size != 0 and (axes .max () >= x .ndim or axes .min () < 0 ):
644+ raise ValueError ("axes exceeds dimensionality of input" )
645+ if axes .size != 0 and np .unique (axes ).shape != axes .shape :
646+ raise ValueError ("all axes must be unique" )
647+
648+ if not noshape :
649+ shape = np .atleast_1d (shape )
650+ elif np .isscalar (x ):
651+ shape = np .array ([], dtype = np .intc )
652+ elif noaxes :
653+ shape = np .array (x .shape , dtype = np .intc )
654+ else :
655+ shape = np .take (x .shape , axes )
656+
657+ if shape .size == 0 :
658+ shape = shape .astype (np .intc )
659+
660+ if shape .ndim != 1 :
661+ raise ValueError ("when given, shape values must be a scalar or vector" )
662+ if not np .issubdtype (shape .dtype , np .integer ):
663+ raise ValueError ("when given, shape values must be integers" )
664+ if axes .shape != shape .shape :
665+ raise ValueError ("when given, axes and shape arguments"
666+ " have to be of the same length" )
667+
668+ shape = np .where (shape == - 1 , np .array (x .shape )[axes ], shape )
669+
670+ if shape .size != 0 and (shape < 1 ).any ():
671+ raise ValueError (
672+ "invalid number of data points ({0}) specified" .format (shape ))
673+
674+ return shape , axes
675+
597676
598677def _cook_nd_args (a , s = None , axes = None , invreal = 0 ):
599678 if s is None :
@@ -621,7 +700,7 @@ def _cook_nd_args(a, s=None, axes=None, invreal=0):
621700
622701def _iter_fftnd (a , s = None , axes = None , function = fft , overwrite_arg = False ):
623702 a = np .asarray (a )
624- s , axes = _cook_nd_args (a , s , axes )
703+ s , axes = _init_nd_shape_and_axes (a , s , axes )
625704 ovwr = overwrite_arg
626705 for ii in reversed (range (len (axes ))):
627706 a = function (a , n = s [ii ], axis = axes [ii ], overwrite_x = ovwr )
0 commit comments