Skip to content

Commit 3829c30

Browse files
committed
Revert "[torchlib] Unregister stft, var, var_mean, std, std_mean" (#1867)
This reverts commit 1eef633.
1 parent 8a94ad6 commit 3829c30

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8118,6 +8118,139 @@ def aten_std_mean_correction(
81188118
return op.Sqrt(var), mean
81198119

81208120

8121+
@torch_op("aten::stft", private=True)
8122+
def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT64]:
8123+
signal_rank = Rank(self)
8124+
if signal_rank == 1:
8125+
# Add a batch dimension
8126+
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
8127+
return op.Identity(self), signal_rank
8128+
8129+
8130+
@torch_op("aten::stft", private=True)
8131+
def _center_window_around_zeros_if_needed(
8132+
window: TFloatOrBFloat16, n_fft: int
8133+
) -> TFloatOrBFloat16:
8134+
# first dimension
8135+
n_win = op.Shape(window, start=0, end=1)
8136+
# Center window around zeros if needed (required by ONNX's STFT)
8137+
if n_win < n_fft:
8138+
left = (n_fft - n_win) / 2
8139+
8140+
right = n_fft - left - n_win
8141+
left = op.Reshape(left, op.Constant(value_ints=[1]))
8142+
right = op.Reshape(right, op.Constant(value_ints=[1]))
8143+
8144+
left_win = op.Expand(op.Constant(value_ints=[0]), left)
8145+
right_win = op.Expand(op.Constant(value_ints=[0]), right)
8146+
right_win = op.CastLike(right_win, window)
8147+
left_win = op.CastLike(left_win, window)
8148+
window = op.Concat(left_win, window, right_win, axis=0)
8149+
return window
8150+
8151+
8152+
@torch_op("aten::stft", private=True)
8153+
def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloat16:
8154+
left = (n_fft - win_length) / 2
8155+
8156+
right = n_fft - left - win_length
8157+
left = op.Reshape(left, op.Constant(value_ints=[1]))
8158+
right = op.Reshape(right, op.Constant(value_ints=[1]))
8159+
win_length = op.Reshape(win_length, op.Constant(value_ints=[1]))
8160+
8161+
left_win = op.Expand(op.Constant(value_ints=[0]), left)
8162+
right_win = op.Expand(op.Constant(value_ints=[0]), right)
8163+
window_list = op.Expand(op.Constant(value_ints=[1]), win_length)
8164+
return op.Concat(left_win, window_list, right_win, axis=0)
8165+
8166+
8167+
@torch_op("aten::stft", private=True)
8168+
def _create_window_from_n_fft(n_fft: int) -> TFloatOrBFloat16:
8169+
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
8170+
window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor)
8171+
return window
8172+
8173+
8174+
@torch_op("aten::stft", private=True)
8175+
def _normalize_fft_result(
8176+
signal: TFloatOrBFloat16, result: TFloatOrBFloat16, n_fft: int
8177+
) -> TFloatOrBFloat16:
8178+
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
8179+
sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal))
8180+
result = result / sqrt_nfft
8181+
return result
8182+
8183+
8184+
@torch_op("aten::stft", private=True)
8185+
def _aten_stft_onnx(
8186+
signal: TFloatOrBFloat16,
8187+
frame_step_const: INT64,
8188+
window: Union[TFloatOrBFloat16, INT64],
8189+
frame_length_const: INT64,
8190+
signal_rank: INT64,
8191+
onesided: int,
8192+
) -> TFloatOrBFloat16:
8193+
window = op.CastLike(window, signal)
8194+
result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided)
8195+
result = op.Transpose(result, perm=[0, 2, 1, 3])
8196+
# Remove batch dimension, if needed
8197+
if signal_rank == 1:
8198+
result = op.Squeeze(result, op.Constant(value_ints=[0]))
8199+
return result
8200+
8201+
8202+
@torch_op("aten::stft", trace_only=True)
8203+
def aten_stft(
8204+
self: TFloatOrBFloat16,
8205+
n_fft: int,
8206+
hop_length: Optional[int] = None,
8207+
win_length: Optional[int] = None,
8208+
window: Optional[TFloatOrBFloat16] = None,
8209+
normalized: bool = False,
8210+
onesided: Optional[bool] = None,
8211+
return_complex: Optional[bool] = None,
8212+
) -> TFloatOrBFloat16:
8213+
"""stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""
8214+
8215+
# NOTE: regarless of the value of return_complex, we always return a real representation.
8216+
del return_complex
8217+
8218+
# Get STFT sizes
8219+
if hop_length is None:
8220+
# core dump
8221+
# hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
8222+
hop_length = n_fft // 4
8223+
frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1]))
8224+
frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1]))
8225+
8226+
# Pre-process input if needed
8227+
self, signal_rank = _add_batch_dimension(self)
8228+
8229+
# Get window and make sure it's the same size as `win_length` or `n_fft`
8230+
if window is not None and window.shape[0] is not None:
8231+
window = _center_window_around_zeros_if_needed(window, n_fft)
8232+
elif window is None:
8233+
if win_length is not None:
8234+
window = _create_window_from_win_length(win_length, n_fft)
8235+
else:
8236+
window = _create_window_from_n_fft(n_fft)
8237+
8238+
if onesided is None or onesided:
8239+
onesided = 1
8240+
else:
8241+
onesided = 0
8242+
# remove batch dimension included
8243+
result = _aten_stft_onnx(
8244+
self, frame_step_const, window, frame_length_const, signal_rank, onesided
8245+
)
8246+
8247+
# Normalize, if needed
8248+
if normalized:
8249+
result = _normalize_fft_result(self, result, n_fft)
8250+
8251+
return result
8252+
8253+
81218254
@torch_op(
81228255
(
81238256
"aten::sub.Tensor",

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,6 +1760,14 @@ def _where_input_wrangler(
17601760
TorchLibOpInfo("ops.aten.scatter.value", core_ops.aten_scatter_value),
17611761
TorchLibOpInfo("slice", core_ops.aten_slice),
17621762
TorchLibOpInfo("slice", core_ops.aten_slice_complex, complex=True),
1763+
TorchLibOpInfo(
1764+
"ops.aten.stft", # Custom from extra_opinfo
1765+
core_ops.aten_stft,
1766+
tolerance={torch.float32: (3.7e-5, 1.8e-4)},
1767+
).xfail(
1768+
dtypes=(torch.float16,),
1769+
reason="RuntimeError: MKL FFT doesn't support tensors of type: Half",
1770+
),
17631771
TorchLibOpInfo(
17641772
"sum",
17651773
core_ops.aten_sum_dim_IntList,

0 commit comments

Comments
 (0)