Skip to content

Commit a5a0ccf

Browse files
authored
[core] AutoencoderMixin to abstract common methods (#12473)
* up * correct wording. * up * up * up
1 parent dd07b19 commit a5a0ccf

19 files changed

+74
-357
lines changed

src/diffusers/models/autoencoders/autoencoder_asym_kl.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
from ...utils.accelerate_utils import apply_forward_hook
2121
from ..modeling_outputs import AutoencoderKLOutput
2222
from ..modeling_utils import ModelMixin
23-
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
23+
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
2424

2525

26-
class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
26+
class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
2727
r"""
2828
Designing a Better Asymmetric VQGAN for StableDiffusion https://huggingface.co/papers/2306.04632 . A VAE model with
2929
KL loss for encoding images into latents and decoding latent representations into images.
@@ -107,9 +107,6 @@ def __init__(
107107
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
108108
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
109109

110-
self.use_slicing = False
111-
self.use_tiling = False
112-
113110
self.register_to_config(block_out_channels=up_block_out_channels)
114111
self.register_to_config(force_upcast=False)
115112

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ..modeling_utils import ModelMixin
2828
from ..normalization import RMSNorm, get_normalization
2929
from ..transformers.sana_transformer import GLUMBConv
30-
from .vae import DecoderOutput, EncoderOutput
30+
from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput
3131

3232

3333
class ResBlock(nn.Module):
@@ -378,7 +378,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
378378
return hidden_states
379379

380380

381-
class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
381+
class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
382382
r"""
383383
An Autoencoder model introduced in [DCAE](https://huggingface.co/papers/2410.10733) and used in
384384
[SANA](https://huggingface.co/papers/2410.10629).
@@ -536,27 +536,6 @@ def enable_tiling(
536536
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
537537
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
538538

539-
def disable_tiling(self) -> None:
540-
r"""
541-
Disable tiled AE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
542-
decoding in one step.
543-
"""
544-
self.use_tiling = False
545-
546-
def enable_slicing(self) -> None:
547-
r"""
548-
Enable sliced AE decoding. When this option is enabled, the AE will split the input tensor in slices to compute
549-
decoding in several steps. This is useful to save some memory and allow larger batch sizes.
550-
"""
551-
self.use_slicing = True
552-
553-
def disable_slicing(self) -> None:
554-
r"""
555-
Disable sliced AE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
556-
decoding in one step.
557-
"""
558-
self.use_slicing = False
559-
560539
def _encode(self, x: torch.Tensor) -> torch.Tensor:
561540
batch_size, num_channels, height, width = x.shape
562541

src/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
)
3333
from ..modeling_outputs import AutoencoderKLOutput
3434
from ..modeling_utils import ModelMixin
35-
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
35+
from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
3636

3737

38-
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
38+
class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
3939
r"""
4040
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
4141
@@ -138,35 +138,6 @@ def __init__(
138138
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
139139
self.tile_overlap_factor = 0.25
140140

141-
def enable_tiling(self, use_tiling: bool = True):
142-
r"""
143-
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
144-
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
145-
processing larger images.
146-
"""
147-
self.use_tiling = use_tiling
148-
149-
def disable_tiling(self):
150-
r"""
151-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
152-
decoding in one step.
153-
"""
154-
self.enable_tiling(False)
155-
156-
def enable_slicing(self):
157-
r"""
158-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
159-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
160-
"""
161-
self.use_slicing = True
162-
163-
def disable_slicing(self):
164-
r"""
165-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
166-
decoding in one step.
167-
"""
168-
self.use_slicing = False
169-
170141
@property
171142
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
172143
def attn_processors(self) -> Dict[str, AttentionProcessor]:

src/diffusers/models/autoencoders/autoencoder_kl_allegro.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ..modeling_utils import ModelMixin
2929
from ..resnet import ResnetBlock2D
3030
from ..upsampling import Upsample2D
31+
from .vae import AutoencoderMixin
3132

3233

3334
class AllegroTemporalConvLayer(nn.Module):
@@ -673,7 +674,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
673674
return sample
674675

675676

676-
class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
677+
class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
677678
r"""
678679
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
679680
[Allegro](https://github.com/rhymes-ai/Allegro).
@@ -795,35 +796,6 @@ def __init__(
795796
sample_size - self.tile_overlap_w,
796797
)
797798

798-
def enable_tiling(self) -> None:
799-
r"""
800-
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
801-
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
802-
processing larger images.
803-
"""
804-
self.use_tiling = True
805-
806-
def disable_tiling(self) -> None:
807-
r"""
808-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
809-
decoding in one step.
810-
"""
811-
self.use_tiling = False
812-
813-
def enable_slicing(self) -> None:
814-
r"""
815-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
816-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
817-
"""
818-
self.use_slicing = True
819-
820-
def disable_slicing(self) -> None:
821-
r"""
822-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
823-
decoding in one step.
824-
"""
825-
self.use_slicing = False
826-
827799
def _encode(self, x: torch.Tensor) -> torch.Tensor:
828800
# TODO(aryan)
829801
# if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ..modeling_outputs import AutoencoderKLOutput
3030
from ..modeling_utils import ModelMixin
3131
from ..upsampling import CogVideoXUpsample3D
32-
from .vae import DecoderOutput, DiagonalGaussianDistribution
32+
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
3333

3434

3535
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -955,7 +955,7 @@ def forward(
955955
return hidden_states, new_conv_cache
956956

957957

958-
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
958+
class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
959959
r"""
960960
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
961961
[CogVideoX](https://github.com/THUDM/CogVideo).
@@ -1124,27 +1124,6 @@ def enable_tiling(
11241124
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
11251125
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
11261126

1127-
def disable_tiling(self) -> None:
1128-
r"""
1129-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1130-
decoding in one step.
1131-
"""
1132-
self.use_tiling = False
1133-
1134-
def enable_slicing(self) -> None:
1135-
r"""
1136-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1137-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1138-
"""
1139-
self.use_slicing = True
1140-
1141-
def disable_slicing(self) -> None:
1142-
r"""
1143-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1144-
decoding in one step.
1145-
"""
1146-
self.use_slicing = False
1147-
11481127
def _encode(self, x: torch.Tensor) -> torch.Tensor:
11491128
batch_size, num_channels, num_frames, height, width = x.shape
11501129

src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ...utils.accelerate_utils import apply_forward_hook
2525
from ..modeling_outputs import AutoencoderKLOutput
2626
from ..modeling_utils import ModelMixin
27-
from .vae import DecoderOutput, IdentityDistribution
27+
from .vae import AutoencoderMixin, DecoderOutput, IdentityDistribution
2828

2929

3030
logger = get_logger(__name__)
@@ -875,7 +875,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
875875
return hidden_states
876876

877877

878-
class AutoencoderKLCosmos(ModelMixin, ConfigMixin):
878+
class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
879879
r"""
880880
Autoencoder used in [Cosmos](https://huggingface.co/papers/2501.03575).
881881
@@ -1031,27 +1031,6 @@ def enable_tiling(
10311031
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
10321032
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
10331033

1034-
def disable_tiling(self) -> None:
1035-
r"""
1036-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1037-
decoding in one step.
1038-
"""
1039-
self.use_tiling = False
1040-
1041-
def enable_slicing(self) -> None:
1042-
r"""
1043-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1044-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1045-
"""
1046-
self.use_slicing = True
1047-
1048-
def disable_slicing(self) -> None:
1049-
r"""
1050-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1051-
decoding in one step.
1052-
"""
1053-
self.use_slicing = False
1054-
10551034
def _encode(self, x: torch.Tensor) -> torch.Tensor:
10561035
x = self.encoder(x)
10571036
enc = self.quant_conv(x)

src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..attention_processor import Attention
2727
from ..modeling_outputs import AutoencoderKLOutput
2828
from ..modeling_utils import ModelMixin
29-
from .vae import DecoderOutput, DiagonalGaussianDistribution
29+
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
3030

3131

3232
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -624,7 +624,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
624624
return hidden_states
625625

626626

627-
class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
627+
class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
628628
r"""
629629
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
630630
Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
@@ -763,27 +763,6 @@ def enable_tiling(
763763
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
764764
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
765765

766-
def disable_tiling(self) -> None:
767-
r"""
768-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
769-
decoding in one step.
770-
"""
771-
self.use_tiling = False
772-
773-
def enable_slicing(self) -> None:
774-
r"""
775-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
776-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
777-
"""
778-
self.use_slicing = True
779-
780-
def disable_slicing(self) -> None:
781-
r"""
782-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
783-
decoding in one step.
784-
"""
785-
self.use_slicing = False
786-
787766
def _encode(self, x: torch.Tensor) -> torch.Tensor:
788767
batch_size, num_channels, num_frames, height, width = x.shape
789768

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..modeling_outputs import AutoencoderKLOutput
2727
from ..modeling_utils import ModelMixin
2828
from ..normalization import RMSNorm
29-
from .vae import DecoderOutput, DiagonalGaussianDistribution
29+
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
3030

3131

3232
class LTXVideoCausalConv3d(nn.Module):
@@ -1034,7 +1034,7 @@ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = No
10341034
return hidden_states
10351035

10361036

1037-
class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1037+
class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
10381038
r"""
10391039
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
10401040
[LTX](https://huggingface.co/Lightricks/LTX-Video).
@@ -1219,27 +1219,6 @@ def enable_tiling(
12191219
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
12201220
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
12211221

1222-
def disable_tiling(self) -> None:
1223-
r"""
1224-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1225-
decoding in one step.
1226-
"""
1227-
self.use_tiling = False
1228-
1229-
def enable_slicing(self) -> None:
1230-
r"""
1231-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1232-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1233-
"""
1234-
self.use_slicing = True
1235-
1236-
def disable_slicing(self) -> None:
1237-
r"""
1238-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1239-
decoding in one step.
1240-
"""
1241-
self.use_slicing = False
1242-
12431222
def _encode(self, x: torch.Tensor) -> torch.Tensor:
12441223
batch_size, num_channels, num_frames, height, width = x.shape
12451224

0 commit comments

Comments
 (0)