Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder


class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The AsymmetricAutoencoderKL class does not seem to implement tiling or slicing logic, and with the removal of self.use_slicing and self.use_tiling attributes, calling enable_tiling() or enable_slicing() will raise a NotImplementedError. This is misleading. It would be better to not inherit from AutoencoderMixin.

The corresponding import on line 27 should also be updated to remove AutoencoderMixin.

Suggested change
class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):

r"""
Designing a Better Asymmetric VQGAN for StableDiffusion https://huggingface.co/papers/2306.04632 . A VAE model with
KL loss for encoding images into latents and decoding latent representations into images.
Expand Down Expand Up @@ -112,9 +112,6 @@ def __init__(
self.quant_conv = mint.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = mint.nn.Conv2d(latent_channels, latent_channels, 1)

self.use_slicing = False
self.use_tiling = False

self.register_to_config(block_out_channels=up_block_out_channels)
self.register_to_config(force_upcast=False)

Expand Down
25 changes: 2 additions & 23 deletions mindone/diffusers/models/autoencoders/autoencoder_dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm, get_normalization
from ..transformers.sana_transformer import GLUMBConv
from .vae import DecoderOutput, EncoderOutput
from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput


class ResBlock(nn.Cell):
Expand Down Expand Up @@ -393,7 +393,7 @@ def construct(self, hidden_states: ms.Tensor) -> ms.Tensor:
return hidden_states


class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
An Autoencoder model introduced in [DCAE](https://huggingface.co/papers/2410.10733) and used in
[SANA](https://huggingface.co/papers/2410.10629).
Expand Down Expand Up @@ -551,27 +551,6 @@ def enable_tiling(
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio

def disable_tiling(self) -> None:
r"""
Disable tiled AE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False

def enable_slicing(self) -> None:
r"""
Enable sliced AE decoding. When this option is enabled, the AE will split the input tensor in slices to compute
decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self) -> None:
r"""
Disable sliced AE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

def _encode(self, x: ms.Tensor) -> ms.Tensor:
batch_size, num_channels, height, width = x.shape

Expand Down
33 changes: 2 additions & 31 deletions mindone/diffusers/models/autoencoders/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder


class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.

Expand Down Expand Up @@ -135,35 +135,6 @@ def __init__(
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25

def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling

def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)

def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
Expand Down
32 changes: 2 additions & 30 deletions mindone/diffusers/models/autoencoders/autoencoder_kl_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ..modeling_utils import ModelMixin
from ..resnet import ResnetBlock2D
from ..upsampling import Upsample2D
from .vae import AutoencoderMixin


class AllegroTemporalConvLayer(nn.Cell):
Expand Down Expand Up @@ -685,7 +686,7 @@ def construct(self, sample: ms.Tensor) -> ms.Tensor:
return sample


class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
[Allegro](https://github.com/rhymes-ai/Allegro).
Expand Down Expand Up @@ -808,35 +809,6 @@ def __init__(
sample_size - self.tile_overlap_w,
)

def enable_tiling(self) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = True

def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False

def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

def _encode(self, x: ms.Tensor) -> ms.Tensor:
# TODO(aryan)
# if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..upsampling import CogVideoXUpsample3D
from .vae import DecoderOutput, DiagonalGaussianDistribution
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -891,7 +891,7 @@ def construct(
return hidden_states, new_conv_cache


class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[CogVideoX](https://github.com/THUDM/CogVideo).
Expand Down Expand Up @@ -1061,27 +1061,6 @@ def enable_tiling(
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width

def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False

def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

def _encode(self, x: ms.Tensor) -> ms.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape

Expand Down
27 changes: 3 additions & 24 deletions mindone/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..layers_compat import conv_transpose3d, unflatten
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, IdentityDistribution
from .vae import AutoencoderMixin, DecoderOutput, IdentityDistribution

logger = get_logger(__name__)

Expand Down Expand Up @@ -915,7 +915,7 @@ def construct(self, hidden_states: ms.tensor) -> ms.tensor:
return hidden_states


class AutoencoderKLCosmos(ModelMixin, ConfigMixin):
class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
Autoencoder used in [Cosmos](https://huggingface.co/papers/2501.03575).

Expand Down Expand Up @@ -1072,28 +1072,7 @@ def enable_tiling(
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames

def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False

def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

def _encode(self, x: ms.tensor) -> ms.tensor:
def _encode(self, x: ms.Tensor) -> ms.Tensor:
x = self.encoder(x)
enc = self.quant_conv(x)
return enc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..layers_compat import unflatten
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -595,7 +595,7 @@ def construct(self, hidden_states: ms.Tensor) -> ms.Tensor:
return hidden_states


class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
Expand Down Expand Up @@ -736,27 +736,6 @@ def enable_tiling(
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames

def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False

def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

def _encode(self, x: ms.Tensor) -> ms.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape

Expand All @@ -777,7 +756,7 @@ def encode(
Encode a batch of images into latents.

Args:
x (`torch.Tensor`): Input batch of images.
x (`ms.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

Expand Down Expand Up @@ -823,7 +802,7 @@ def decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput
Decode a batch of images.

Args:
z (`torch.Tensor`): Input batch of latent vectors.
z (`ms.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

Expand Down Expand Up @@ -871,10 +850,10 @@ def tiled_encode(self, x: ms.Tensor) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.

Args:
x (`torch.Tensor`): Input batch of videos.
x (`ms.Tensor`): Input batch of videos.

Returns:
`torch.Tensor`:
`ms.Tensor`:
The latent representation of the encoded videos.
"""
batch_size, num_channels, num_frames, height, width = x.shape
Expand Down Expand Up @@ -922,7 +901,7 @@ def tiled_decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[Decoder
Decode a batch of images using a tiled decoder.

Args:
z (`torch.Tensor`): Input batch of latent vectors.
z (`ms.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

Expand Down Expand Up @@ -1051,7 +1030,7 @@ def construct(
) -> Union[DecoderOutput, ms.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample (`ms.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Expand Down
Loading