Skip to content

Commit 48c34e3

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
Fixes
1 parent a46bd0a commit 48c34e3

File tree

2 files changed

+17
-27
lines changed

2 files changed

+17
-27
lines changed

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9535,7 +9535,7 @@ def forward(self, x):
95359535
[16, 32], # n_fft
95369536
[5, 9], # num_frames
95379537
[None, 4, 5], # hop_length
9538-
[None, 16, 9], # win_length
9538+
[None, 10, 8], # win_length
95399539
[None, torch.hann_window], # window
95409540
[False, True], # center
95419541
[False, True], # normalized
@@ -9548,6 +9548,9 @@ def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_len
95489548
if return_complex and onesided:
95499549
pytest.skip("Complex output is incompatible with onesided")
95509550

9551+
if hop_length is None and win_length is not None:
9552+
pytest.skip("If win_length is set then we must set hop_length and 0 < hop_length <= win_length")
9553+
95519554
freq = n_fft//2+1 if onesided else n_fft
95529555
input_shape = (channels, freq, num_frames) if channels else (freq, num_frames)
95539556

@@ -9566,24 +9569,13 @@ def forward(self, x):
95669569
length=length,
95679570
return_complex=return_complex)
95689571
if return_complex:
9569-
x = torch.stack([torch.real(x), torch.imag(x)], dim=0)
9570-
return x
9572+
return torch.stack([torch.real(x), torch.imag(x)], dim=0)
9573+
else:
9574+
return torch.real(x)
95719575

9572-
if length is not None or center is False:
9576+
if win_length and center is False:
95739577
# For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033
9574-
with pytest.raises(
9575-
RuntimeError, match="istft\(.*\) window overlap add min: 1"
9576-
):
9577-
TorchBaseTest.run_compare_torch(
9578-
input_shape,
9579-
ISTFTModel(),
9580-
backend=backend,
9581-
compute_unit=compute_unit
9582-
)
9583-
elif return_complex is False:
9584-
with pytest.raises(
9585-
ValueError, match="MIL doesn't support complex data as model's output"
9586-
):
9578+
with pytest.raises(RuntimeError, match="istft\(.*\) window overlap add min: 1"):
95879579
TorchBaseTest.run_compare_torch(
95889580
input_shape,
95899581
ISTFTModel(),

coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def _istft(
427427

428428
expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1)
429429

430-
is_onesided = onesided.val if onesided else fft_size != n_fft
430+
is_onesided = True if fft_size != n_fft.val else onesided and onesided.val
431431
cos_base, sin_base = _calculate_dft_matrix(n_fft, onesided=is_onesided, before_op=before_op)
432432

433433
# create a window of centered 1s of the requested size
@@ -481,20 +481,18 @@ def _istft(
481481
window_envelope = _overlap_add(x=window_mtx, n_fft=n_fft, hop_length=hop_length, before_op=before_op)
482482
real_result = mb.real_div(x=real_result, y=window_envelope, before_op=before_op)
483483
imag_result = mb.real_div(x=imag_result, y=window_envelope, before_op=before_op)
484-
485484
# We need to adapt last dimension
486485
if length is not None:
487486
if length.val > expected_output_signal_len:
487+
real_result = mb.pad(x=real_result, pad=(0, length.val - expected_output_signal_len), before_op=before_op)
488+
imag_result = mb.pad(x=imag_result, pad=(0, length.val - expected_output_signal_len), before_op=before_op)
489+
elif length.val < expected_output_signal_len:
488490
if channels:
489-
right_pad = mb.fill(shape=(channels, length.val - expected_output_signal_len ), value=0., before_op=before_op)
491+
real_result = mb.slice_by_size(x=real_result, begin=[0,0], size=[-1, length.val], before_op=before_op)
492+
imag_result = mb.slice_by_size(x=imag_result, begin=[0,0], size=[-1, length.val], before_op=before_op)
490493
else:
491-
right_pad = mb.fill(shape=(length.val - expected_output_signal_len,), value=0., before_op=before_op)
492-
493-
real_result = mb.stack(values=(real_result, right_pad), axis=1, before_op=before_op)
494-
imag_result = mb.stack(values=(imag_result, right_pad), axis=1, before_op=before_op)
495-
elif length.val < expected_output_signal_len:
496-
real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length.val], before_op=before_op)
497-
imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length.val], before_op=before_op)
494+
real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length.val], before_op=before_op)
495+
imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length.val], before_op=before_op)
498496

499497
return real_result, imag_result
500498

0 commit comments

Comments
 (0)