Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/torchaudio/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ class MelScale(torch.nn.Module):
:py:func:`torchaudio.functional.melscale_fbanks` - The function used to
generate the filter banks.
"""
__constants__ = ["n_mels", "sample_rate", "f_min", "f_max"]

__constants__ = ["n_mels", "sample_rate", "f_min", "f_max", "n_stft"]

def __init__(
self,
Expand All @@ -391,13 +392,16 @@ def __init__(
self.sample_rate = sample_rate
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
self.f_min = f_min
self.n_stft = n_stft
self.norm = norm
self.mel_scale = mel_scale

if f_min > self.f_max:
raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max))

fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale)
fb = F.melscale_fbanks(
self.n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale
)
self.register_buffer("fb", fb)

def forward(self, specgram: Tensor) -> Tensor:
Expand Down Expand Up @@ -464,10 +468,13 @@ def __init__(
driver: str = "gels",
) -> None:
super(InverseMelScale, self).__init__()
self.n_stft = n_stft
self.n_mels = n_mels
self.sample_rate = sample_rate
self.f_max = f_max or float(sample_rate // 2)
self.f_min = f_min
self.norm = norm
self.mel_scale = mel_scale
self.driver = driver

if f_min > self.f_max:
Expand All @@ -476,7 +483,9 @@ def __init__(
if driver not in ["gels", "gelsy", "gelsd", "gelss"]:
raise ValueError(f'driver must be one of ["gels", "gelsy", "gelsd", "gelss"]. Found {driver}.')

fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm, mel_scale)
fb = F.melscale_fbanks(
self.n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale
)
self.register_buffer("fb", fb)

def forward(self, melspec: Tensor) -> Tensor:
Expand Down
Loading