Skip to content

Commit 0c8cc8e

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
Normalize by window square
1 parent 40b6b3f commit 0c8cc8e

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -454,16 +454,20 @@ def _istft(
454454
real_result = _overlap_add(x=real_result, n_fft=n_fft, hop_length=hop_length, before_op=before_op)
455455
imag_result = _overlap_add(x=imag_result, n_fft=n_fft, hop_length=hop_length, before_op=before_op)
456456

457+
# Normalize by the window square
458+
n_frames = mb.shape(x=real_result, before_op=before_op)[1]
459+
window_square = mb.mul(x=window, y=window, before_op=before_op)
460+
window_mtx = mb.stack(values=[window_square] * n_frames, axis=1)
461+
normalization_factor = _overlap_add(x=window_mtx, n_fft=n_fft, hop_length=hop_length, before_op=before_op)
462+
463+
real_result = mb.real_div(x=real_result, y=normalization_factor, before_op=before_op)
464+
imag_result = mb.real_div(x=imag_result, y=normalization_factor, before_op=before_op)
465+
457466
# reduce the rank of the output
458467
if should_increase_rank:
459468
real_result = mb.squeeze(x=real_result, axes=(0,), before_op=before_op)
460469
imag_result = mb.squeeze(x=imag_result, axes=(0,), before_op=before_op)
461470

462-
if normalized and normalized.val:
463-
divisor = mb.sqrt(x=mb.cast(x=n_fft, dtype="fp32", before_op=before_op), before_op=before_op)
464-
real_result = mb.real_div(x=real_result, y=divisor, before_op=before_op)
465-
imag_result = mb.real_div(x=imag_result, y=divisor, before_op=before_op)
466-
467471
return real_result, imag_result
468472

469473
def _overlap_add(
@@ -473,7 +477,7 @@ def _overlap_add(
473477
before_op: Operation,
474478
) -> Var:
475479
n_frames = mb.shape(x=x, before_op=before_op)[1]
476-
output = mb.fill(shape=(n_fft + hop_length * (n_frames - 1)), value=0., before_op=before_op)
480+
output = mb.fill(shape=(n_fft.val + hop_length.val * (n_frames - 1)), value=0., before_op=before_op)
477481
signal_frames = mb.split(x=x, num_splits=n_frames, axis=1, before_op=before_op)
478482
local_idx = mb.range_1d(start=0, end=n_fft, step=1, before_op=before_op)
479483

0 commit comments

Comments
 (0)