@@ -454,16 +454,20 @@ def _istft(
454454 real_result = _overlap_add (x = real_result , n_fft = n_fft , hop_length = hop_length , before_op = before_op )
455455 imag_result = _overlap_add (x = imag_result , n_fft = n_fft , hop_length = hop_length , before_op = before_op )
456456
457+ # Normalize by the window square
458+ n_frames = mb .shape (x = real_result , before_op = before_op )[1 ]
459+ window_square = mb .mul (x = window , y = window , before_op = before_op )
460+ window_mtx = mb .stack (values = [window_square ] * n_frames , axis = 1 )
461+ normalization_factor = _overlap_add (x = window_mtx , n_fft = n_fft , hop_length = hop_length , before_op = before_op )
462+
463+ real_result = mb .real_div (x = real_result , y = normalization_factor , before_op = before_op )
464+ imag_result = mb .real_div (x = imag_result , y = normalization_factor , before_op = before_op )
465+
457466 # reduce the rank of the output
458467 if should_increase_rank :
459468 real_result = mb .squeeze (x = real_result , axes = (0 ,), before_op = before_op )
460469 imag_result = mb .squeeze (x = imag_result , axes = (0 ,), before_op = before_op )
461470
462- if normalized and normalized .val :
463- divisor = mb .sqrt (x = mb .cast (x = n_fft , dtype = "fp32" , before_op = before_op ), before_op = before_op )
464- real_result = mb .real_div (x = real_result , y = divisor , before_op = before_op )
465- imag_result = mb .real_div (x = imag_result , y = divisor , before_op = before_op )
466-
467471 return real_result , imag_result
468472
469473def _overlap_add (
@@ -473,7 +477,7 @@ def _overlap_add(
473477 before_op : Operation ,
474478) -> Var :
475479 n_frames = mb .shape (x = x , before_op = before_op )[1 ]
476- output = mb .fill (shape = (n_fft + hop_length * (n_frames - 1 )), value = 0. , before_op = before_op )
480+ output = mb .fill (shape = (n_fft . val + hop_length . val * (n_frames - 1 )), value = 0. , before_op = before_op )
477481 signal_frames = mb .split (x = x , num_splits = n_frames , axis = 1 , before_op = before_op )
478482 local_idx = mb .range_1d (start = 0 , end = n_fft , step = 1 , before_op = before_op )
479483
0 commit comments