Skip to content

Commit a46bd0a

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
Fixes
1 parent fc700a0 commit a46bd0a

File tree

2 files changed

+33
-11
lines changed

2 files changed

+33
-11
lines changed

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9548,7 +9548,7 @@ def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_len
95489548
if return_complex and onesided:
95499549
pytest.skip("Complex output is incompatible with onesided")
95509550

9551-
freq = n_fft*2+1 if onesided else n_fft
9551+
freq = n_fft//2+1 if onesided else n_fft
95529552
input_shape = (channels, freq, num_frames) if channels else (freq, num_frames)
95539553

95549554
class ISTFTModel(torch.nn.Module):
@@ -9569,12 +9569,34 @@ def forward(self, x):
95699569
x = torch.stack([torch.real(x), torch.imag(x)], dim=0)
95709570
return x
95719571

9572-
TorchBaseTest.run_compare_torch(
9573-
input_shape,
9574-
ISTFTModel(),
9575-
backend=backend,
9576-
compute_unit=compute_unit
9577-
)
9572+
if length is not None or center is False:
9573+
# For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033
9574+
with pytest.raises(
9575+
RuntimeError, match="istft\(.*\) window overlap add min: 1"
9576+
):
9577+
TorchBaseTest.run_compare_torch(
9578+
input_shape,
9579+
ISTFTModel(),
9580+
backend=backend,
9581+
compute_unit=compute_unit
9582+
)
9583+
elif return_complex is False:
9584+
with pytest.raises(
9585+
ValueError, match="MIL doesn't support complex data as model's output"
9586+
):
9587+
TorchBaseTest.run_compare_torch(
9588+
input_shape,
9589+
ISTFTModel(),
9590+
backend=backend,
9591+
compute_unit=compute_unit
9592+
)
9593+
else:
9594+
TorchBaseTest.run_compare_torch(
9595+
input_shape,
9596+
ISTFTModel(),
9597+
backend=backend,
9598+
compute_unit=compute_unit
9599+
)
95789600

95799601
if _HAS_TORCH_AUDIO:
95809602

coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -486,12 +486,12 @@ def _istft(
486486
if length is not None:
487487
if length.val > expected_output_signal_len:
488488
if channels:
489-
right_pad = mb.fill(shape=(channels, expected_output_signal_len - length), value=0., before_op=before_op)
489+
right_pad = mb.fill(shape=(channels, length.val - expected_output_signal_len ), value=0., before_op=before_op)
490490
else:
491-
right_pad = mb.fill(shape=(expected_output_signal_len - length,), value=0., before_op=before_op)
491+
right_pad = mb.fill(shape=(length.val - expected_output_signal_len,), value=0., before_op=before_op)
492492

493-
real_result = mb.stack(x=(real_result, right_pad), axis=1, before_op=before_op)
494-
imag_result = mb.stack(x=(imag_result, right_pad), axis=1, before_op=before_op)
493+
real_result = mb.stack(values=(real_result, right_pad), axis=1, before_op=before_op)
494+
imag_result = mb.stack(values=(imag_result, right_pad), axis=1, before_op=before_op)
495495
elif length.val < expected_output_signal_len:
496496
real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length.val], before_op=before_op)
497497
imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length.val], before_op=before_op)

0 commit comments

Comments
 (0)