From 47384780d62e1c199dee8d290013fdc95a500549 Mon Sep 17 00:00:00 2001 From: Tomoaki Kobayashi Date: Sat, 23 Aug 2025 20:16:05 +0900 Subject: [PATCH 1/3] Revert "[torchlib] Unregister stft, var, var_mean, std, std_mean" (#1867) This reverts commit 1eef63304555f4ce7686d9ed20657367b64ae323. --- .../function_libs/torch_lib/ops/core.py | 133 ++++++++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 8 ++ 2 files changed, 141 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 2cbecdcfc2..8f1e2f7e1b 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8548,6 +8548,139 @@ def aten_std_mean_correction( return op.Sqrt(var), mean +@torch_op("aten::stft", private=True) +def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT64]: + signal_rank = Rank(self) + if signal_rank == 1: + # Add a batch dimension + self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + return op.Identity(self), signal_rank + + +@torch_op("aten::stft", private=True) +def _center_window_around_zeros_if_needed( + window: TFloatOrBFloat16, n_fft: int +) -> TFloatOrBFloat16: + # first dimension + n_win = op.Shape(window, start=0, end=1) + # Center window around zeros if needed (required by ONNX's STFT) + if n_win < n_fft: + left = (n_fft - n_win) / 2 + + right = n_fft - left - n_win + left = op.Reshape(left, op.Constant(value_ints=[1])) + right = op.Reshape(right, op.Constant(value_ints=[1])) + + left_win = op.Expand(op.Constant(value_ints=[0]), left) + right_win = op.Expand(op.Constant(value_ints=[0]), right) + right_win = op.CastLike(right_win, window) + left_win = op.CastLike(left_win, window) + window = op.Concat(left_win, window, right_win, axis=0) + return window + + +@torch_op("aten::stft", private=True) +def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloat16: + left = (n_fft - win_length) / 2 + + right = n_fft - left - win_length + left = op.Reshape(left, op.Constant(value_ints=[1])) + right = op.Reshape(right, op.Constant(value_ints=[1])) + win_length = op.Reshape(win_length, op.Constant(value_ints=[1])) + + left_win = op.Expand(op.Constant(value_ints=[0]), left) + right_win = op.Expand(op.Constant(value_ints=[0]), right) + window_list = op.Expand(op.Constant(value_ints=[1]), win_length) + return op.Concat(left_win, window_list, right_win, axis=0) + + +@torch_op("aten::stft", private=True) +def _create_window_from_n_fft(n_fft: int) -> TFloatOrBFloat16: + n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) + window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor) + return window + + +@torch_op("aten::stft", private=True) +def _normalize_fft_result( + signal: TFloatOrBFloat16, result: TFloatOrBFloat16, n_fft: int +) -> TFloatOrBFloat16: + n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) + sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal)) + result = result / sqrt_nfft + return result + + +@torch_op("aten::stft", private=True) +def _aten_stft_onnx( + signal: TFloatOrBFloat16, + frame_step_const: INT64, + window: Union[TFloatOrBFloat16, INT64], + frame_length_const: INT64, + signal_rank: INT64, + onesided: int, +) -> TFloatOrBFloat16: + window = op.CastLike(window, signal) + result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided) + result = op.Transpose(result, perm=[0, 2, 1, 3]) + # Remove batch dimension, if needed + if signal_rank == 1: + result = op.Squeeze(result, op.Constant(value_ints=[0])) + return result + + +@torch_op("aten::stft", trace_only=True) +def aten_stft( + self: TFloatOrBFloat16, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[TFloatOrBFloat16] = None, + normalized: bool = False, + onesided: Optional[bool] = None, + return_complex: Optional[bool] = None, +) -> TFloatOrBFloat16: + """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor""" + + # NOTE: regarless of the value of return_complex, we always return a real representation. + del return_complex + + # Get STFT sizes + if hop_length is None: + # core dump + # hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4])) + hop_length = n_fft // 4 + frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1])) + frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1])) + + # Pre-process input if needed + self, signal_rank = _add_batch_dimension(self) + + # Get window and make sure it's the same size as `win_length` or `n_fft` + if window is not None and window.shape[0] is not None: + window = _center_window_around_zeros_if_needed(window, n_fft) + elif window is None: + if win_length is not None: + window = _create_window_from_win_length(win_length, n_fft) + else: + window = _create_window_from_n_fft(n_fft) + + if onesided is None or onesided: + onesided = 1 + else: + onesided = 0 + # remove batch dimension included + result = _aten_stft_onnx( + self, frame_step_const, window, frame_length_const, signal_rank, onesided + ) + + # Normalize, if needed + if normalized: + result = _normalize_fft_result(self, result, n_fft) + + return result + + @torch_op( ( "aten::sub.Tensor", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index b60fd8cf31..4ef7550b6e 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1760,6 +1760,14 @@ def _where_input_wrangler( TorchLibOpInfo("ops.aten.scatter.value", core_ops.aten_scatter_value), TorchLibOpInfo("slice", core_ops.aten_slice), TorchLibOpInfo("slice", core_ops.aten_slice_complex, complex=True), + TorchLibOpInfo( + "ops.aten.stft", # Custom from extra_opinfo + core_ops.aten_stft, + tolerance={torch.float32: (3.7e-5, 1.8e-4)}, + ).xfail( + dtypes=(torch.float16,), + reason="RuntimeError: MKL FFT doesn't support tensors of type: Half", + ), TorchLibOpInfo( "sum", core_ops.aten_sum_dim_IntList, From 111f4ed955126cb0b57c98b7a47ff07c75c1d759 Mon Sep 17 00:00:00 2001 From: Tomoaki Kobayashi Date: Sat, 18 Oct 2025 03:57:05 +0900 Subject: [PATCH 2/3] Fix aten_stft --- .../function_libs/torch_lib/ops/core.py | 107 ++++++------------ 1 file changed, 36 insertions(+), 71 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 8f1e2f7e1b..44d8513f59 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8548,42 +8548,10 @@ def aten_std_mean_correction( return op.Sqrt(var), mean -@torch_op("aten::stft", private=True) -def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT64]: - signal_rank = Rank(self) - if signal_rank == 1: - # Add a batch dimension - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - return op.Identity(self), signal_rank - - -@torch_op("aten::stft", private=True) -def _center_window_around_zeros_if_needed( - window: TFloatOrBFloat16, n_fft: int -) -> TFloatOrBFloat16: - # first dimension - n_win = op.Shape(window, start=0, end=1) - # Center window around zeros if needed (required by ONNX's STFT) - if n_win < n_fft: - left = (n_fft - n_win) / 2 - - right = n_fft - left - n_win - left = op.Reshape(left, op.Constant(value_ints=[1])) - right = op.Reshape(right, op.Constant(value_ints=[1])) - - left_win = op.Expand(op.Constant(value_ints=[0]), left) - right_win = op.Expand(op.Constant(value_ints=[0]), right) - right_win = op.CastLike(right_win, window) - left_win = op.CastLike(left_win, window) - window = op.Concat(left_win, window, right_win, axis=0) - return window - - -@torch_op("aten::stft", private=True) -def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloat16: - left = (n_fft - win_length) / 2 +def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloat: + left = op.Div(op.Sub(n_fft, win_length), op.Constant(value_ints=[2])) - right = n_fft - left - win_length + right = op.Sub(op.Sub(n_fft, left), win_length) left = op.Reshape(left, op.Constant(value_ints=[1])) right = op.Reshape(right, op.Constant(value_ints=[1])) win_length = op.Reshape(win_length, op.Constant(value_ints=[1])) @@ -8594,71 +8562,66 @@ def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloa return op.Concat(left_win, window_list, right_win, axis=0) -@torch_op("aten::stft", private=True) -def _create_window_from_n_fft(n_fft: int) -> TFloatOrBFloat16: +def _create_window_from_n_fft(n_fft: int) -> TFloat: n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor) return window -@torch_op("aten::stft", private=True) -def _normalize_fft_result( - signal: TFloatOrBFloat16, result: TFloatOrBFloat16, n_fft: int -) -> TFloatOrBFloat16: +def _normalize_fft_result(signal: TFloat, result: TFloat, n_fft: int) -> TFloat: n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal)) - result = result / sqrt_nfft - return result - - -@torch_op("aten::stft", private=True) -def _aten_stft_onnx( - signal: TFloatOrBFloat16, - frame_step_const: INT64, - window: Union[TFloatOrBFloat16, INT64], - frame_length_const: INT64, - signal_rank: INT64, - onesided: int, -) -> TFloatOrBFloat16: - window = op.CastLike(window, signal) - result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided) - result = op.Transpose(result, perm=[0, 2, 1, 3]) - # Remove batch dimension, if needed - if signal_rank == 1: - result = op.Squeeze(result, op.Constant(value_ints=[0])) + result = op.Div(result, sqrt_nfft) return result @torch_op("aten::stft", trace_only=True) def aten_stft( - self: TFloatOrBFloat16, + self: TFloat, n_fft: int, hop_length: Optional[int] = None, win_length: Optional[int] = None, - window: Optional[TFloatOrBFloat16] = None, + window: Optional[TFloat] = None, normalized: bool = False, onesided: Optional[bool] = None, return_complex: Optional[bool] = None, -) -> TFloatOrBFloat16: +) -> TFloat: """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor""" - # NOTE: regarless of the value of return_complex, we always return a real representation. + # NOTE: regardless of the value of return_complex, we always return a real representation. del return_complex # Get STFT sizes if hop_length is None: # core dump - # hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4])) + # hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4])) hop_length = n_fft // 4 frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1])) frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1])) # Pre-process input if needed - self, signal_rank = _add_batch_dimension(self) + is_signal_rank1 = len(self.shape) == 1 + if is_signal_rank1: + # Add a batch dimension + self = op.Identity(op.Unsqueeze(self, op.Constant(value_ints=[0]))) # Get window and make sure it's the same size as `win_length` or `n_fft` if window is not None and window.shape[0] is not None: - window = _center_window_around_zeros_if_needed(window, n_fft) + # first dimension + n_win = op.Shape(window, start=0, end=1) + # Center window around zeros if needed (required by ONNX's STFT) + if n_win < n_fft: + left = op.Div(op.Sub(n_fft, n_win), op.Constant(value_ints=[2])) + + right = op.Sub(op.Sub(n_fft, left), n_win) + left = op.Reshape(left, op.Constant(value_ints=[1])) + right = op.Reshape(right, op.Constant(value_ints=[1])) + + left_win = op.Expand(op.Constant(value_ints=[0]), left) + right_win = op.Expand(op.Constant(value_ints=[0]), right) + right_win = op.CastLike(right_win, window) + left_win = op.CastLike(left_win, window) + window = op.Concat(left_win, window, right_win, axis=0) elif window is None: if win_length is not None: window = _create_window_from_win_length(win_length, n_fft) @@ -8669,10 +8632,12 @@ def aten_stft( onesided = 1 else: onesided = 0 - # remove batch dimension included - result = _aten_stft_onnx( - self, frame_step_const, window, frame_length_const, signal_rank, onesided - ) + window = op.CastLike(window, self) + result = op.STFT(self, frame_step_const, window, frame_length_const, onesided=onesided) + result = op.Transpose(result, perm=[0, 2, 1, 3]) + # Remove batch dimension, if needed + if is_signal_rank1: + result = op.Squeeze(result, op.Constant(value_ints=[0])) # Normalize, if needed if normalized: From 29ba6b99b05295a4136cdea15209cbd4094d67c1 Mon Sep 17 00:00:00 2001 From: Tomoaki Kobayashi Date: Sun, 2 Nov 2025 22:44:06 +0900 Subject: [PATCH 3/3] Add test and fix impl --- .../function_libs/torch_lib/ops/core.py | 3 +- .../function_libs/torch_lib/e2e_ops_tests.py | 69 +++++++++++++++++++ 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 44d8513f59..09704199f9 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8597,7 +8597,6 @@ def aten_stft( # hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4])) hop_length = n_fft // 4 frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1])) - frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1])) # Pre-process input if needed is_signal_rank1 = len(self.shape) == 1 @@ -8633,7 +8632,7 @@ def aten_stft( else: onesided = 0 window = op.CastLike(window, self) - result = op.STFT(self, frame_step_const, window, frame_length_const, onesided=onesided) + result = op.STFT(self, frame_step_const, window, n_fft, onesided=onesided) result = op.Transpose(result, perm=[0, 2, 1, 3]) # Remove batch dimension, if needed if is_signal_rank1: diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index f74dda699d..cb272a98a6 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -406,6 +406,75 @@ def forward(self, x): onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) _testing.assert_onnx_program(onnx_program) + def test_aten_stft_1(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.stft(x, n_fft=4, return_complex=True) + + x = torch.randn(4, 16, dtype=torch.float32) + + onnx_program = torch.onnx.export( + Model(), + (x,), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_aten_stft_2(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.stft(x, n_fft=4, return_complex=False) + + x = torch.randn(4, 16, dtype=torch.float32) + + onnx_program = torch.onnx.export( + Model(), + (x,), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_aten_stft_3(self): + class Model(torch.nn.Module): + def forward(self, x): + window = torch.ones(16, dtype=torch.float32) + return torch.ops.aten.stft(x, n_fft=16, window=window, return_complex=False) + + x = torch.randn(100, dtype=torch.float32) + + onnx_program = torch.onnx.export( + Model(), + (x,), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_aten_stft_4(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.stft( + x, + n_fft=4, + hop_length=1, + win_length=4, + center=True, + onesided=True, + return_complex=True, + ) + + x = torch.randn(4, 16, dtype=torch.float32) + + onnx_program = torch.onnx.export( + Model(), + (x,), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main()