@@ -8128,6 +8128,139 @@ def aten_std_mean_correction(
81288128 return op .Sqrt (var ), mean
81298129
81308130
8131+ @torch_op ("aten::stft" , private = True )
8132+ def _add_batch_dimension (self : TFloatOrBFloat16 ) -> Tuple [TFloatOrBFloat16 , INT64 ]:
8133+ signal_rank = Rank (self )
8134+ if signal_rank == 1 :
8135+ # Add a batch dimension
8136+ self = op .Unsqueeze (self , op .Constant (value_ints = [0 ]))
8137+ return op .Identity (self ), signal_rank
8138+
8139+
8140+ @torch_op ("aten::stft" , private = True )
8141+ def _center_window_around_zeros_if_needed (
8142+ window : TFloatOrBFloat16 , n_fft : int
8143+ ) -> TFloatOrBFloat16 :
8144+ # first dimension
8145+ n_win = op .Shape (window , start = 0 , end = 1 )
8146+ # Center window around zeros if needed (required by ONNX's STFT)
8147+ if n_win < n_fft :
8148+ left = (n_fft - n_win ) / 2
8149+
8150+ right = n_fft - left - n_win
8151+ left = op .Reshape (left , op .Constant (value_ints = [1 ]))
8152+ right = op .Reshape (right , op .Constant (value_ints = [1 ]))
8153+
8154+ left_win = op .Expand (op .Constant (value_ints = [0 ]), left )
8155+ right_win = op .Expand (op .Constant (value_ints = [0 ]), right )
8156+ right_win = op .CastLike (right_win , window )
8157+ left_win = op .CastLike (left_win , window )
8158+ window = op .Concat (left_win , window , right_win , axis = 0 )
8159+ return window
8160+
8161+
8162+ @torch_op ("aten::stft" , private = True )
8163+ def _create_window_from_win_length (win_length : int , n_fft : int ) -> TFloatOrBFloat16 :
8164+ left = (n_fft - win_length ) / 2
8165+
8166+ right = n_fft - left - win_length
8167+ left = op .Reshape (left , op .Constant (value_ints = [1 ]))
8168+ right = op .Reshape (right , op .Constant (value_ints = [1 ]))
8169+ win_length = op .Reshape (win_length , op .Constant (value_ints = [1 ]))
8170+
8171+ left_win = op .Expand (op .Constant (value_ints = [0 ]), left )
8172+ right_win = op .Expand (op .Constant (value_ints = [0 ]), right )
8173+ window_list = op .Expand (op .Constant (value_ints = [1 ]), win_length )
8174+ return op .Concat (left_win , window_list , right_win , axis = 0 )
8175+
8176+
8177+ @torch_op ("aten::stft" , private = True )
8178+ def _create_window_from_n_fft (n_fft : int ) -> TFloatOrBFloat16 :
8179+ n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
8180+ window = op .Expand (op .Constant (value_ints = [1 ]), n_fft_tensor )
8181+ return window
8182+
8183+
8184+ @torch_op ("aten::stft" , private = True )
8185+ def _normalize_fft_result (
8186+ signal : TFloatOrBFloat16 , result : TFloatOrBFloat16 , n_fft : int
8187+ ) -> TFloatOrBFloat16 :
8188+ n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
8189+ sqrt_nfft = op .Sqrt (op .CastLike (n_fft_tensor , signal ))
8190+ result = result / sqrt_nfft
8191+ return result
8192+
8193+
8194+ @torch_op ("aten::stft" , private = True )
8195+ def _aten_stft_onnx (
8196+ signal : TFloatOrBFloat16 ,
8197+ frame_step_const : INT64 ,
8198+ window : Union [TFloatOrBFloat16 , INT64 ],
8199+ frame_length_const : INT64 ,
8200+ signal_rank : INT64 ,
8201+ onesided : int ,
8202+ ) -> TFloatOrBFloat16 :
8203+ window = op .CastLike (window , signal )
8204+ result = op .STFT (signal , frame_step_const , window , frame_length_const , onesided = onesided )
8205+ result = op .Transpose (result , perm = [0 , 2 , 1 , 3 ])
8206+ # Remove batch dimension, if needed
8207+ if signal_rank == 1 :
8208+ result = op .Squeeze (result , op .Constant (value_ints = [0 ]))
8209+ return result
8210+
8211+
8212+ @torch_op ("aten::stft" , trace_only = True )
8213+ def aten_stft (
8214+ self : TFloatOrBFloat16 ,
8215+ n_fft : int ,
8216+ hop_length : Optional [int ] = None ,
8217+ win_length : Optional [int ] = None ,
8218+ window : Optional [TFloatOrBFloat16 ] = None ,
8219+ normalized : bool = False ,
8220+ onesided : Optional [bool ] = None ,
8221+ return_complex : Optional [bool ] = None ,
8222+ ) -> TFloatOrBFloat16 :
8223+ """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""
8224+
8225+ # NOTE: regarless of the value of return_complex, we always return a real representation.
8226+ del return_complex
8227+
8228+ # Get STFT sizes
8229+ if hop_length is None :
8230+ # core dump
8231+ # hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
8232+ hop_length = n_fft // 4
8233+ frame_step_const = op .Reshape (hop_length , op .Constant (value_ints = [1 ]))
8234+ frame_length_const = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
8235+
8236+ # Pre-process input if needed
8237+ self , signal_rank = _add_batch_dimension (self )
8238+
8239+ # Get window and make sure it's the same size as `win_length` or `n_fft`
8240+ if window is not None and window .shape [0 ] is not None :
8241+ window = _center_window_around_zeros_if_needed (window , n_fft )
8242+ elif window is None :
8243+ if win_length is not None :
8244+ window = _create_window_from_win_length (win_length , n_fft )
8245+ else :
8246+ window = _create_window_from_n_fft (n_fft )
8247+
8248+ if onesided is None or onesided :
8249+ onesided = 1
8250+ else :
8251+ onesided = 0
8252+ # remove batch dimension included
8253+ result = _aten_stft_onnx (
8254+ self , frame_step_const , window , frame_length_const , signal_rank , onesided
8255+ )
8256+
8257+ # Normalize, if needed
8258+ if normalized :
8259+ result = _normalize_fft_result (self , result , n_fft )
8260+
8261+ return result
8262+
8263+
81318264@torch_op (
81328265 (
81338266 "aten::sub.Tensor" ,
0 commit comments