@@ -419,9 +419,12 @@ def _istft(
419419 win_length = win_length or n_fft
420420
421421 input_shape = mb .shape (x = input_real , before_op = before_op )
422- channels = input_shape .val [0 ]
423- fft_size = input_shape .val [1 ]
424- n_frames = input_shape .val [2 ]
422+ if input_shape .rank == 3 :
423+ channels , fft_size , n_frames = input_shape .val
424+ else :
425+ channels = None
426+ fft_size , n_frames = input_shape .val
427+
425428 expected_output_signal_len = n_fft .val + hop_length .val * (n_frames - 1 )
426429
427430 is_onesided = onesided .val if onesided else fft_size != n_fft
@@ -482,12 +485,16 @@ def _istft(
482485 # We need to adapt last dimension
483486 if length is not None :
484487 if length .val > expected_output_signal_len :
485- right_pad = mb .fill (shape = (channels , expected_output_signal_len - length ), value = 0. , before_op = before_op )
488+ if channels :
489+ right_pad = mb .fill (shape = (channels , expected_output_signal_len - length ), value = 0. , before_op = before_op )
490+ else :
491+ right_pad = mb .fill (shape = (expected_output_signal_len - length ,), value = 0. , before_op = before_op )
492+
486493 real_result = mb .stack (x = (real_result , right_pad ), axis = 1 , before_op = before_op )
487494 imag_result = mb .stack (x = (imag_result , right_pad ), axis = 1 , before_op = before_op )
488495 elif length .val < expected_output_signal_len :
489- real_result = mb .slice_by_size (x = real_result , begin = [0 ], size = [length ], before_op = before_op )
490- imag_result = mb .slice_by_size (x = imag_result , begin = [0 ], size = [length ], before_op = before_op )
496+ real_result = mb .slice_by_size (x = real_result , begin = [0 ], size = [length . val ], before_op = before_op )
497+ imag_result = mb .slice_by_size (x = imag_result , begin = [0 ], size = [length . val ], before_op = before_op )
491498
492499 return real_result , imag_result
493500
@@ -498,14 +505,18 @@ def _overlap_add(
498505 before_op : Operation ,
499506) -> Var :
500507 """
501- The input has shape (channels, fft_size, n_frames )
508+ The input has shape (channels, n_frames, fft_size )
502509 """
503510 input_shape = mb .shape (x = x , before_op = before_op )
504- channels = input_shape .val [0 ]
505- n_frames = input_shape .val [1 ]
506511
507512 # Create empty output with final shape
508- output = mb .fill (shape = (channels , int (n_fft .val + hop_length .val * (n_frames - 1 ))), value = 0. , before_op = before_op )
513+ if input_shape .rank == 3 :
514+ channels , n_frames = input_shape .val
515+ output = mb .fill (shape = (channels , int (n_fft .val + hop_length .val * (n_frames - 1 ))), value = 0. , before_op = before_op )
516+ else :
517+ channels = None
518+ n_frames = input_shape .val
519+ output = mb .fill (shape = (int (n_fft .val + hop_length .val * (n_frames - 1 )),), value = 0. , before_op = before_op )
509520
510521 # Create an index used later on overlap add
511522 n_fft = mb .cast (x = n_fft , dtype = "int32" , before_op = before_op )
@@ -519,7 +530,8 @@ def _overlap_add(
519530
520531 # Create index to align data frames
521532 global_idx = mb .add (x = local_idx , y = frame_num * hop_length .val , before_op = before_op )
522- global_idx = mb .stack (values = [global_idx ] * channels , axis = 0 , before_op = before_op )
533+ if channels :
534+ global_idx = mb .stack (values = [global_idx ] * channels , axis = 0 , before_op = before_op )
523535
524536 # Add data frame
525537 output = mb .scatter_along_axis (data = output , indices = global_idx , updates = frame , axis = 1 , mode = "add" , before_op = before_op )
0 commit comments