Skip to content

Commit 40b6b3f

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
Add ISTFT
1 parent f598456 commit 40b6b3f

File tree

2 files changed

+174
-0
lines changed

2 files changed

+174
-0
lines changed

coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -861,3 +861,85 @@ def type_inference(self):
861861

862862
return types.tensor(output_type, tuple(output_shape))
863863

864+
@register_op(namespace="complex")
865+
class complex_istft(Operation):
866+
"""
867+
Dialect op for 1-D ISTFT.
868+
869+
Parameters
870+
----------
871+
input: tensor<\*V, complex64> (Required)
872+
* A complex tensor where real and imag parts have the same shape.
873+
n_fft: const i32 (Required)
874+
* Size of the fourier transform.
875+
hop_length: const i32 (Optional)
876+
* Stride between window frames of the input tensor.
877+
win_length: const i32 (optional)
878+
* The size of the window frame.
879+
window: tensor<1, win_length> (optional)
880+
* The window to apply to the input signal before performing the fourier transform.
881+
normalized: const bool (optional, Default=``false``)
882+
* Whether to normalize the results of the STFT
883+
onesided: const bool (optional, Default=``true``)
884+
* Whether the STFT was onesieded
885+
length: const i32 (Required)
886+
* Output fixed length, which will be zeropadded
887+
888+
889+
Returns
890+
-------
891+
tensor<\*D, T>
892+
* The output tensor
893+
894+
Attributes
895+
----------
896+
T: fp32, complex64
897+
898+
References
899+
----------
900+
See `torch.istft <https://pytorch.org/docs/2.0/generated/torch.istft.html>`_.
901+
"""
902+
903+
input_spec = InputSpec(
904+
input=TensorInputType(type_domain="T"),
905+
n_fft=TensorInputType(const=True, type_domain=types.int32),
906+
hop_length=TensorInputType(const=True, optional=True, type_domain=types.int32),
907+
win_length=TensorInputType(const=True, optional=True, type_domain=types.int32),
908+
window=TensorInputType(const=True, optional=True, type_domain=types.fp32),
909+
normalized=TensorInputType(const=True, optional=True, type_domain=types.bool),
910+
onesided=TensorInputType(const=True, optional=True, type_domain=types.bool),
911+
length=TensorInputType(const=True, optional=True, type_domain=types.int32),
912+
)
913+
914+
type_domains = {
915+
"T": (types.fp32, types.complex64),
916+
}
917+
918+
def default_inputs(self):
919+
return DefaultInputs(
920+
hop_length = None,
921+
win_length = None,
922+
window = None,
923+
normalized = False,
924+
onesided = True,
925+
length = None
926+
)
927+
928+
def type_inference(self):
929+
output_type = (types.fp32)
930+
output_shape = []
931+
932+
# add back rank if needed
933+
if self.input.rank == 2:
934+
output_shape += [self.input.shape[0]]
935+
936+
if self.length:
937+
output_shape += [self.length]
938+
return types.tensor(output_type, tuple(output_shape))
939+
940+
941+
n_frames = self.input.shape[-1]
942+
output_shape = self.n_fft.val + self.hop_length.val * (n_frames - 1)
943+
944+
return types.tensor(output_type, tuple(output_shape))
945+

coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,80 @@ def _stft(
376376
real_result = cos_windows_real
377377
imag_result = sin_windows_real
378378

379+
def _istft(
380+
input_real: Var,
381+
input_imaginary: Var,
382+
n_fft: Var,
383+
hop_length: Optional[Var],
384+
win_length: Optional[Var],
385+
window: Optional[Var],
386+
normalized: Optional[Var],
387+
onesided: Optional[Var],
388+
before_op: Operation,
389+
) -> Tuple[Var, Var]:
390+
"""
391+
We can write ISTFT in terms of convolutions with a DFT kernel.
392+
At the end:
393+
* The real part output is: cos_base * input_real + sin_base * input_imag
394+
* The imaginary part output is: - (sin_base * input_real - cos_base * input_imag)
395+
Adapted from: https://github.com/adobe-research/convmelspec/blob/main/convmelspec/mil.py
396+
"""
397+
# Set the default hop, if it's not already specified
398+
hop_length = hop_length or mb.floor_div(x=n_fft, y=4, before_op=before_op)
399+
400+
# By default, use the entire frame
401+
win_length = win_length or n_fft
402+
403+
# input should always be 2D
404+
should_increase_rank = input_real.rank == 1
405+
if should_increase_rank:
406+
input_real = mb.expand_dims(x=input_real, axes=(0,), before_op=before_op)
407+
if input_imaginary:
408+
input_imaginary = mb.expand_dims(x=input_imaginary, axes=(0,), before_op=before_op)
409+
410+
is_onesided = onesided and onesided.val
411+
cos_base, sin_base = _calculate_dft_matrix(n_fft, onesided=is_onesided, before_op=before_op)
412+
413+
# create a window of centered 1s of the requested size
414+
if win_length:
415+
window = _get_window(win_length=win_length, n_fft=n_fft, before_op=before_op)
416+
417+
# apply time window
418+
if window:
419+
cos_base = mb.mul(x=window, y=cos_base, before_op=before_op)
420+
sin_base = mb.mul(x=window, y=sin_base, before_op=before_op)
421+
422+
# The DFT matrix is obtained with the equation e^(2pi/N i), which is what we want but we actually need the conjuate => e^(-2pi/N i)
423+
# or in terms of cos and sin => cos+i*sin cos-i*sin
424+
sin_base = mb.sub(x=0., ysin_base, before_op=before_op)
425+
426+
cos_base = mb.expand_dims(x=cos_base, axes=(1,), before_op=before_op)
427+
sin_base = mb.expand_dims(x=sin_base, axes=(1,), before_op=before_op)
428+
hop_size = mb.expand_dims(x=hop_length, axes=(0,), before_op=before_op)
429+
430+
signal_real = mb.expand_dims(x=input_real, axes=(1,), before_op=before_op)
431+
signal_imaginary = mb.expand_dims(x=input_imaginary, axes=(1,), before_op=before_op)
432+
433+
# Conv with DFT kernel across the input signal
434+
# We can describe the IDFT in terms of DFT just by swapping the input and output
435+
# ref: https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Expressing_the_inverse_DFT_in_terms_of_the_DFT
436+
# So IDFT(x) = (1/N) * swap(DFT(swap(x)))
437+
# DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i)
438+
# If x is complex then x[n]=(a+i*b)
439+
# So the real part = (1/N)*Σ(a*cos(2kpi/N)-b*sin(2kpi/N))
440+
# So the imag part = (1/N)*Σ(b*cos(2kpi/N)+a*sin(2kpi/N))
441+
cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op)
442+
sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op)
443+
cos_windows_imag = mb.conv(x=signal_imaginary, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op)
444+
sin_windows_imag = mb.conv(x=signal_imaginary, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op)
445+
446+
real_result = mb.sub(x=cos_windows_real, y=sin_windows_imag, before_op=before_op)
447+
imag_result = mb.add(x=cos_windows_imag, y=sin_windows_real, before_op=before_op)
448+
449+
# Divide by N
450+
real_result = mb.real_div(x=real_result, y=n_fft, before_op=before_op)
451+
imag_result = mb.real_div(x=imag_result, y=n_fft, before_op=before_op)
452+
379453
# Overlap-add
380454
real_result = _overlap_add(x=real_result, n_fft=n_fft, hop_length=hop_length, before_op=before_op)
381455
imag_result = _overlap_add(x=imag_result, n_fft=n_fft, hop_length=hop_length, before_op=before_op)
@@ -638,6 +712,24 @@ def _lower_complex_stft(op: Operation):
638712
return _wrap_complex_output(op.outputs[0], real, imag)
639713

640714

715+
@LowerComplex.register_lower_func(op_type="complex_istft")
716+
def _lower_complex_istft(op: Operation):
717+
is_complex = types.is_complex(op.input.dtype)
718+
719+
# check parameters for validity
720+
if op.win_length and op.win_length.val > op.n_fft.val:
721+
raise ValueError("Window length must be less than or equal to n_fft")
722+
if is_complex and op.onesided and op.onesided.val:
723+
raise ValueError("Onesided is only valid for real inputs")
724+
725+
real, imag = _istft(
726+
op.input.real if is_complex else op.input,
727+
op.input.imag if is_complex else None,
728+
op.n_fft, op.hop_length, op.win_length, op.window, op.normalized, op.onesided, before_op=op)
729+
730+
return _wrap_complex_output(op.outputs[0], real, imag)
731+
732+
641733
@LowerComplex.register_lower_func(op_type="complex_shape")
642734
def _lower_complex_shape(op: Operation):
643735
return mb.shape(x=op.data.real, before_op=op)

0 commit comments

Comments
 (0)