Skip to content

Commit 729d5d4

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
More fixes
1 parent 4edfd20 commit 729d5d4

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -934,7 +934,6 @@ def type_inference(self):
934934
output_shape += [self.length]
935935
else:
936936
n_frames = self.input.shape[-1]
937-
938937
hop_length = self.hop_length.val if self.hop_length else self.n_fft.val // 4
939938
output_shape += [self.n_fft.val + hop_length * (n_frames - 1)]
940939

coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def _istft(
419419
win_length = win_length or n_fft
420420

421421
input_shape = mb.shape(x=input_real, before_op=before_op)
422-
if input_shape.rank == 3:
422+
if input_real.rank == 3:
423423
channels, fft_size, n_frames = input_shape.val
424424
else:
425425
channels = None
@@ -510,7 +510,7 @@ def _overlap_add(
510510
input_shape = mb.shape(x=x, before_op=before_op)
511511

512512
# Create empty output with final shape
513-
if input_shape.rank == 3:
513+
if x.rank == 3:
514514
channels, n_frames, _= input_shape.val
515515
output = mb.fill(shape=(channels, int(n_fft.val + hop_length.val * (n_frames - 1)),), value=0., before_op=before_op)
516516
else:
@@ -523,18 +523,18 @@ def _overlap_add(
523523
local_idx = mb.range_1d(start=0, end=n_fft, step=1, before_op=before_op)
524524

525525
# Split data into frames and iterate
526-
signal_frames = mb.split(x=x, num_splits=n_frames, axis=1, before_op=before_op)
526+
signal_frames = mb.split(x=x, num_splits=n_frames, axis=1 if channels else 0, before_op=before_op)
527527

528528
for frame_num, frame in enumerate(signal_frames):
529-
frame = mb.squeeze(x=frame, axes=[1], before_op=before_op)
529+
frame = mb.squeeze(x=frame, axes=[1] if channels else [0], before_op=before_op)
530530

531531
# Create index to align data frames
532532
global_idx = mb.add(x=local_idx , y=frame_num*hop_length.val, before_op=before_op)
533533
if channels:
534534
global_idx = mb.stack(values=[global_idx] * channels, axis=0, before_op=before_op)
535535

536536
# Add data frame
537-
output = mb.scatter_along_axis(data=output, indices=global_idx, updates=frame, axis=1, mode="add", before_op=before_op)
537+
output = mb.scatter_along_axis(data=output, indices=global_idx, updates=frame, axis=1 if channels else 0, mode="add", before_op=before_op)
538538

539539
return output
540540

0 commit comments

Comments
 (0)