@@ -427,7 +427,7 @@ def _istft(
427427
428428 expected_output_signal_len = n_fft .val + hop_length .val * (n_frames - 1 )
429429
430- is_onesided = onesided . val if onesided else fft_size != n_fft
430+ is_onesided = True if fft_size != n_fft . val else onesided and onesided . val
431431 cos_base , sin_base = _calculate_dft_matrix (n_fft , onesided = is_onesided , before_op = before_op )
432432
433433 # create a window of centered 1s of the requested size
@@ -481,20 +481,18 @@ def _istft(
481481 window_envelope = _overlap_add (x = window_mtx , n_fft = n_fft , hop_length = hop_length , before_op = before_op )
482482 real_result = mb .real_div (x = real_result , y = window_envelope , before_op = before_op )
483483 imag_result = mb .real_div (x = imag_result , y = window_envelope , before_op = before_op )
484-
485484 # We need to adapt last dimension
486485 if length is not None :
487486 if length .val > expected_output_signal_len :
487+ real_result = mb .pad (x = real_result , pad = (0 , length .val - expected_output_signal_len ), before_op = before_op )
488+ imag_result = mb .pad (x = imag_result , pad = (0 , length .val - expected_output_signal_len ), before_op = before_op )
489+ elif length .val < expected_output_signal_len :
488490 if channels :
489- right_pad = mb .fill (shape = (channels , length .val - expected_output_signal_len ), value = 0. , before_op = before_op )
491+ real_result = mb .slice_by_size (x = real_result , begin = [0 ,0 ], size = [- 1 , length .val ], before_op = before_op )
492+ imag_result = mb .slice_by_size (x = imag_result , begin = [0 ,0 ], size = [- 1 , length .val ], before_op = before_op )
490493 else :
491- right_pad = mb .fill (shape = (length .val - expected_output_signal_len ,), value = 0. , before_op = before_op )
492-
493- real_result = mb .stack (values = (real_result , right_pad ), axis = 1 , before_op = before_op )
494- imag_result = mb .stack (values = (imag_result , right_pad ), axis = 1 , before_op = before_op )
495- elif length .val < expected_output_signal_len :
496- real_result = mb .slice_by_size (x = real_result , begin = [0 ], size = [length .val ], before_op = before_op )
497- imag_result = mb .slice_by_size (x = imag_result , begin = [0 ], size = [length .val ], before_op = before_op )
494+ real_result = mb .slice_by_size (x = real_result , begin = [0 ], size = [length .val ], before_op = before_op )
495+ imag_result = mb .slice_by_size (x = imag_result , begin = [0 ], size = [length .val ], before_op = before_op )
498496
499497 return real_result , imag_result
500498
0 commit comments