Skip to content

Commit 29aaae3

Browse files
alancuckinv-kkudrynski
authored andcommitted
[Jasper/PyT] Update torch.stft for PyTorch 2.0
1 parent 9de48bc commit 29aaae3

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

PyTorch/SpeechRecognition/Jasper/common/features.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,13 @@ def get_seq_len(self, seq_len):
244244
return torch.ceil(seq_len.to(dtype=torch.float) / self.hop_length).to(
245245
dtype=torch.int)
246246

247-
# do stft
248247
# TORCHSCRIPT: center removed due to bug
249248
def stft(self, x):
250-
return torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
249+
spec = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
251250
win_length=self.win_length,
252-
window=self.window.to(dtype=torch.float))
251+
window=self.window.to(dtype=torch.float),
252+
return_complex=True)
253+
return torch.view_as_real(spec)
253254

254255
@torch.no_grad()
255256
def calculate_features(self, x, seq_len):

PyTorch/SpeechRecognition/QuartzNet/common/features.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,13 @@ def get_seq_len(self, seq_len):
248248
return torch.ceil(seq_len.to(dtype=torch.float) / self.hop_length).to(
249249
dtype=torch.int)
250250

251-
# do stft
252251
# TORCHSCRIPT: center removed due to bug
253252
def stft(self, x):
254-
return torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
253+
spec = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
255254
win_length=self.win_length,
256-
window=self.window.to(dtype=torch.float))
255+
window=self.window.to(dtype=torch.float),
256+
return_complex=True)
257+
return torch.view_as_real(spec)
257258

258259
@torch.no_grad()
259260
def calculate_features(self, x, seq_len):

PyTorch/SpeechRecognition/wav2vec2/common/features.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,13 @@ def get_seq_len(self, seq_len):
261261
return torch.ceil(seq_len.to(dtype=torch.float) / self.hop_length).to(
262262
dtype=torch.int)
263263

264-
# do stft
265264
# TORCHSCRIPT: center removed due to bug
266265
def stft(self, x):
267-
return torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
266+
spec = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
268267
win_length=self.win_length,
269-
window=self.window.to(dtype=torch.float))
268+
window=self.window.to(dtype=torch.float),
269+
return_complex=True)
270+
return torch.view_as_real(spec)
270271

271272
@torch.no_grad()
272273
def calculate_features(self, x, x_lens):

0 commit comments

Comments
 (0)