Skip to content

Commit c2686c8

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
More fixes
1 parent e0cfffd commit c2686c8

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9597,7 +9597,7 @@ def forward(self, x):
95979597
[None, 4, 5], # hop_length
95989598
[None, 16, 9], # win_length
95999599
[None, torch.hann_window], # window
9600-
[None, False, True], # center
9600+
[False, True], # center
96019601
[False, True], # normalized
96029602
[None, False, True], # onesided
96039603
[None, 60], # length

coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -934,7 +934,9 @@ def type_inference(self):
934934
output_shape += [self.length]
935935
else:
936936
n_frames = self.input.shape[-1]
937-
output_shape += [self.n_fft.val + self.hop_length.val * (n_frames - 1)]
937+
938+
hop_length = self.hop_length.val if self.hop_length else self.n_fft.val // 4
939+
output_shape += [self.n_fft.val + hop_length * (n_frames - 1)]
938940

939941

940942
return types.tensor(output_type, tuple(output_shape))

coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,11 +481,11 @@ def _istft(
481481

482482
# We need to adapt last dimension
483483
if length is not None:
484-
if length > expected_output_signal_len:
484+
if length.val > expected_output_signal_len:
485485
right_pad = mb.fill(shape=(channels, expected_output_signal_len - length), value=0., before_op=before_op)
486486
real_result = mb.stack(x=(real_result, right_pad), axis=1, before_op=before_op)
487487
imag_result = mb.stack(x=(imag_result, right_pad), axis=1, before_op=before_op)
488-
elif length < expected_output_signal_len:
488+
elif length.val < expected_output_signal_len:
489489
real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length], before_op=before_op)
490490
imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length], before_op=before_op)
491491

0 commit comments

Comments
 (0)