From 5ef4fb82c5546088257edd92321b148612acee2e Mon Sep 17 00:00:00 2001 From: sammlapp Date: Mon, 3 Nov 2025 20:16:17 -0500 Subject: [PATCH 1/2] retain arguments from init as properties in MelScale and InverseMelScale MelScale should retain n_stft as an attribute Fixes #4122 I observed inconsistency between the MelScale and InverseMelScale in which arguments were retained as attributes. Perhaps there was some reason for this, but I think it probably makes sense to save all arguments as attributes for both classes. This allows serialization of the object for re-initialization with the same parameters. --- src/torchaudio/transforms/_transforms.py | 32 +++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 08d2dcef11..a70306ddee 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -59,6 +59,7 @@ class Spectrogram(torch.nn.Module): >>> spectrogram = transform(waveform) """ + __constants__ = ["n_fft", "win_length", "hop_length", "pad", "power", "normalized"] def __init__( @@ -156,6 +157,7 @@ class InverseSpectrogram(torch.nn.Module): >>> transform = transforms.InverseSpectrogram(n_fft=512) >>> waveform = transform(spectrogram, length) """ + __constants__ = ["n_fft", "win_length", "hop_length", "pad", "power", "normalized"] def __init__( @@ -242,6 +244,7 @@ class GriffinLim(torch.nn.Module): >>> transform = transforms.GriffinLim(n_fft=512) >>> waveform = transform(spectrogram) """ + __constants__ = ["n_fft", "n_iter", "win_length", "hop_length", "power", "length", "momentum", "rand_init"] def __init__( @@ -319,6 +322,7 @@ class AmplitudeToDB(torch.nn.Module): >>> transform = transforms.AmplitudeToDB(stype="amplitude", top_db=80) >>> waveform_db = transform(waveform) """ + __constants__ = ["multiplier", "amin", "ref_value", "db_multiplier"] def __init__(self, stype: str = "power", top_db: Optional[float] = None) -> None: @@ -374,7 +378,8 @@ class MelScale(torch.nn.Module): :py:func:`torchaudio.functional.melscale_fbanks` - The function used to generate the filter banks. """ - __constants__ = ["n_mels", "sample_rate", "f_min", "f_max"] + + __constants__ = ["n_mels", "sample_rate", "f_min", "f_max", "n_stft"] def __init__( self, @@ -391,13 +396,16 @@ def __init__( self.sample_rate = sample_rate self.f_max = f_max if f_max is not None else float(sample_rate // 2) self.f_min = f_min + self.n_stft = n_stft self.norm = norm self.mel_scale = mel_scale if f_min > self.f_max: raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max)) - fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale) + fb = F.melscale_fbanks( + self.n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale + ) self.register_buffer("fb", fb) def forward(self, specgram: Tensor) -> Tensor: @@ -444,6 +452,7 @@ class InverseMelScale(torch.nn.Module): >>> inverse_melscale_transform = transforms.InverseMelScale(n_stft=1024 // 2 + 1) >>> spectrogram = inverse_melscale_transform(mel_spectrogram) """ + __constants__ = [ "n_stft", "n_mels", @@ -464,10 +473,13 @@ def __init__( driver: str = "gels", ) -> None: super(InverseMelScale, self).__init__() + self.n_stft = n_stft self.n_mels = n_mels self.sample_rate = sample_rate self.f_max = f_max or float(sample_rate // 2) self.f_min = f_min + self.norm = norm + self.mel_scale = mel_scale self.driver = driver if f_min > self.f_max: @@ -476,7 +488,9 @@ def __init__( if driver not in ["gels", "gelsy", "gelsd", "gelss"]: raise ValueError(f'driver must be one of ["gels", "gelsy", "gelsd", "gelss"]. Found {driver}.') - fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm, mel_scale) + fb = F.melscale_fbanks( + self.n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale + ) self.register_buffer("fb", fb) def forward(self, melspec: Tensor) -> Tensor: @@ -552,6 +566,7 @@ class MelSpectrogram(torch.nn.Module): :py:func:`torchaudio.functional.melscale_fbanks` - The function used to generate the filter banks. """ + __constants__ = ["sample_rate", "n_fft", "win_length", "hop_length", "pad", "n_mels", "f_min"] def __init__( @@ -658,6 +673,7 @@ class MFCC(torch.nn.Module): :py:func:`torchaudio.functional.melscale_fbanks` - The function used to generate the filter banks. """ + __constants__ = ["sample_rate", "n_mfcc", "dct_type", "top_db", "log_mels"] def __init__( @@ -748,6 +764,7 @@ class LFCC(torch.nn.Module): :py:func:`torchaudio.functional.linear_fbanks` - The function used to generate the filter banks. """ + __constants__ = ["sample_rate", "n_filter", "n_lfcc", "dct_type", "top_db", "log_lf"] def __init__( @@ -841,6 +858,7 @@ class MuLawEncoding(torch.nn.Module): >>> mulawtrans = transform(waveform) """ + __constants__ = ["quantization_channels"] def __init__(self, quantization_channels: int = 256) -> None: @@ -879,6 +897,7 @@ class MuLawDecoding(torch.nn.Module): >>> transform = torchaudio.transforms.MuLawDecoding(quantization_channels=512) >>> mulawtrans = transform(waveform) """ + __constants__ = ["quantization_channels"] def __init__(self, quantization_channels: int = 256) -> None: @@ -993,6 +1012,7 @@ class ComputeDeltas(torch.nn.Module): win_length (int, optional): The window length used for computing delta. (Default: ``5``) mode (str, optional): Mode parameter passed to padding. (Default: ``"replicate"``) """ + __constants__ = ["win_length"] def __init__(self, win_length: int = 5, mode: str = "replicate") -> None: @@ -1043,6 +1063,7 @@ class TimeStretch(torch.nn.Module): :width: 600 :alt: The visualization of stretched spectrograms. """ + __constants__ = ["fixed_rate"] def __init__(self, hop_length: Optional[int] = None, n_freq: int = 201, fixed_rate: Optional[float] = None) -> None: @@ -1176,6 +1197,7 @@ class _AxisMasking(torch.nn.Module): This option is applicable only when the dimension of the input tensor is >= 3. p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0) """ + __constants__ = ["mask_param", "axis", "iid_masks", "p"] def __init__(self, mask_param: int, axis: int, iid_masks: bool, p: float = 1.0) -> None: @@ -1291,6 +1313,7 @@ class SpecAugment(torch.nn.Module): zero_masking (bool, optional): If ``True``, use 0 as the mask value, else use mean of the input tensor. (Default: ``False``) """ + __constants__ = [ "n_time_masks", "time_mask_param", @@ -1366,6 +1389,7 @@ class Loudness(torch.nn.Module): Reference: - https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en """ + __constants__ = ["sample_rate"] def __init__(self, sample_rate: int): @@ -1635,6 +1659,7 @@ class SpectralCentroid(torch.nn.Module): >>> transform = transforms.SpectralCentroid(sample_rate) >>> spectral_centroid = transform(waveform) # (channel, time) """ + __constants__ = ["sample_rate", "n_fft", "win_length", "hop_length", "pad"] def __init__( @@ -1694,6 +1719,7 @@ class PitchShift(LazyModuleMixin, torch.nn.Module): >>> transform = transforms.PitchShift(sample_rate, 4) >>> waveform_shift = transform(waveform) # (channel, time) """ + __constants__ = ["sample_rate", "n_steps", "bins_per_octave", "n_fft", "win_length", "hop_length"] kernel: UninitializedParameter From fae5180bc3b4869579bf2095b04f4224e80beaad Mon Sep 17 00:00:00 2001 From: sammlapp Date: Mon, 3 Nov 2025 20:24:25 -0500 Subject: [PATCH 2/2] Undo unintentional formatting --- src/torchaudio/transforms/_transforms.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index a70306ddee..51f54eeaf8 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -59,7 +59,6 @@ class Spectrogram(torch.nn.Module): >>> spectrogram = transform(waveform) """ - __constants__ = ["n_fft", "win_length", "hop_length", "pad", "power", "normalized"] def __init__( @@ -157,7 +156,6 @@ class InverseSpectrogram(torch.nn.Module): >>> transform = transforms.InverseSpectrogram(n_fft=512) >>> waveform = transform(spectrogram, length) """ - __constants__ = ["n_fft", "win_length", "hop_length", "pad", "power", "normalized"] def __init__( @@ -244,7 +242,6 @@ class GriffinLim(torch.nn.Module): >>> transform = transforms.GriffinLim(n_fft=512) >>> waveform = transform(spectrogram) """ - __constants__ = ["n_fft", "n_iter", "win_length", "hop_length", "power", "length", "momentum", "rand_init"] def __init__( @@ -322,7 +319,6 @@ class AmplitudeToDB(torch.nn.Module): >>> transform = transforms.AmplitudeToDB(stype="amplitude", top_db=80) >>> waveform_db = transform(waveform) """ - __constants__ = ["multiplier", "amin", "ref_value", "db_multiplier"] def __init__(self, stype: str = "power", top_db: Optional[float] = None) -> None: @@ -452,7 +448,6 @@ class InverseMelScale(torch.nn.Module): >>> inverse_melscale_transform = transforms.InverseMelScale(n_stft=1024 // 2 + 1) >>> spectrogram = inverse_melscale_transform(mel_spectrogram) """ - __constants__ = [ "n_stft", "n_mels", @@ -566,7 +561,6 @@ class MelSpectrogram(torch.nn.Module): :py:func:`torchaudio.functional.melscale_fbanks` - The function used to generate the filter banks. """ - __constants__ = ["sample_rate", "n_fft", "win_length", "hop_length", "pad", "n_mels", "f_min"] def __init__( @@ -673,7 +667,6 @@ class MFCC(torch.nn.Module): :py:func:`torchaudio.functional.melscale_fbanks` - The function used to generate the filter banks. """ - __constants__ = ["sample_rate", "n_mfcc", "dct_type", "top_db", "log_mels"] def __init__( @@ -764,7 +757,6 @@ class LFCC(torch.nn.Module): :py:func:`torchaudio.functional.linear_fbanks` - The function used to generate the filter banks. """ - __constants__ = ["sample_rate", "n_filter", "n_lfcc", "dct_type", "top_db", "log_lf"] def __init__( @@ -858,7 +850,6 @@ class MuLawEncoding(torch.nn.Module): >>> mulawtrans = transform(waveform) """ - __constants__ = ["quantization_channels"] def __init__(self, quantization_channels: int = 256) -> None: @@ -897,7 +888,6 @@ class MuLawDecoding(torch.nn.Module): >>> transform = torchaudio.transforms.MuLawDecoding(quantization_channels=512) >>> mulawtrans = transform(waveform) """ - __constants__ = ["quantization_channels"] def __init__(self, quantization_channels: int = 256) -> None: @@ -1012,7 +1002,6 @@ class ComputeDeltas(torch.nn.Module): win_length (int, optional): The window length used for computing delta. (Default: ``5``) mode (str, optional): Mode parameter passed to padding. (Default: ``"replicate"``) """ - __constants__ = ["win_length"] def __init__(self, win_length: int = 5, mode: str = "replicate") -> None: @@ -1063,7 +1052,6 @@ class TimeStretch(torch.nn.Module): :width: 600 :alt: The visualization of stretched spectrograms. """ - __constants__ = ["fixed_rate"] def __init__(self, hop_length: Optional[int] = None, n_freq: int = 201, fixed_rate: Optional[float] = None) -> None: @@ -1197,7 +1185,6 @@ class _AxisMasking(torch.nn.Module): This option is applicable only when the dimension of the input tensor is >= 3. p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0) """ - __constants__ = ["mask_param", "axis", "iid_masks", "p"] def __init__(self, mask_param: int, axis: int, iid_masks: bool, p: float = 1.0) -> None: @@ -1313,7 +1300,6 @@ class SpecAugment(torch.nn.Module): zero_masking (bool, optional): If ``True``, use 0 as the mask value, else use mean of the input tensor. (Default: ``False``) """ - __constants__ = [ "n_time_masks", "time_mask_param", @@ -1389,7 +1375,6 @@ class Loudness(torch.nn.Module): Reference: - https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en """ - __constants__ = ["sample_rate"] def __init__(self, sample_rate: int): @@ -1659,7 +1644,6 @@ class SpectralCentroid(torch.nn.Module): >>> transform = transforms.SpectralCentroid(sample_rate) >>> spectral_centroid = transform(waveform) # (channel, time) """ - __constants__ = ["sample_rate", "n_fft", "win_length", "hop_length", "pad"] def __init__( @@ -1719,7 +1703,6 @@ class PitchShift(LazyModuleMixin, torch.nn.Module): >>> transform = transforms.PitchShift(sample_rate, 4) >>> waveform_shift = transform(waveform) # (channel, time) """ - __constants__ = ["sample_rate", "n_steps", "bins_per_octave", "n_fft", "win_length", "hop_length"] kernel: UninitializedParameter