Skip to content

Commit 44786f4

Browse files
committed
Revert "[torchlib] Unregister stft, var, var_mean, std, std_mean" (#1867)
This reverts commit 1eef633.
1 parent 5be9d3b commit 44786f4

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8128,6 +8128,139 @@ def aten_std_mean_correction(
81288128
return op.Sqrt(var), mean
81298129

81308130

8131+
@torch_op("aten::stft", private=True)
8132+
def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT64]:
8133+
signal_rank = Rank(self)
8134+
if signal_rank == 1:
8135+
# Add a batch dimension
8136+
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
8137+
return op.Identity(self), signal_rank
8138+
8139+
8140+
@torch_op("aten::stft", private=True)
8141+
def _center_window_around_zeros_if_needed(
8142+
window: TFloatOrBFloat16, n_fft: int
8143+
) -> TFloatOrBFloat16:
8144+
# first dimension
8145+
n_win = op.Shape(window, start=0, end=1)
8146+
# Center window around zeros if needed (required by ONNX's STFT)
8147+
if n_win < n_fft:
8148+
left = (n_fft - n_win) / 2
8149+
8150+
right = n_fft - left - n_win
8151+
left = op.Reshape(left, op.Constant(value_ints=[1]))
8152+
right = op.Reshape(right, op.Constant(value_ints=[1]))
8153+
8154+
left_win = op.Expand(op.Constant(value_ints=[0]), left)
8155+
right_win = op.Expand(op.Constant(value_ints=[0]), right)
8156+
right_win = op.CastLike(right_win, window)
8157+
left_win = op.CastLike(left_win, window)
8158+
window = op.Concat(left_win, window, right_win, axis=0)
8159+
return window
8160+
8161+
8162+
@torch_op("aten::stft", private=True)
8163+
def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloat16:
8164+
left = (n_fft - win_length) / 2
8165+
8166+
right = n_fft - left - win_length
8167+
left = op.Reshape(left, op.Constant(value_ints=[1]))
8168+
right = op.Reshape(right, op.Constant(value_ints=[1]))
8169+
win_length = op.Reshape(win_length, op.Constant(value_ints=[1]))
8170+
8171+
left_win = op.Expand(op.Constant(value_ints=[0]), left)
8172+
right_win = op.Expand(op.Constant(value_ints=[0]), right)
8173+
window_list = op.Expand(op.Constant(value_ints=[1]), win_length)
8174+
return op.Concat(left_win, window_list, right_win, axis=0)
8175+
8176+
8177+
@torch_op("aten::stft", private=True)
8178+
def _create_window_from_n_fft(n_fft: int) -> TFloatOrBFloat16:
8179+
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
8180+
window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor)
8181+
return window
8182+
8183+
8184+
@torch_op("aten::stft", private=True)
8185+
def _normalize_fft_result(
8186+
signal: TFloatOrBFloat16, result: TFloatOrBFloat16, n_fft: int
8187+
) -> TFloatOrBFloat16:
8188+
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
8189+
sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal))
8190+
result = result / sqrt_nfft
8191+
return result
8192+
8193+
8194+
@torch_op("aten::stft", private=True)
8195+
def _aten_stft_onnx(
8196+
signal: TFloatOrBFloat16,
8197+
frame_step_const: INT64,
8198+
window: Union[TFloatOrBFloat16, INT64],
8199+
frame_length_const: INT64,
8200+
signal_rank: INT64,
8201+
onesided: int,
8202+
) -> TFloatOrBFloat16:
8203+
window = op.CastLike(window, signal)
8204+
result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided)
8205+
result = op.Transpose(result, perm=[0, 2, 1, 3])
8206+
# Remove batch dimension, if needed
8207+
if signal_rank == 1:
8208+
result = op.Squeeze(result, op.Constant(value_ints=[0]))
8209+
return result
8210+
8211+
8212+
@torch_op("aten::stft", trace_only=True)
8213+
def aten_stft(
8214+
self: TFloatOrBFloat16,
8215+
n_fft: int,
8216+
hop_length: Optional[int] = None,
8217+
win_length: Optional[int] = None,
8218+
window: Optional[TFloatOrBFloat16] = None,
8219+
normalized: bool = False,
8220+
onesided: Optional[bool] = None,
8221+
return_complex: Optional[bool] = None,
8222+
) -> TFloatOrBFloat16:
8223+
"""stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""
8224+
8225+
# NOTE: regarless of the value of return_complex, we always return a real representation.
8226+
del return_complex
8227+
8228+
# Get STFT sizes
8229+
if hop_length is None:
8230+
# core dump
8231+
# hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
8232+
hop_length = n_fft // 4
8233+
frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1]))
8234+
frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1]))
8235+
8236+
# Pre-process input if needed
8237+
self, signal_rank = _add_batch_dimension(self)
8238+
8239+
# Get window and make sure it's the same size as `win_length` or `n_fft`
8240+
if window is not None and window.shape[0] is not None:
8241+
window = _center_window_around_zeros_if_needed(window, n_fft)
8242+
elif window is None:
8243+
if win_length is not None:
8244+
window = _create_window_from_win_length(win_length, n_fft)
8245+
else:
8246+
window = _create_window_from_n_fft(n_fft)
8247+
8248+
if onesided is None or onesided:
8249+
onesided = 1
8250+
else:
8251+
onesided = 0
8252+
# remove batch dimension included
8253+
result = _aten_stft_onnx(
8254+
self, frame_step_const, window, frame_length_const, signal_rank, onesided
8255+
)
8256+
8257+
# Normalize, if needed
8258+
if normalized:
8259+
result = _normalize_fft_result(self, result, n_fft)
8260+
8261+
return result
8262+
8263+
81318264
@torch_op(
81328265
(
81338266
"aten::sub.Tensor",

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,6 +1760,14 @@ def _where_input_wrangler(
17601760
TorchLibOpInfo("ops.aten.scatter.value", core_ops.aten_scatter_value),
17611761
TorchLibOpInfo("slice", core_ops.aten_slice),
17621762
TorchLibOpInfo("slice", core_ops.aten_slice_complex, complex=True),
1763+
TorchLibOpInfo(
1764+
"ops.aten.stft", # Custom from extra_opinfo
1765+
core_ops.aten_stft,
1766+
tolerance={torch.float32: (3.7e-5, 1.8e-4)},
1767+
).xfail(
1768+
dtypes=(torch.float16,),
1769+
reason="RuntimeError: MKL FFT doesn't support tensors of type: Half",
1770+
),
17631771
TorchLibOpInfo(
17641772
"sum",
17651773
core_ops.aten_sum_dim_IntList,

0 commit comments

Comments
 (0)