Skip to content

Commit 0d3238a

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
More fixes
1 parent c2686c8 commit 0d3238a

File tree

3 files changed

+30
-16
lines changed

3 files changed

+30
-16
lines changed

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9589,11 +9589,13 @@ def forward(self, x):
95899589

95909590
@pytest.mark.slow
95919591
@pytest.mark.parametrize(
9592-
"compute_unit, backend, input_shape, hop_length, win_length, window, center, normalized, onesided, length, return_complex",
9592+
"compute_unit, backend, channels, n_fft, num_frames, hop_length, win_length, window, center, normalized, onesided, length, return_complex",
95939593
itertools.product(
95949594
compute_units,
95959595
backends,
9596-
[(1, 32, 9), (32, 9), (3, 32, 9)], # input shape
9596+
[None, 1, 3], # channels
9597+
[16, 32], # n_fft
9598+
[5, 9], # num_frames
95979599
[None, 4, 5], # hop_length
95989600
[None, 16, 9], # win_length
95999601
[None, torch.hann_window], # window
@@ -9604,11 +9606,12 @@ def forward(self, x):
96049606
[False, True], # return_complex
96059607
)
96069608
)
9607-
def test_istft(self, compute_unit, backend, input_shape, hop_length, win_length, window, center, normalized, onesided, length, return_complex):
9609+
def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_length, win_length, window, center, normalized, onesided, length, return_complex):
96089610
if return_complex and onesided:
96099611
pytest.skip("Complex output is incompatible with onesided")
96109612

9611-
n_fft = input_shape[1]
9613+
freq = n_fft*2+1 if onesided else n_fft
9614+
input_shape = (channels, freq, num_frames) if channels else (freq, num_frames)
96129615

96139616
class ISTFTModel(torch.nn.Module):
96149617
def forward(self, x):

coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -938,5 +938,4 @@ def type_inference(self):
938938
hop_length = self.hop_length.val if self.hop_length else self.n_fft.val // 4
939939
output_shape += [self.n_fft.val + hop_length * (n_frames - 1)]
940940

941-
942941
return types.tensor(output_type, tuple(output_shape))

coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -419,9 +419,12 @@ def _istft(
419419
win_length = win_length or n_fft
420420

421421
input_shape = mb.shape(x=input_real, before_op=before_op)
422-
channels = input_shape.val[0]
423-
fft_size = input_shape.val[1]
424-
n_frames = input_shape.val[2]
422+
if input_shape.rank == 3:
423+
channels, fft_size, n_frames = input_shape.val
424+
else:
425+
channels = None
426+
fft_size, n_frames = input_shape.val
427+
425428
expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1)
426429

427430
is_onesided = onesided.val if onesided else fft_size != n_fft
@@ -482,12 +485,16 @@ def _istft(
482485
# We need to adapt last dimension
483486
if length is not None:
484487
if length.val > expected_output_signal_len:
485-
right_pad = mb.fill(shape=(channels, expected_output_signal_len - length), value=0., before_op=before_op)
488+
if channels:
489+
right_pad = mb.fill(shape=(channels, expected_output_signal_len - length), value=0., before_op=before_op)
490+
else:
491+
right_pad = mb.fill(shape=(expected_output_signal_len - length,), value=0., before_op=before_op)
492+
486493
real_result = mb.stack(x=(real_result, right_pad), axis=1, before_op=before_op)
487494
imag_result = mb.stack(x=(imag_result, right_pad), axis=1, before_op=before_op)
488495
elif length.val < expected_output_signal_len:
489-
real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length], before_op=before_op)
490-
imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length], before_op=before_op)
496+
real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length.val], before_op=before_op)
497+
imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length.val], before_op=before_op)
491498

492499
return real_result, imag_result
493500

@@ -498,14 +505,18 @@ def _overlap_add(
498505
before_op: Operation,
499506
) -> Var:
500507
"""
501-
The input has shape (channels, fft_size, n_frames)
508+
The input has shape (channels, n_frames, fft_size)
502509
"""
503510
input_shape = mb.shape(x=x, before_op=before_op)
504-
channels = input_shape.val[0]
505-
n_frames = input_shape.val[1]
506511

507512
# Create empty output with final shape
508-
output = mb.fill(shape=(channels, int(n_fft.val + hop_length.val * (n_frames - 1))), value=0., before_op=before_op)
513+
if input_shape.rank == 3:
514+
channels, n_frames = input_shape.val
515+
output = mb.fill(shape=(channels, int(n_fft.val + hop_length.val * (n_frames - 1))), value=0., before_op=before_op)
516+
else:
517+
channels = None
518+
n_frames= input_shape.val
519+
output = mb.fill(shape=(int(n_fft.val + hop_length.val * (n_frames - 1)),), value=0., before_op=before_op)
509520

510521
# Create an index used later on overlap add
511522
n_fft = mb.cast(x=n_fft, dtype="int32", before_op=before_op)
@@ -519,7 +530,8 @@ def _overlap_add(
519530

520531
# Create index to align data frames
521532
global_idx = mb.add(x=local_idx , y=frame_num*hop_length.val, before_op=before_op)
522-
global_idx = mb.stack(values=[global_idx] * channels, axis=0, before_op=before_op)
533+
if channels:
534+
global_idx = mb.stack(values=[global_idx] * channels, axis=0, before_op=before_op)
523535

524536
# Add data frame
525537
output = mb.scatter_along_axis(data=output, indices=global_idx, updates=frame, axis=1, mode="add", before_op=before_op)

0 commit comments

Comments
 (0)