@@ -8118,6 +8118,139 @@ def aten_std_mean_correction(
81188118 return op .Sqrt (var ), mean
81198119
81208120
8121+ @torch_op ("aten::stft" , private = True )
8122+ def _add_batch_dimension (self : TFloatOrBFloat16 ) -> Tuple [TFloatOrBFloat16 , INT64 ]:
8123+ signal_rank = Rank (self )
8124+ if signal_rank == 1 :
8125+ # Add a batch dimension
8126+ self = op .Unsqueeze (self , op .Constant (value_ints = [0 ]))
8127+ return op .Identity (self ), signal_rank
8128+
8129+
8130+ @torch_op ("aten::stft" , private = True )
8131+ def _center_window_around_zeros_if_needed (
8132+ window : TFloatOrBFloat16 , n_fft : int
8133+ ) -> TFloatOrBFloat16 :
8134+ # first dimension
8135+ n_win = op .Shape (window , start = 0 , end = 1 )
8136+ # Center window around zeros if needed (required by ONNX's STFT)
8137+ if n_win < n_fft :
8138+ left = (n_fft - n_win ) / 2
8139+
8140+ right = n_fft - left - n_win
8141+ left = op .Reshape (left , op .Constant (value_ints = [1 ]))
8142+ right = op .Reshape (right , op .Constant (value_ints = [1 ]))
8143+
8144+ left_win = op .Expand (op .Constant (value_ints = [0 ]), left )
8145+ right_win = op .Expand (op .Constant (value_ints = [0 ]), right )
8146+ right_win = op .CastLike (right_win , window )
8147+ left_win = op .CastLike (left_win , window )
8148+ window = op .Concat (left_win , window , right_win , axis = 0 )
8149+ return window
8150+
8151+
8152+ @torch_op ("aten::stft" , private = True )
8153+ def _create_window_from_win_length (win_length : int , n_fft : int ) -> TFloatOrBFloat16 :
8154+ left = (n_fft - win_length ) / 2
8155+
8156+ right = n_fft - left - win_length
8157+ left = op .Reshape (left , op .Constant (value_ints = [1 ]))
8158+ right = op .Reshape (right , op .Constant (value_ints = [1 ]))
8159+ win_length = op .Reshape (win_length , op .Constant (value_ints = [1 ]))
8160+
8161+ left_win = op .Expand (op .Constant (value_ints = [0 ]), left )
8162+ right_win = op .Expand (op .Constant (value_ints = [0 ]), right )
8163+ window_list = op .Expand (op .Constant (value_ints = [1 ]), win_length )
8164+ return op .Concat (left_win , window_list , right_win , axis = 0 )
8165+
8166+
8167+ @torch_op ("aten::stft" , private = True )
8168+ def _create_window_from_n_fft (n_fft : int ) -> TFloatOrBFloat16 :
8169+ n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
8170+ window = op .Expand (op .Constant (value_ints = [1 ]), n_fft_tensor )
8171+ return window
8172+
8173+
8174+ @torch_op ("aten::stft" , private = True )
8175+ def _normalize_fft_result (
8176+ signal : TFloatOrBFloat16 , result : TFloatOrBFloat16 , n_fft : int
8177+ ) -> TFloatOrBFloat16 :
8178+ n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
8179+ sqrt_nfft = op .Sqrt (op .CastLike (n_fft_tensor , signal ))
8180+ result = result / sqrt_nfft
8181+ return result
8182+
8183+
8184+ @torch_op ("aten::stft" , private = True )
8185+ def _aten_stft_onnx (
8186+ signal : TFloatOrBFloat16 ,
8187+ frame_step_const : INT64 ,
8188+ window : Union [TFloatOrBFloat16 , INT64 ],
8189+ frame_length_const : INT64 ,
8190+ signal_rank : INT64 ,
8191+ onesided : int ,
8192+ ) -> TFloatOrBFloat16 :
8193+ window = op .CastLike (window , signal )
8194+ result = op .STFT (signal , frame_step_const , window , frame_length_const , onesided = onesided )
8195+ result = op .Transpose (result , perm = [0 , 2 , 1 , 3 ])
8196+ # Remove batch dimension, if needed
8197+ if signal_rank == 1 :
8198+ result = op .Squeeze (result , op .Constant (value_ints = [0 ]))
8199+ return result
8200+
8201+
8202+ @torch_op ("aten::stft" , trace_only = True )
8203+ def aten_stft (
8204+ self : TFloatOrBFloat16 ,
8205+ n_fft : int ,
8206+ hop_length : Optional [int ] = None ,
8207+ win_length : Optional [int ] = None ,
8208+ window : Optional [TFloatOrBFloat16 ] = None ,
8209+ normalized : bool = False ,
8210+ onesided : Optional [bool ] = None ,
8211+ return_complex : Optional [bool ] = None ,
8212+ ) -> TFloatOrBFloat16 :
8213+ """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""
8214+
8215+ # NOTE: regarless of the value of return_complex, we always return a real representation.
8216+ del return_complex
8217+
8218+ # Get STFT sizes
8219+ if hop_length is None :
8220+ # core dump
8221+ # hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
8222+ hop_length = n_fft // 4
8223+ frame_step_const = op .Reshape (hop_length , op .Constant (value_ints = [1 ]))
8224+ frame_length_const = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
8225+
8226+ # Pre-process input if needed
8227+ self , signal_rank = _add_batch_dimension (self )
8228+
8229+ # Get window and make sure it's the same size as `win_length` or `n_fft`
8230+ if window is not None and window .shape [0 ] is not None :
8231+ window = _center_window_around_zeros_if_needed (window , n_fft )
8232+ elif window is None :
8233+ if win_length is not None :
8234+ window = _create_window_from_win_length (win_length , n_fft )
8235+ else :
8236+ window = _create_window_from_n_fft (n_fft )
8237+
8238+ if onesided is None or onesided :
8239+ onesided = 1
8240+ else :
8241+ onesided = 0
8242+ # remove batch dimension included
8243+ result = _aten_stft_onnx (
8244+ self , frame_step_const , window , frame_length_const , signal_rank , onesided
8245+ )
8246+
8247+ # Normalize, if needed
8248+ if normalized :
8249+ result = _normalize_fft_result (self , result , n_fft )
8250+
8251+ return result
8252+
8253+
81218254@torch_op (
81228255 (
81238256 "aten::sub.Tensor" ,
0 commit comments