@@ -338,21 +338,22 @@ def irfft(a, n=None, axis=-1, norm=None, workers=None, plan=None):
338338 return NotImplemented
339339 if x is NotImplemented :
340340 return x
341- fsc = _compute_1d_forward_scale (norm , n , x .shape [axis ])
341+ nn = n if n else 2 * (x .shape [axis ]- 1 )
342+ fsc = _compute_1d_forward_scale (norm , nn , x .shape [axis ])
342343 _check_plan (plan )
343344 with Workers (workers ):
344345 output = _pydfti .irfft_numpy (x , n = n , axis = axis , forward_scale = fsc )
345346 return output
346347
347348
348- def _compute_nd_forward_scale_for_rfft (norm , s , axes , x ):
349+ def _compute_nd_forward_scale_for_rfft (norm , s , axes , x , invreal = False ):
349350 if norm in (None , "backward" ):
350351 fsc = 1.0
351352 elif norm == "forward" :
352- s , axes = _cook_nd_args (x , s , axes )
353+ s , axes = _cook_nd_args (x , s , axes , invreal = invreal )
353354 fsc = _frwd_sc_nd (s , axes , x .shape )
354355 elif norm == "ortho" :
355- s , axes = _cook_nd_args (x , s , axes )
356+ s , axes = _cook_nd_args (x , s , axes , invreal = invreal )
356357 fsc = sqrt (_frwd_sc_nd (s , axes , x .shape ))
357358 else :
358359 _check_norm (norm )
@@ -380,7 +381,7 @@ def irfft2(a, s=None, axes=(-2, -1), norm=None, workers=None, plan=None):
380381 return NotImplemented
381382 if x is NotImplemented :
382383 return x
383- s , axes , fsc = _compute_nd_forward_scale_for_rfft (norm , s , axes , x )
384+ s , axes , fsc = _compute_nd_forward_scale_for_rfft (norm , s , axes , x , invreal = True )
384385 _check_plan (plan )
385386 with Workers (workers ):
386387 output = _pydfti .irfftn_numpy (x , s , axes , forward_scale = fsc )
@@ -408,7 +409,7 @@ def irfftn(a, s=None, axes=None, norm=None, workers=None, plan=None):
408409 return NotImplemented
409410 if x is NotImplemented :
410411 return x
411- s , axes , fsc = _compute_nd_forward_scale_for_rfft (norm , s , axes , x )
412+ s , axes , fsc = _compute_nd_forward_scale_for_rfft (norm , s , axes , x , invreal = True )
412413 _check_plan (plan )
413414 with Workers (workers ):
414415 output = _pydfti .irfftn_numpy (x , s , axes , forward_scale = fsc )
0 commit comments