Skip to content

Commit e0cfffd

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
Fixes
1 parent fe0ad05 commit e0cfffd

File tree

4 files changed

+70
-41
lines changed

4 files changed

+70
-41
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6362,6 +6362,7 @@ def stft(context, node):
63626362
Lowers torch.stft with the dialect op `complex_stft` from complex_dialect_ops.py
63636363
"""
63646364
input_data, n_fft, hop_length, win_length, window, normalized, onesided, _ = _get_inputs(context, node, min_expected=2)
6365+
63656366
if types.is_complex(input_data.dtype):
63666367
onesided = False # pytorch defaults onesided to False for complex inputs
63676368
stft_res = mb.complex_stft(
@@ -6371,9 +6372,32 @@ def stft(context, node):
63716372
win_length=win_length,
63726373
window=window,
63736374
normalized=normalized,
6374-
onesided=onesided)
6375+
onesided=onesided
6376+
)
63756377
context.add(stft_res, node.name)
63766378

6379+
@register_torch_op
6380+
def istft(context, node):
6381+
"""
6382+
Lowers torch.istft with the dialect op `complex_istft` from complex_dialect_ops.py
6383+
"""
6384+
input_data, n_fft, hop_length, win_length, window, center, normalized, onesided, length, _ = _get_inputs(context, node, min_expected=2)
6385+
6386+
if types.is_complex(input_data.dtype):
6387+
onesided = False # pytorch defaults onesided to False for complex inputs
6388+
istft_res = mb.complex_istft(
6389+
input=input_data,
6390+
n_fft=n_fft,
6391+
hop_length=hop_length,
6392+
win_length=win_length,
6393+
window=window,
6394+
center=center,
6395+
normalized=normalized,
6396+
onesided=onesided,
6397+
length=length,
6398+
)
6399+
context.add(istft_res, node.name)
6400+
63776401
@register_torch_op(torch_alias=["torchvision::nms"])
63786402
def torchvision_nms(context, node):
63796403
inputs = _get_inputs(context, node, expected=3)

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9589,12 +9589,11 @@ def forward(self, x):
95899589

95909590
@pytest.mark.slow
95919591
@pytest.mark.parametrize(
9592-
"compute_unit, backend, input_shape, n_fft, hop_length, win_length, window, center, normalized, onesided, length, return_complex",
9592+
"compute_unit, backend, input_shape, hop_length, win_length, window, center, normalized, onesided, length, return_complex",
95939593
itertools.product(
95949594
compute_units,
95959595
backends,
95969596
[(1, 32, 9), (32, 9), (3, 32, 9)], # input shape
9597-
[16], # n_fft
95989597
[None, 4, 5], # hop_length
95999598
[None, 16, 9], # win_length
96009599
[None, torch.hann_window], # window
@@ -9605,10 +9604,12 @@ def forward(self, x):
96059604
[False, True], # return_complex
96069605
)
96079606
)
9608-
def test_istft(self, compute_unit, backend, input_shape, n_fft, hop_length, win_length, window, center, normalized, onesided, length, return_complex):
9607+
def test_istft(self, compute_unit, backend, input_shape, hop_length, win_length, window, center, normalized, onesided, length, return_complex):
96099608
if return_complex and onesided:
96109609
pytest.skip("Complex output is incompatible with onesided")
96119610

9611+
n_fft = input_shape[1]
9612+
96129613
class ISTFTModel(torch.nn.Module):
96139614
def forward(self, x):
96149615
applied_window = window(win_length) if window and win_length else None

coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,6 @@ class complex_istft(Operation):
893893
894894
Attributes
895895
----------
896-
V: complex64
897896
T: fp32, complex64
898897
899898
References
@@ -902,44 +901,40 @@ class complex_istft(Operation):
902901
"""
903902

904903
input_spec = InputSpec(
905-
input=TensorInputType(type_domain="V"),
904+
input=TensorInputType(type_domain=types.complex),
906905
n_fft=TensorInputType(const=True, type_domain=types.int32),
907906
hop_length=TensorInputType(const=True, optional=True, type_domain=types.int32),
908907
win_length=TensorInputType(const=True, optional=True, type_domain=types.int32),
909908
window=TensorInputType(const=True, optional=True, type_domain=types.fp32),
910-
normalized=TensorInputType(const=True, optional=True, type_domain=types.bool),
909+
center=TensorInputType(const=True, type_domain=types.bool),
910+
normalized=TensorInputType(const=True, optional=False, type_domain=types.bool),
911911
onesided=TensorInputType(const=True, optional=True, type_domain=types.bool),
912912
length=TensorInputType(const=True, optional=True, type_domain=types.int32),
913+
return_complex=TensorInputType(const=True, optional=True, type_domain=types.bool),
913914
)
914915

915-
type_domains = {
916-
"V": types.complex64,
917-
}
918-
919916
def default_inputs(self):
920917
return DefaultInputs(
921918
hop_length = None,
922919
win_length = None,
923920
window = None,
924921
normalized = False,
925922
onesided = True,
926-
length = None
923+
length = None,
924+
return_complex = True,
927925
)
928926

929927
def type_inference(self):
930-
output_type = (types.fp32)
931-
output_shape = []
928+
output_type = (types.complex64) if self.return_complex else (types.fp32)
932929

933-
# add back rank if needed
934-
if self.input.rank == 2:
935-
output_shape += [self.input.shape[0]]
930+
# add batch size if given
931+
output_shape = [self.input.shape[0] if self.input.rank == 3 else 1]
936932

937933
if self.length:
938934
output_shape += [self.length]
939-
return types.tensor(output_type, tuple(output_shape))
935+
else:
936+
n_frames = self.input.shape[-1]
937+
output_shape += [self.n_fft.val + self.hop_length.val * (n_frames - 1)]
940938

941-
n_frames = self.input.shape[-1]
942-
output_shape = self.n_fft.val + self.hop_length.val * (n_frames - 1)
943939

944940
return types.tensor(output_type, tuple(output_shape))
945-

coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ def _istft(
397397
hop_length: Optional[Var],
398398
win_length: Optional[Var],
399399
window: Optional[Var],
400+
center: Optional[Var],
400401
normalized: Optional[Var],
401402
onesided: Optional[Var],
402403
length: Optional[Var],
@@ -435,12 +436,10 @@ def _istft(
435436
cos_base = mb.mul(x=window, y=cos_base, before_op=before_op)
436437
sin_base = mb.mul(x=window, y=sin_base, before_op=before_op)
437438

438-
cos_base = mb.expand_dims(x=cos_base, axes=(1,), before_op=before_op)
439-
sin_base = mb.expand_dims(x=sin_base, axes=(1,), before_op=before_op)
440439
hop_size = mb.expand_dims(x=hop_length, axes=(0,), before_op=before_op)
441440

442-
signal_real = mb.expand_dims(x=input_real, axes=(1,), before_op=before_op)
443-
signal_imaginary = mb.expand_dims(x=input_imaginary, axes=(1,), before_op=before_op)
441+
signal_real = input_real
442+
signal_imaginary = input_imaginary
444443

445444
# De-normalized signal before applying the IFT
446445
if normalized and normalized.val:
@@ -455,15 +454,16 @@ def _istft(
455454
# So using the definition in stft function, we get:
456455
# real(x[n]) = Σ(a[k]*cos(2π*k/K*n)+b[k]*sin(2π*k/K*n))
457456
# imag(x[n]) = Σ(b[k]*cos(2π*k/K*n)-a[k]*sin(2π*k/K*n))
458-
cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op)
459-
sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op)
460-
cos_windows_imag = mb.conv(x=signal_imaginary, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op)
461-
sin_windows_imag = mb.conv(x=signal_imaginary, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op)
457+
cos_windows_real = mb.matmul(x=signal_real, y=cos_base, transpose_x=True, before_op=before_op)
458+
sin_windows_real = mb.matmul(x=signal_real, y=sin_base, transpose_x=True, before_op=before_op)
459+
cos_windows_imag = mb.matmul(x=signal_imaginary, y=cos_base, transpose_x=True, before_op=before_op)
460+
sin_windows_imag = mb.matmul(x=signal_imaginary, y=sin_base, transpose_x=True, before_op=before_op)
462461

463462
real_result = mb.add(x=cos_windows_real, y=sin_windows_imag, before_op=before_op)
464463
imag_result = mb.sub(x=cos_windows_imag, y=sin_windows_real, before_op=before_op)
465464

466465
# Divide by N
466+
n_fft = mb.cast(x=n_fft, dtype="fp32", before_op=before_op)
467467
real_result = mb.real_div(x=real_result, y=n_fft, before_op=before_op)
468468
imag_result = mb.real_div(x=imag_result, y=n_fft, before_op=before_op)
469469

@@ -472,9 +472,9 @@ def _istft(
472472
imag_result = _overlap_add(x=imag_result, n_fft=n_fft, hop_length=hop_length, before_op=before_op)
473473

474474
# Normalize by the window square
475-
n_frames = mb.shape(x=real_result, before_op=before_op)[1]
476475
window_square = mb.mul(x=window, y=window, before_op=before_op)
477-
window_mtx = mb.stack(values=[window_square] * n_frames, axis=1)
476+
window_mtx = mb.stack(values=[window_square] * n_frames, axis=0, before_op=before_op)
477+
window_mtx = mb.expand_dims(x=window_mtx, axes=(0,), before_op=before_op)
478478
window_envelope = _overlap_add(x=window_mtx, n_fft=n_fft, hop_length=hop_length, before_op=before_op)
479479
real_result = mb.real_div(x=real_result, y=window_envelope, before_op=before_op)
480480
imag_result = mb.real_div(x=imag_result, y=window_envelope, before_op=before_op)
@@ -502,17 +502,27 @@ def _overlap_add(
502502
"""
503503
input_shape = mb.shape(x=x, before_op=before_op)
504504
channels = input_shape.val[0]
505-
n_frames = input_shape.val[2]
505+
n_frames = input_shape.val[1]
506+
507+
# Create empty output with final shape
508+
output = mb.fill(shape=(channels, int(n_fft.val + hop_length.val * (n_frames - 1))), value=0., before_op=before_op)
506509

507-
output = mb.fill(shape=(channels, n_fft.val + hop_length.val * (n_frames - 1)), value=0., before_op=before_op)
508-
signal_frames = mb.split(x=x, num_splits=n_frames, axis=2, before_op=before_op)
510+
# Create an index used later on overlap add
511+
n_fft = mb.cast(x=n_fft, dtype="int32", before_op=before_op)
509512
local_idx = mb.range_1d(start=0, end=n_fft, step=1, before_op=before_op)
510513

514+
# Split data into frames and iterate
515+
signal_frames = mb.split(x=x, num_splits=n_frames, axis=1, before_op=before_op)
516+
511517
for frame_num, frame in enumerate(signal_frames):
518+
frame = mb.squeeze(x=frame, axes=[1], before_op=before_op)
519+
520+
# Create index to align data frames
512521
global_idx = mb.add(x=local_idx , y=frame_num*hop_length.val, before_op=before_op)
513-
global_idx = mb.expand_dims(x=global_idx, axes=(0,), before_op=before_op)
514-
global_idx = mb.stack(values=[global_idx] * channels, axis=0)
515-
output = mb.scatter_nd(data=output, indices=global_idx, updates=frame, before_op=before_op)
522+
global_idx = mb.stack(values=[global_idx] * channels, axis=0, before_op=before_op)
523+
524+
# Add data frame
525+
output = mb.scatter_along_axis(data=output, indices=global_idx, updates=frame, axis=1, mode="add", before_op=before_op)
516526

517527
return output
518528

@@ -748,19 +758,18 @@ def _lower_complex_stft(op: Operation):
748758

749759
@LowerComplex.register_lower_func(op_type="complex_istft")
750760
def _lower_complex_istft(op: Operation):
751-
is_complex = types.is_complex(op.input.dtype)
752761

753762
# check parameters for validity
754-
if is_complex:
755-
raise ValueError("Only complex inputs are allowed")
763+
if not types.is_complex(op.input.dtype):
764+
raise TypeError("Input type must be complex")
756765
if op.win_length and op.win_length.val > op.n_fft.val:
757766
raise ValueError("Window length must be less than or equal to n_fft")
758767
if op.return_complex and op.onesided and op.onesided.val:
759768
raise ValueError("Complex output is not compatible with onesided")
760769

761770
real, imag = _istft(
762771
op.input.real, op.input.imag,
763-
op.n_fft, op.hop_length, op.win_length, op.window, op.normalized, op.onesided, op.length, before_op=op)
772+
op.n_fft, op.hop_length, op.win_length, op.window, op.center, op.normalized, op.onesided, op.length, before_op=op)
764773

765774
if op.return_complex:
766775
return _wrap_complex_output(op.outputs[0], real, imag)

0 commit comments

Comments
 (0)