@@ -376,6 +376,80 @@ def _stft(
376376 real_result = cos_windows_real
377377 imag_result = sin_windows_real
378378
379+ def _istft (
380+ input_real : Var ,
381+ input_imaginary : Var ,
382+ n_fft : Var ,
383+ hop_length : Optional [Var ],
384+ win_length : Optional [Var ],
385+ window : Optional [Var ],
386+ normalized : Optional [Var ],
387+ onesided : Optional [Var ],
388+ before_op : Operation ,
389+ ) -> Tuple [Var , Var ]:
390+ """
391+ We can write ISTFT in terms of convolutions with a DFT kernel.
392+ At the end:
393+ * The real part output is: cos_base * input_real + sin_base * input_imag
394+ * The imaginary part output is: - (sin_base * input_real - cos_base * input_imag)
395+ Adapted from: https://github.com/adobe-research/convmelspec/blob/main/convmelspec/mil.py
396+ """
397+ # Set the default hop, if it's not already specified
398+ hop_length = hop_length or mb .floor_div (x = n_fft , y = 4 , before_op = before_op )
399+
400+ # By default, use the entire frame
401+ win_length = win_length or n_fft
402+
403+ # input should always be 2D
404+ should_increase_rank = input_real .rank == 1
405+ if should_increase_rank :
406+ input_real = mb .expand_dims (x = input_real , axes = (0 ,), before_op = before_op )
407+ if input_imaginary :
408+ input_imaginary = mb .expand_dims (x = input_imaginary , axes = (0 ,), before_op = before_op )
409+
410+ is_onesided = onesided and onesided .val
411+ cos_base , sin_base = _calculate_dft_matrix (n_fft , onesided = is_onesided , before_op = before_op )
412+
413+ # create a window of centered 1s of the requested size
414+ if win_length :
415+ window = _get_window (win_length = win_length , n_fft = n_fft , before_op = before_op )
416+
417+ # apply time window
418+ if window :
419+ cos_base = mb .mul (x = window , y = cos_base , before_op = before_op )
420+ sin_base = mb .mul (x = window , y = sin_base , before_op = before_op )
421+
422+ # The DFT matrix is obtained with the equation e^(2pi/N i), which is what we want but we actually need the conjuate => e^(-2pi/N i)
423+ # or in terms of cos and sin => cos+i*sin cos-i*sin
424+ sin_base = mb .sub (x = 0. , ysin_base , before_op = before_op )
425+
426+ cos_base = mb .expand_dims (x = cos_base , axes = (1 ,), before_op = before_op )
427+ sin_base = mb .expand_dims (x = sin_base , axes = (1 ,), before_op = before_op )
428+ hop_size = mb .expand_dims (x = hop_length , axes = (0 ,), before_op = before_op )
429+
430+ signal_real = mb .expand_dims (x = input_real , axes = (1 ,), before_op = before_op )
431+ signal_imaginary = mb .expand_dims (x = input_imaginary , axes = (1 ,), before_op = before_op )
432+
433+ # Conv with DFT kernel across the input signal
434+ # We can describe the IDFT in terms of DFT just by swapping the input and output
435+ # ref: https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Expressing_the_inverse_DFT_in_terms_of_the_DFT
436+ # So IDFT(x) = (1/N) * swap(DFT(swap(x)))
437+ # DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i)
438+ # If x is complex then x[n]=(a+i*b)
439+ # So the real part = (1/N)*Σ(a*cos(2kpi/N)-b*sin(2kpi/N))
440+ # So the imag part = (1/N)*Σ(b*cos(2kpi/N)+a*sin(2kpi/N))
441+ cos_windows_real = mb .conv (x = signal_real , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
442+ sin_windows_real = mb .conv (x = signal_real , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
443+ cos_windows_imag = mb .conv (x = signal_imaginary , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
444+ sin_windows_imag = mb .conv (x = signal_imaginary , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
445+
446+ real_result = mb .sub (x = cos_windows_real , y = sin_windows_imag , before_op = before_op )
447+ imag_result = mb .add (x = cos_windows_imag , y = sin_windows_real , before_op = before_op )
448+
449+ # Divide by N
450+ real_result = mb .real_div (x = real_result , y = n_fft , before_op = before_op )
451+ imag_result = mb .real_div (x = imag_result , y = n_fft , before_op = before_op )
452+
379453 # Overlap-add
380454 real_result = _overlap_add (x = real_result , n_fft = n_fft , hop_length = hop_length , before_op = before_op )
381455 imag_result = _overlap_add (x = imag_result , n_fft = n_fft , hop_length = hop_length , before_op = before_op )
@@ -638,6 +712,24 @@ def _lower_complex_stft(op: Operation):
638712 return _wrap_complex_output (op .outputs [0 ], real , imag )
639713
640714
715+ @LowerComplex .register_lower_func (op_type = "complex_istft" )
716+ def _lower_complex_istft (op : Operation ):
717+ is_complex = types .is_complex (op .input .dtype )
718+
719+ # check parameters for validity
720+ if op .win_length and op .win_length .val > op .n_fft .val :
721+ raise ValueError ("Window length must be less than or equal to n_fft" )
722+ if is_complex and op .onesided and op .onesided .val :
723+ raise ValueError ("Onesided is only valid for real inputs" )
724+
725+ real , imag = _istft (
726+ op .input .real if is_complex else op .input ,
727+ op .input .imag if is_complex else None ,
728+ op .n_fft , op .hop_length , op .win_length , op .window , op .normalized , op .onesided , before_op = op )
729+
730+ return _wrap_complex_output (op .outputs [0 ], real , imag )
731+
732+
641733@LowerComplex .register_lower_func (op_type = "complex_shape" )
642734def _lower_complex_shape (op : Operation ):
643735 return mb .shape (x = op .data .real , before_op = op )
0 commit comments