diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 08d2dcef11..51f54eeaf8 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -374,7 +374,8 @@ class MelScale(torch.nn.Module): :py:func:`torchaudio.functional.melscale_fbanks` - The function used to generate the filter banks. """ - __constants__ = ["n_mels", "sample_rate", "f_min", "f_max"] + + __constants__ = ["n_mels", "sample_rate", "f_min", "f_max", "n_stft"] def __init__( self, @@ -391,13 +392,16 @@ def __init__( self.sample_rate = sample_rate self.f_max = f_max if f_max is not None else float(sample_rate // 2) self.f_min = f_min + self.n_stft = n_stft self.norm = norm self.mel_scale = mel_scale if f_min > self.f_max: raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max)) - fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale) + fb = F.melscale_fbanks( + self.n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale + ) self.register_buffer("fb", fb) def forward(self, specgram: Tensor) -> Tensor: @@ -464,10 +468,13 @@ def __init__( driver: str = "gels", ) -> None: super(InverseMelScale, self).__init__() + self.n_stft = n_stft self.n_mels = n_mels self.sample_rate = sample_rate self.f_max = f_max or float(sample_rate // 2) self.f_min = f_min + self.norm = norm + self.mel_scale = mel_scale self.driver = driver if f_min > self.f_max: @@ -476,7 +483,9 @@ def __init__( if driver not in ["gels", "gelsy", "gelsd", "gelss"]: raise ValueError(f'driver must be one of ["gels", "gelsy", "gelsd", "gelss"]. Found {driver}.') - fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm, mel_scale) + fb = F.melscale_fbanks( + self.n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale + ) self.register_buffer("fb", fb) def forward(self, melspec: Tensor) -> Tensor: