@@ -258,23 +258,43 @@ def _iter_fftnd(
258258 axes = None ,
259259 out = None ,
260260 direction = + 1 ,
261- overwrite_x = False ,
262- scale_function = lambda n , ind : 1.0 ,
261+ scale_function = lambda ind : 1.0 ,
263262):
264263 a = np .asarray (a )
265264 s , axes = _init_nd_shape_and_axes (a , s , axes )
266- ovwr = overwrite_x
267- for ii in reversed (range (len (axes ))):
265+
266+ # Combine the two, but in reverse, to end with the first axis given.
267+ axes_and_s = list (zip (axes , s ))[::- 1 ]
268+ # We try to use in-place calculations where possible, which is
269+ # everywhere except when the size changes after the first FFT.
270+ size_changes = [axis for axis , n in axes_and_s [1 :] if a .shape [axis ] != n ]
271+
272+ # If there are any size changes, we cannot use out
273+ res = None if size_changes else out
274+ for ind , (axis , n ) in enumerate (axes_and_s ):
275+ if axis in size_changes :
276+ if axis == size_changes [- 1 ]:
277+ # Last size change, so any output should now be OK
278+ # (an error will be raised if not), and if no output is
279+ # required, we want a freshly allocated array of the right size.
280+ res = out
281+ elif res is not None and n < res .shape [axis ]:
282+ # For an intermediate step where we return fewer elements, we
283+ # can use a smaller view of the previous array.
284+ res = res [(slice (None ),) * axis + (slice (n ),)]
285+ else :
286+ # If we need more elements, we cannot use res.
287+ res = None
268288 a = _c2c_fft1d_impl (
269289 a ,
270- n = s [ii ],
271- axis = axes [ii ],
272- overwrite_x = ovwr ,
290+ n = n ,
291+ axis = axis ,
273292 direction = direction ,
274- fsc = scale_function (s [ ii ], ii ),
275- out = out ,
293+ fsc = scale_function (ind ),
294+ out = res ,
276295 )
277- ovwr = True
296+ # Default output for next iteration.
297+ res = a
278298 return a
279299
280300
@@ -356,7 +376,6 @@ def _c2c_fftnd_impl(
356376 x ,
357377 s = None ,
358378 axes = None ,
359- overwrite_x = False ,
360379 direction = + 1 ,
361380 fsc = 1.0 ,
362381 out = None ,
@@ -381,7 +400,6 @@ def _c2c_fftnd_impl(
381400 if _direct :
382401 return _direct_fftnd (
383402 x ,
384- overwrite_x = overwrite_x ,
385403 direction = direction ,
386404 fsc = fsc ,
387405 out = out ,
@@ -399,11 +417,7 @@ def _c2c_fftnd_impl(
399417 x ,
400418 axes ,
401419 _direct_fftnd ,
402- {
403- "overwrite_x" : overwrite_x ,
404- "direction" : direction ,
405- "fsc" : fsc ,
406- },
420+ {"direction" : direction , "fsc" : fsc },
407421 res ,
408422 )
409423 else :
@@ -414,97 +428,122 @@ def _c2c_fftnd_impl(
414428 axes = axes ,
415429 out = out ,
416430 direction = direction ,
417- overwrite_x = overwrite_x ,
418- scale_function = lambda n , i : fsc if i == 0 else 1.0 ,
431+ scale_function = lambda i : fsc if i == 0 else 1.0 ,
419432 )
420433
421434
422435def _r2c_fftnd_impl (x , s = None , axes = None , fsc = 1.0 , out = None ):
423436 a = np .asarray (x )
424437 no_trim = (s is None ) and (axes is None )
425438 s , axes = _cook_nd_args (a , s , axes )
439+ axes = [ax + a .ndim if ax < 0 else ax for ax in axes ]
426440 la = axes [- 1 ]
441+
427442 # trim array, so that rfft avoids doing unnecessary computations
428443 if not no_trim :
429444 a = _trim_array (a , s , axes )
445+
446+ # last axis is not included since we calculate r2c FFT separately
447+ # and not in the loop
448+ axes_and_s = list (zip (axes , s ))[- 2 ::- 1 ]
449+ size_changes = [axis for axis , n in axes_and_s if a .shape [axis ] != n ]
450+ res = None if size_changes else out
451+
430452 # r2c along last axis
431- a = _r2c_fft1d_impl (a , n = s [- 1 ], axis = la , fsc = fsc , out = out )
453+ a = _r2c_fft1d_impl (a , n = s [- 1 ], axis = la , fsc = fsc , out = res )
454+ res = a
432455 if len (s ) > 1 :
433- if not no_trim :
434- ss = list (s )
435- ss [- 1 ] = a .shape [la ]
436- a = _pad_array (a , tuple (ss ), axes )
456+
437457 len_axes = len (axes )
438458 if len (set (axes )) == len_axes and len_axes == a .ndim and len_axes > 2 :
459+ if not no_trim :
460+ ss = list (s )
461+ ss [- 1 ] = a .shape [la ]
462+ a = _pad_array (a , tuple (ss ), axes )
439463 # a series of ND c2c FFTs along last axis
440464 ss , aa = _remove_axis (s , axes , - 1 )
441- ind = [
442- slice (None , None , 1 ),
443- ] * len (s )
465+ ind = [slice (None , None , 1 )] * len (s )
444466 for ii in range (a .shape [la ]):
445467 ind [la ] = ii
446468 tind = tuple (ind )
447469 a_inp = a [tind ]
448- res = out [tind ] if out is not None else None
449- a_res = _c2c_fftnd_impl (
450- a_inp , s = ss , axes = aa , overwrite_x = True , direction = 1 , out = res
451- )
452- if a_res is not a_inp :
453- a [tind ] = a_res # copy in place
470+ res = out [tind ] if out is not None else a_inp
471+ _ = _c2c_fftnd_impl (a_inp , s = ss , axes = aa , direction = 1 , out = res )
472+ if out is not None :
473+ a = out
454474 else :
475+ # another size_changes check is needed if there are repeated axes
476+ # of last axis, since since FFT changes the shape along last axis
477+ size_changes = [
478+ axis for axis , n in axes_and_s if a .shape [axis ] != n
479+ ]
480+
455481 # a series of 1D c2c FFTs along all axes except last
456- for ii in range (len (axes ) - 2 , - 1 , - 1 ):
457- a = _c2c_fft1d_impl (a , s [ii ], axes [ii ], overwrite_x = True )
482+ for axis , n in axes_and_s :
483+ if axis in size_changes :
484+ if axis == size_changes [- 1 ]:
485+ res = out
486+ elif res is not None and n < res .shape [axis ]:
487+ res = res [(slice (None ),) * axis + (slice (n ),)]
488+ else :
489+ res = None
490+ a = _c2c_fft1d_impl (a , n , axis , out = res )
491+ res = a
458492 return a
459493
460494
461495def _c2r_fftnd_impl (x , s = None , axes = None , fsc = 1.0 , out = None ):
462496 a = np .asarray (x )
463497 no_trim = (s is None ) and (axes is None )
464498 s , axes = _cook_nd_args (a , s , axes , invreal = True )
499+ axes = [ax + a .ndim if ax < 0 else ax for ax in axes ]
465500 la = axes [- 1 ]
466501 if not no_trim :
467502 a = _trim_array (a , s , axes )
468503 if len (s ) > 1 :
469- if not no_trim :
470- a = _pad_array (a , s , axes )
471- ovr_x = True if _datacopied (a , x ) else False
472504 len_axes = len (axes )
473505 if len (set (axes )) == len_axes and len_axes == a .ndim and len_axes > 2 :
506+ if not no_trim :
507+ a = _pad_array (a , s , axes )
474508 # a series of ND c2c FFTs along last axis
475509 # due to need to write into a, we must copy
476- if not ovr_x :
477- a = a .copy ()
478- ovr_x = True
510+ a = a if _datacopied (a , x ) else a .copy ()
479511 if not np .issubdtype (a .dtype , np .complexfloating ):
480512 # complex output will be copied to input, copy is needed
481513 if a .dtype == np .float32 :
482514 a = a .astype (np .complex64 )
483515 else :
484516 a = a .astype (np .complex128 )
485- ovr_x = True
486517 ss , aa = _remove_axis (s , axes , - 1 )
487- ind = [
488- slice (None , None , 1 ),
489- ] * len (s )
518+ ind = [slice (None , None , 1 )] * len (s )
490519 for ii in range (a .shape [la ]):
491520 ind [la ] = ii
492521 tind = tuple (ind )
493522 a_inp = a [tind ]
494523 # out has real dtype and cannot be used in intermediate steps
495- a_res = _c2c_fftnd_impl (
496- a_inp , s = ss , axes = aa , overwrite_x = True , direction = - 1
524+ # ss and aa are reversed since np.irfftn uses forward order but
525+ # np.ifftn uses reverse order see numpy-gh-28950
526+ _ = _c2c_fftnd_impl (
527+ a_inp , s = ss [::- 1 ], axes = aa [::- 1 ], out = a_inp , direction = - 1
497528 )
498- if a_res is not a_inp :
499- a [tind ] = a_res # copy in place
500529 else :
501530 # a series of 1D c2c FFTs along all axes except last
502- for ii in range (len (axes ) - 1 ):
503- # out has real dtype and cannot be used in intermediate steps
504- a = _c2c_fft1d_impl (
505- a , s [ii ], axes [ii ], overwrite_x = ovr_x , direction = - 1
506- )
507- ovr_x = True
531+ # forward order, see numpy-gh-28950
532+ axes_and_s = list (zip (axes , s ))[:- 1 ]
533+ size_changes = [
534+ axis for axis , n in axes_and_s [1 :] if a .shape [axis ] != n
535+ ]
536+ # out has real dtype cannot be used for intermediate steps
537+ res = None
538+ for axis , n in axes_and_s :
539+ if axis in size_changes :
540+ if res is not None and n < res .shape [axis ]:
541+ # pylint: disable=unsubscriptable-object
542+ res = res [(slice (None ),) * axis + (slice (n ),)]
543+ else :
544+ res = None
545+ a = _c2c_fft1d_impl (a , n , axis , out = res , direction = - 1 )
546+ res = a
508547 # c2r along last axis
509548 a = _c2r_fft1d_impl (a , n = s [- 1 ], axis = la , fsc = fsc , out = out )
510549 return a
0 commit comments