@@ -419,7 +419,7 @@ def _istft(
419419 win_length = win_length or n_fft
420420
421421 input_shape = mb .shape (x = input_real , before_op = before_op )
422- if input_shape .rank == 3 :
422+ if input_real .rank == 3 :
423423 channels , fft_size , n_frames = input_shape .val
424424 else :
425425 channels = None
@@ -510,7 +510,7 @@ def _overlap_add(
510510 input_shape = mb .shape (x = x , before_op = before_op )
511511
512512 # Create empty output with final shape
513- if input_shape .rank == 3 :
513+ if x .rank == 3 :
514514 channels , n_frames , _ = input_shape .val
515515 output = mb .fill (shape = (channels , int (n_fft .val + hop_length .val * (n_frames - 1 )),), value = 0. , before_op = before_op )
516516 else :
@@ -523,18 +523,18 @@ def _overlap_add(
523523 local_idx = mb .range_1d (start = 0 , end = n_fft , step = 1 , before_op = before_op )
524524
525525 # Split data into frames and iterate
526- signal_frames = mb .split (x = x , num_splits = n_frames , axis = 1 , before_op = before_op )
526+ signal_frames = mb .split (x = x , num_splits = n_frames , axis = 1 if channels else 0 , before_op = before_op )
527527
528528 for frame_num , frame in enumerate (signal_frames ):
529- frame = mb .squeeze (x = frame , axes = [1 ], before_op = before_op )
529+ frame = mb .squeeze (x = frame , axes = [1 ] if channels else [ 0 ] , before_op = before_op )
530530
531531 # Create index to align data frames
532532 global_idx = mb .add (x = local_idx , y = frame_num * hop_length .val , before_op = before_op )
533533 if channels :
534534 global_idx = mb .stack (values = [global_idx ] * channels , axis = 0 , before_op = before_op )
535535
536536 # Add data frame
537- output = mb .scatter_along_axis (data = output , indices = global_idx , updates = frame , axis = 1 , mode = "add" , before_op = before_op )
537+ output = mb .scatter_along_axis (data = output , indices = global_idx , updates = frame , axis = 1 if channels else 0 , mode = "add" , before_op = before_op )
538538
539539 return output
540540
0 commit comments