@@ -397,6 +397,7 @@ def _istft(
397397 hop_length : Optional [Var ],
398398 win_length : Optional [Var ],
399399 window : Optional [Var ],
400+ center : Optional [Var ],
400401 normalized : Optional [Var ],
401402 onesided : Optional [Var ],
402403 length : Optional [Var ],
@@ -435,12 +436,10 @@ def _istft(
435436 cos_base = mb .mul (x = window , y = cos_base , before_op = before_op )
436437 sin_base = mb .mul (x = window , y = sin_base , before_op = before_op )
437438
438- cos_base = mb .expand_dims (x = cos_base , axes = (1 ,), before_op = before_op )
439- sin_base = mb .expand_dims (x = sin_base , axes = (1 ,), before_op = before_op )
440439 hop_size = mb .expand_dims (x = hop_length , axes = (0 ,), before_op = before_op )
441440
442- signal_real = mb . expand_dims ( x = input_real , axes = ( 1 ,), before_op = before_op )
443- signal_imaginary = mb . expand_dims ( x = input_imaginary , axes = ( 1 ,), before_op = before_op )
441+ signal_real = input_real
442+ signal_imaginary = input_imaginary
444443
445444 # De-normalized signal before applying the IFT
446445 if normalized and normalized .val :
@@ -455,15 +454,16 @@ def _istft(
455454 # So using the definition in stft function, we get:
456455 # real(x[n]) = Σ(a[k]*cos(2π*k/K*n)+b[k]*sin(2π*k/K*n))
457456 # imag(x[n]) = Σ(b[k]*cos(2π*k/K*n)-a[k]*sin(2π*k/K*n))
458- cos_windows_real = mb .conv (x = signal_real , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
459- sin_windows_real = mb .conv (x = signal_real , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
460- cos_windows_imag = mb .conv (x = signal_imaginary , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
461- sin_windows_imag = mb .conv (x = signal_imaginary , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
457+ cos_windows_real = mb .matmul (x = signal_real , y = cos_base , transpose_x = True , before_op = before_op )
458+ sin_windows_real = mb .matmul (x = signal_real , y = sin_base , transpose_x = True , before_op = before_op )
459+ cos_windows_imag = mb .matmul (x = signal_imaginary , y = cos_base , transpose_x = True , before_op = before_op )
460+ sin_windows_imag = mb .matmul (x = signal_imaginary , y = sin_base , transpose_x = True , before_op = before_op )
462461
463462 real_result = mb .add (x = cos_windows_real , y = sin_windows_imag , before_op = before_op )
464463 imag_result = mb .sub (x = cos_windows_imag , y = sin_windows_real , before_op = before_op )
465464
466465 # Divide by N
466+ n_fft = mb .cast (x = n_fft , dtype = "fp32" , before_op = before_op )
467467 real_result = mb .real_div (x = real_result , y = n_fft , before_op = before_op )
468468 imag_result = mb .real_div (x = imag_result , y = n_fft , before_op = before_op )
469469
@@ -472,9 +472,9 @@ def _istft(
472472 imag_result = _overlap_add (x = imag_result , n_fft = n_fft , hop_length = hop_length , before_op = before_op )
473473
474474 # Normalize by the window square
475- n_frames = mb .shape (x = real_result , before_op = before_op )[1 ]
476475 window_square = mb .mul (x = window , y = window , before_op = before_op )
477- window_mtx = mb .stack (values = [window_square ] * n_frames , axis = 1 )
476+ window_mtx = mb .stack (values = [window_square ] * n_frames , axis = 0 , before_op = before_op )
477+ window_mtx = mb .expand_dims (x = window_mtx , axes = (0 ,), before_op = before_op )
478478 window_envelope = _overlap_add (x = window_mtx , n_fft = n_fft , hop_length = hop_length , before_op = before_op )
479479 real_result = mb .real_div (x = real_result , y = window_envelope , before_op = before_op )
480480 imag_result = mb .real_div (x = imag_result , y = window_envelope , before_op = before_op )
@@ -502,17 +502,27 @@ def _overlap_add(
502502 """
503503 input_shape = mb .shape (x = x , before_op = before_op )
504504 channels = input_shape .val [0 ]
505- n_frames = input_shape .val [2 ]
505+ n_frames = input_shape .val [1 ]
506+
507+ # Create empty output with final shape
508+ output = mb .fill (shape = (channels , int (n_fft .val + hop_length .val * (n_frames - 1 ))), value = 0. , before_op = before_op )
506509
507- output = mb . fill ( shape = ( channels , n_fft . val + hop_length . val * ( n_frames - 1 )), value = 0. , before_op = before_op )
508- signal_frames = mb .split (x = x , num_splits = n_frames , axis = 2 , before_op = before_op )
510+ # Create an index used later on overlap add
511+ n_fft = mb .cast (x = n_fft , dtype = "int32" , before_op = before_op )
509512 local_idx = mb .range_1d (start = 0 , end = n_fft , step = 1 , before_op = before_op )
510513
514+ # Split data into frames and iterate
515+ signal_frames = mb .split (x = x , num_splits = n_frames , axis = 1 , before_op = before_op )
516+
511517 for frame_num , frame in enumerate (signal_frames ):
518+ frame = mb .squeeze (x = frame , axes = [1 ], before_op = before_op )
519+
520+ # Create index to align data frames
512521 global_idx = mb .add (x = local_idx , y = frame_num * hop_length .val , before_op = before_op )
513- global_idx = mb .expand_dims (x = global_idx , axes = (0 ,), before_op = before_op )
514- global_idx = mb .stack (values = [global_idx ] * channels , axis = 0 )
515- output = mb .scatter_nd (data = output , indices = global_idx , updates = frame , before_op = before_op )
522+ global_idx = mb .stack (values = [global_idx ] * channels , axis = 0 , before_op = before_op )
523+
524+ # Add data frame
525+ output = mb .scatter_along_axis (data = output , indices = global_idx , updates = frame , axis = 1 , mode = "add" , before_op = before_op )
516526
517527 return output
518528
@@ -748,19 +758,18 @@ def _lower_complex_stft(op: Operation):
748758
749759@LowerComplex .register_lower_func (op_type = "complex_istft" )
750760def _lower_complex_istft (op : Operation ):
751- is_complex = types .is_complex (op .input .dtype )
752761
753762 # check parameters for validity
754- if is_complex :
755- raise ValueError ( "Only complex inputs are allowed " )
763+ if not types . is_complex ( op . input . dtype ) :
764+ raise TypeError ( "Input type must be complex " )
756765 if op .win_length and op .win_length .val > op .n_fft .val :
757766 raise ValueError ("Window length must be less than or equal to n_fft" )
758767 if op .return_complex and op .onesided and op .onesided .val :
759768 raise ValueError ("Complex output is not compatible with onesided" )
760769
761770 real , imag = _istft (
762771 op .input .real , op .input .imag ,
763- op .n_fft , op .hop_length , op .win_length , op .window , op .normalized , op .onesided , op .length , before_op = op )
772+ op .n_fft , op .hop_length , op .win_length , op .window , op .center , op . normalized , op .onesided , op .length , before_op = op )
764773
765774 if op .return_complex :
766775 return _wrap_complex_output (op .outputs [0 ], real , imag )
0 commit comments