diff --git a/dac/model/base.py b/dac/model/base.py index 546b3cb..34510fa 100644 --- a/dac/model/base.py +++ b/dac/model/base.py @@ -215,6 +215,9 @@ def compress( codes = torch.cat(codes, dim=-1) + if n_quantizers is not None: + codes = codes[:, :n_quantizers, :] + dac_file = DACFile( codes=codes, chunk_length=chunk_length, @@ -226,9 +229,6 @@ def compress( dac_version=SUPPORTED_VERSIONS[-1], ) - if n_quantizers is not None: - codes = codes[:, :n_quantizers, :] - self.padding = original_padding return dac_file