Skip to content

Commit 866d61e

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
Fixes
1 parent 729d5d4 commit 866d61e

File tree

2 files changed

+33
-11
lines changed

2 files changed

+33
-11
lines changed

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9610,7 +9610,7 @@ def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_len
96109610
if return_complex and onesided:
96119611
pytest.skip("Complex output is incompatible with onesided")
96129612

9613-
freq = n_fft*2+1 if onesided else n_fft
9613+
freq = n_fft//2+1 if onesided else n_fft
96149614
input_shape = (channels, freq, num_frames) if channels else (freq, num_frames)
96159615

96169616
class ISTFTModel(torch.nn.Module):
@@ -9631,12 +9631,34 @@ def forward(self, x):
96319631
x = torch.stack([torch.real(x), torch.imag(x)], dim=0)
96329632
return x
96339633

9634-
TorchBaseTest.run_compare_torch(
9635-
input_shape,
9636-
ISTFTModel(),
9637-
backend=backend,
9638-
compute_unit=compute_unit
9639-
)
9634+
if length is not None or center is False:
9635+
# For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033
9636+
with pytest.raises(
9637+
RuntimeError, match="istft\(.*\) window overlap add min: 1"
9638+
):
9639+
TorchBaseTest.run_compare_torch(
9640+
input_shape,
9641+
ISTFTModel(),
9642+
backend=backend,
9643+
compute_unit=compute_unit
9644+
)
9645+
elif return_complex is False:
9646+
with pytest.raises(
9647+
ValueError, match="MIL doesn't support complex data as model's output"
9648+
):
9649+
TorchBaseTest.run_compare_torch(
9650+
input_shape,
9651+
ISTFTModel(),
9652+
backend=backend,
9653+
compute_unit=compute_unit
9654+
)
9655+
else:
9656+
TorchBaseTest.run_compare_torch(
9657+
input_shape,
9658+
ISTFTModel(),
9659+
backend=backend,
9660+
compute_unit=compute_unit
9661+
)
96409662

96419663
if _HAS_TORCH_AUDIO:
96429664

coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -486,12 +486,12 @@ def _istft(
486486
if length is not None:
487487
if length.val > expected_output_signal_len:
488488
if channels:
489-
right_pad = mb.fill(shape=(channels, expected_output_signal_len - length), value=0., before_op=before_op)
489+
right_pad = mb.fill(shape=(channels, length.val - expected_output_signal_len ), value=0., before_op=before_op)
490490
else:
491-
right_pad = mb.fill(shape=(expected_output_signal_len - length,), value=0., before_op=before_op)
491+
right_pad = mb.fill(shape=(length.val - expected_output_signal_len,), value=0., before_op=before_op)
492492

493-
real_result = mb.stack(x=(real_result, right_pad), axis=1, before_op=before_op)
494-
imag_result = mb.stack(x=(imag_result, right_pad), axis=1, before_op=before_op)
493+
real_result = mb.stack(values=(real_result, right_pad), axis=1, before_op=before_op)
494+
imag_result = mb.stack(values=(imag_result, right_pad), axis=1, before_op=before_op)
495495
elif length.val < expected_output_signal_len:
496496
real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length.val], before_op=before_op)
497497
imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length.val], before_op=before_op)

0 commit comments

Comments
 (0)