diff --git a/mindone/diffusers/models/autoencoders/autoencoder_asym_kl.py b/mindone/diffusers/models/autoencoders/autoencoder_asym_kl.py index 404f0c5107..2f98d2c4b7 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_asym_kl.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_asym_kl.py @@ -24,10 +24,10 @@ from ...configuration_utils import ConfigMixin, register_to_config from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder -class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): +class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin): r""" Designing a Better Asymmetric VQGAN for StableDiffusion https://huggingface.co/papers/2306.04632 . A VAE model with KL loss for encoding images into latents and decoding latent representations into images. @@ -112,9 +112,6 @@ def __init__( self.quant_conv = mint.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.post_quant_conv = mint.nn.Conv2d(latent_channels, latent_channels, 1) - self.use_slicing = False - self.use_tiling = False - self.register_to_config(block_out_channels=up_block_out_channels) self.register_to_config(force_upcast=False) diff --git a/mindone/diffusers/models/autoencoders/autoencoder_dc.py b/mindone/diffusers/models/autoencoders/autoencoder_dc.py index a14c008cdf..c94caa5a86 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_dc.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_dc.py @@ -30,7 +30,7 @@ from ..modeling_utils import ModelMixin from ..normalization import RMSNorm, get_normalization from ..transformers.sana_transformer import GLUMBConv -from .vae import DecoderOutput, EncoderOutput +from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput class ResBlock(nn.Cell): @@ -393,7 +393,7 @@ def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: return hidden_states -class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin): +class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): r""" An Autoencoder model introduced in [DCAE](https://huggingface.co/papers/2410.10733) and used in [SANA](https://huggingface.co/papers/2410.10629). @@ -551,27 +551,6 @@ def enable_tiling( self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio - def disable_tiling(self) -> None: - r""" - Disable tiled AE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced AE decoding. When this option is enabled, the AE will split the input tensor in slices to compute - decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced AE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def _encode(self, x: ms.Tensor) -> ms.Tensor: batch_size, num_channels, height, width = x.shape diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl.py b/mindone/diffusers/models/autoencoders/autoencoder_kl.py index 8db2ef0ca3..ba1b7d918d 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl.py @@ -28,10 +28,10 @@ from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder +from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder -class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): +class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. @@ -135,35 +135,6 @@ def __init__( self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 - def enable_tiling(self, use_tiling: bool = True): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.use_tiling = use_tiling - - def disable_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.enable_tiling(False) - - def enable_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_allegro.py index e50a40608c..899a23935a 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -32,6 +32,7 @@ from ..modeling_utils import ModelMixin from ..resnet import ResnetBlock2D from ..upsampling import Upsample2D +from .vae import AutoencoderMixin class AllegroTemporalConvLayer(nn.Cell): @@ -685,7 +686,7 @@ def construct(self, sample: ms.Tensor) -> ms.Tensor: return sample -class AutoencoderKLAllegro(ModelMixin, ConfigMixin): +class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin): r""" A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in [Allegro](https://github.com/rhymes-ai/Allegro). @@ -808,35 +809,6 @@ def __init__( sample_size - self.tile_overlap_w, ) - def enable_tiling(self) -> None: - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.use_tiling = True - - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def _encode(self, x: ms.Tensor) -> ms.Tensor: # TODO(aryan) # if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 9de498da03..698d1e8755 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -32,7 +32,7 @@ from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from ..upsampling import CogVideoXUpsample3D -from .vae import DecoderOutput, DiagonalGaussianDistribution +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -891,7 +891,7 @@ def construct( return hidden_states, new_conv_cache -class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): +class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in [CogVideoX](https://github.com/THUDM/CogVideo). @@ -1061,27 +1061,6 @@ def enable_tiling( self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def _encode(self, x: ms.Tensor) -> ms.Tensor: batch_size, num_channels, num_frames, height, width = x.shape diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_cosmos.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_cosmos.py index 076d8b2a93..12f3f5471e 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_cosmos.py @@ -25,7 +25,7 @@ from ..layers_compat import conv_transpose3d, unflatten from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import DecoderOutput, IdentityDistribution +from .vae import AutoencoderMixin, DecoderOutput, IdentityDistribution logger = get_logger(__name__) @@ -915,7 +915,7 @@ def construct(self, hidden_states: ms.tensor) -> ms.tensor: return hidden_states -class AutoencoderKLCosmos(ModelMixin, ConfigMixin): +class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin): r""" Autoencoder used in [Cosmos](https://huggingface.co/papers/2501.03575). @@ -1072,28 +1072,7 @@ def enable_tiling( self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - - def _encode(self, x: ms.tensor) -> ms.tensor: + def _encode(self, x: ms.Tensor) -> ms.Tensor: x = self.encoder(x) enc = self.quant_conv(x) return enc diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index faa5cf587c..d77b37083f 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -30,7 +30,7 @@ from ..layers_compat import unflatten from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import DecoderOutput, DiagonalGaussianDistribution +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -595,7 +595,7 @@ def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: return hidden_states -class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin): +class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin): r""" A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603). @@ -736,27 +736,6 @@ def enable_tiling( self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def _encode(self, x: ms.Tensor) -> ms.Tensor: batch_size, num_channels, num_frames, height, width = x.shape @@ -777,7 +756,7 @@ def encode( Encode a batch of images into latents. Args: - x (`torch.Tensor`): Input batch of images. + x (`ms.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. @@ -823,7 +802,7 @@ def decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput Decode a batch of images. Args: - z (`torch.Tensor`): Input batch of latent vectors. + z (`ms.Tensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. @@ -871,10 +850,10 @@ def tiled_encode(self, x: ms.Tensor) -> AutoencoderKLOutput: r"""Encode a batch of images using a tiled encoder. Args: - x (`torch.Tensor`): Input batch of videos. + x (`ms.Tensor`): Input batch of videos. Returns: - `torch.Tensor`: + `ms.Tensor`: The latent representation of the encoded videos. """ batch_size, num_channels, num_frames, height, width = x.shape @@ -922,7 +901,7 @@ def tiled_decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[Decoder Decode a batch of images using a tiled decoder. Args: - z (`torch.Tensor`): Input batch of latent vectors. + z (`ms.Tensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. @@ -1051,7 +1030,7 @@ def construct( ) -> Union[DecoderOutput, ms.Tensor]: r""" Args: - sample (`torch.Tensor`): Input sample. + sample (`ms.Tensor`): Input sample. sample_posterior (`bool`, *optional*, defaults to `False`): Whether to sample from the posterior. return_dict (`bool`, *optional*, defaults to `True`): diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 49a6769883..c7a7abe8de 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -32,7 +32,7 @@ from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from ..normalization import RMSNorm -from .vae import DecoderOutput, DiagonalGaussianDistribution +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution class LTXVideoCausalConv3d(nn.Cell): @@ -1044,7 +1044,7 @@ def construct(self, hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None) return hidden_states -class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): +class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in [LTX](https://huggingface.co/Lightricks/LTX-Video). @@ -1229,27 +1229,6 @@ def enable_tiling( self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def _encode(self, x: ms.Tensor) -> ms.Tensor: batch_size, num_channels, num_frames, height, width = x.shape diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_magvit.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_magvit.py index e3bbfbcb70..880518e3f5 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_magvit.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_magvit.py @@ -30,7 +30,7 @@ from ..activations import get_activation from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import DecoderOutput, DiagonalGaussianDistribution +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -671,7 +671,7 @@ def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: return hidden_states -class AutoencoderKLMagvit(ModelMixin, ConfigMixin): +class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This model is used in [EasyAnimate](https://huggingface.co/papers/2405.18991). @@ -815,27 +815,6 @@ def enable_tiling( self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def _encode( self, x: ms.Tensor, return_dict: bool = False ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_mochi.py index 5b56e634de..3caee3b22f 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_mochi.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -32,7 +32,7 @@ from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d -from .vae import DecoderOutput, DiagonalGaussianDistribution +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -635,7 +635,7 @@ def construct(self, hidden_states: ms.Tensor, conv_cache: Optional[Dict[str, ms. return hidden_states, new_conv_cache -class AutoencoderKLMochi(ModelMixin, ConfigMixin): +class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in [Mochi 1 preview](https://github.com/genmoai/models). @@ -798,27 +798,6 @@ def enable_tiling( self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def _enable_framewise_encoding(self): r""" Enables the framewise VAE encoding implementation with past latent padding. By default, Diffusers uses the @@ -921,7 +900,7 @@ def decode(self, z: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput Decode a batch of images. Args: - z (`torch.Tensor`): Input batch of latent vectors. + z (`ms.Tensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `False`): Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. @@ -961,10 +940,10 @@ def tiled_encode(self, x: ms.Tensor) -> ms.Tensor: r"""Encode a batch of images using a tiled encoder. Args: - x (`torch.Tensor`): Input batch of videos. + x (`ms.Tensor`): Input batch of videos. Returns: - `torch.Tensor`: + `ms.Tensor`: The latent representation of the encoded videos. """ batch_size, num_channels, num_frames, height, width = x.shape diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py index de36af7aba..5edc03d3bb 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -36,7 +36,7 @@ from ..activations import get_activation from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import DecoderOutput, DiagonalGaussianDistribution +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -670,7 +670,7 @@ def construct(self, x, feat_cache=None, feat_idx=[0]): return x -class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin): +class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. @@ -778,27 +778,6 @@ def enable_tiling( self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def clear_cache(self): def _count_conv3d(model): count = 0 diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py index c44b5395bb..2136fea3f4 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py @@ -26,7 +26,7 @@ from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder -from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder class TemporalDecoder(nn.Cell): @@ -120,7 +120,7 @@ def construct( return sample -class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin): +class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_wan.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_wan.py index 6d4de4a3c1..fd8eadfe7f 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -29,7 +29,7 @@ from ..activations import get_activation from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import DecoderOutput, DiagonalGaussianDistribution +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -474,14 +474,14 @@ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", def construct(self, x, feat_cache=None, feat_idx=[0]): # First residual block - x = self.resnets[0](x, feat_cache, feat_idx) + x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx) # Process through attention and residual blocks for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: x = attn(x) - x = resnet(x, feat_cache, feat_idx) + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) return x @@ -515,9 +515,9 @@ def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample def construct(self, x, feat_cache=None, feat_idx=[0]): x_copy = x.clone() for resnet in self.resnets: - x = resnet(x, feat_cache, feat_idx) + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) if self.downsampler is not None: - x = self.downsampler(x, feat_cache, feat_idx) + x = self.downsampler(x, feat_cache=feat_cache, feat_idx=feat_idx) return x + self.avg_shortcut(x_copy) @@ -619,12 +619,12 @@ def construct(self, x, feat_cache=None, feat_idx=[0]): # downsamples for layer in self.down_blocks: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx) else: x = layer(x) # middle - x = self.mid_block(x, feat_cache, feat_idx) + x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx) # head x = self.norm_out(x) @@ -715,13 +715,13 @@ def construct(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): for resnet in self.resnets: if feat_cache is not None: - x = resnet(x, feat_cache, feat_idx) + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) else: x = resnet(x) if self.upsampler is not None: if feat_cache is not None: - x = self.upsampler(x, feat_cache, feat_idx) + x = self.upsampler(x, feat_cache=feat_cache, feat_idx=feat_idx) else: x = self.upsampler(x) @@ -788,13 +788,13 @@ def construct(self, x, feat_cache=None, feat_idx=[0], first_chunk=None): """ for resnet in self.resnets: if feat_cache is not None: - x = resnet(x, feat_cache, feat_idx) + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) else: x = resnet(x) if self.upsamplers is not None: if feat_cache is not None: - x = self.upsamplers[0](x, feat_cache, feat_idx) + x = self.upsamplers[0](x, feat_cache=feat_cache, feat_idx=feat_idx) else: x = self.upsamplers[0](x) return x @@ -906,11 +906,11 @@ def construct(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): x = self.conv_in(x) # middle - x = self.mid_block(x, feat_cache, feat_idx) + x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx) # upsamples for up_block in self.up_blocks: - x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk) + x = up_block(x, feat_cache=feat_cache, feat_idx=feat_idx, first_chunk=first_chunk) # head x = self.norm_out(x) @@ -972,7 +972,7 @@ def unpatchify(x, patch_size): return x -class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): +class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Introduced in [Wan 2.1]. @@ -982,6 +982,9 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): """ _supports_gradient_checkpointing = False + # keys toignore when AlignDeviceHook moves inputs/outputs between devices + # these are shared mutable state modified in-place + _skip_keys = ["feat_cache", "feat_idx"] @register_to_config def __init__( @@ -989,7 +992,7 @@ def __init__( base_dim: int = 96, decoder_base_dim: Optional[int] = None, z_dim: int = 16, - dim_mult: Tuple[int] = [1, 2, 4, 4], + dim_mult: List[int] = [1, 2, 4, 4], num_res_blocks: int = 2, attn_scales: List[float] = [], temperal_downsample: List[bool] = [False, True, True], @@ -1074,7 +1077,7 @@ def __init__( self.diag_gauss_dist = DiagonalGaussianDistribution() - self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) + self.spatial_compression_ratio = scale_factor_spatial # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension # to perform decoding of a single video latent at a time. @@ -1133,27 +1136,6 @@ def enable_tiling( self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def clear_cache(self): # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call self._conv_num = self._cached_conv_counts["decoder"] @@ -1167,12 +1149,13 @@ def clear_cache(self): def _encode(self, x: ms.Tensor): _, _, num_frame, height, width = x.shape - if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): - return self.tiled_encode(x) - self.clear_cache() if self.config.patch_size is not None: x = patchify(x, patch_size=self.config.patch_size) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + iter_ = 1 + (num_frame - 1) // 4 for i in range(iter_): self._enc_conv_idx = [0] @@ -1377,9 +1360,18 @@ def tiled_decode(self, z: ms.Tensor, return_dict: bool = True) -> Union[DecoderO tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio - - blend_height = self.tile_sample_min_height - self.tile_sample_stride_height - blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + tile_sample_stride_height = self.tile_sample_stride_height + tile_sample_stride_width = self.tile_sample_stride_width + if self.config.patch_size is not None: + sample_height = sample_height // self.config.patch_size + sample_width = sample_width // self.config.patch_size + tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size + tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size + blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height + blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width + else: + blend_height = self.tile_sample_min_height - tile_sample_stride_height + blend_width = self.tile_sample_min_width - tile_sample_stride_width # Split z into overlapping tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. @@ -1393,7 +1385,9 @@ def tiled_decode(self, z: ms.Tensor, return_dict: bool = True) -> Union[DecoderO self._conv_idx = [0] tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] tile = self.post_quant_conv(tile) - decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) + decoded = self.decoder( + tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0) + ) time.append(decoded) row.append(mint.cat(time, dim=2)) rows.append(row) @@ -1409,11 +1403,15 @@ def tiled_decode(self, z: ms.Tensor, return_dict: bool = True) -> Union[DecoderO tile = self.blend_v(rows[i - 1][j], tile, blend_height) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_width) - result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width]) result_rows.append(mint.cat(result_row, dim=-1)) - dec = mint.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + if self.config.patch_size is not None: + dec = unpatchify(dec, patch_size=self.config.patch_size) + + dec = mint.clamp(dec, min=-1.0, max=1.0) + if not return_dict: return (dec,) return DecoderOutput(sample=dec) diff --git a/mindone/diffusers/models/autoencoders/autoencoder_oobleck.py b/mindone/diffusers/models/autoencoders/autoencoder_oobleck.py index a6d98b4707..143ff56945 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_oobleck.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_oobleck.py @@ -28,6 +28,7 @@ from ...utils import BaseOutput from ...utils.mindspore_utils import randn_tensor from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin class Snake1d(nn.Cell): @@ -326,7 +327,7 @@ def construct(self, hidden_state): return hidden_state -class AutoencoderOobleck(ModelMixin, ConfigMixin): +class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin): r""" An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First introduced in Stable Audio. @@ -391,20 +392,6 @@ def __init__( self.use_slicing = False - def enable_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def encode( self, x: ms.Tensor, return_dict: bool = True ) -> Union[AutoencoderOobleckOutput, Tuple[OobleckDiagonalGaussianDistribution]]: diff --git a/mindone/diffusers/models/autoencoders/autoencoder_tiny.py b/mindone/diffusers/models/autoencoders/autoencoder_tiny.py index 08e418b2e9..1fb3ae5f30 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_tiny.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_tiny.py @@ -27,7 +27,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput from ..modeling_utils import ModelMixin -from .vae import DecoderOutput, DecoderTiny, EncoderTiny +from .vae import AutoencoderMixin, DecoderOutput, DecoderTiny, EncoderTiny @dataclass @@ -43,7 +43,7 @@ class AutoencoderTinyOutput(BaseOutput): latents: ms.Tensor -class AutoencoderTiny(ModelMixin, ConfigMixin): +class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin): r""" A tiny distilled VAE model for encoding images into latents and decoding latent representations into images. @@ -167,35 +167,6 @@ def unscale_latents(self, x: ms.Tensor) -> ms.Tensor: """[0, 1] -> raw latents""" return x.sub(self.latent_shift).mul(2 * self.latent_magnitude) - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - - def enable_tiling(self, use_tiling: bool = True) -> None: - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.use_tiling = use_tiling - - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.enable_tiling(False) - def _tiled_encode(self, x: ms.Tensor) -> ms.Tensor: r"""Encode a batch of images using a tiled encoder. diff --git a/mindone/diffusers/models/autoencoders/consistency_decoder_vae.py b/mindone/diffusers/models/autoencoders/consistency_decoder_vae.py index 9fb6d8a9dd..5f50a22db4 100644 --- a/mindone/diffusers/models/autoencoders/consistency_decoder_vae.py +++ b/mindone/diffusers/models/autoencoders/consistency_decoder_vae.py @@ -29,7 +29,7 @@ from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor from ..modeling_utils import ModelMixin from ..unets.unet_2d import UNet2DModel -from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder @dataclass @@ -46,7 +46,7 @@ class ConsistencyDecoderVAEOutput(BaseOutput): latent: ms.Tensor -class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): +class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin): r""" The consistency decoder used with DALL-E 3. @@ -159,39 +159,6 @@ def __init__( self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 - # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling - def enable_tiling(self, use_tiling: bool = True): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.use_tiling = use_tiling - - # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling - def disable_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.enable_tiling(False) - - # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing - def enable_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing - def disable_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: # type: ignore diff --git a/mindone/diffusers/models/autoencoders/vae.py b/mindone/diffusers/models/autoencoders/vae.py index 16f3dcdcb8..e7b2abdfb3 100644 --- a/mindone/diffusers/models/autoencoders/vae.py +++ b/mindone/diffusers/models/autoencoders/vae.py @@ -867,3 +867,38 @@ def construct(self, x: ms.Tensor) -> ms.Tensor: # scale image from [0, 1] to [-1, 1] to match diffusers convention return x.mul(2).sub(1) + + +class AutoencoderMixin: + def enable_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + if not hasattr(self, "use_tiling"): + raise NotImplementedError(f"Tiling doesn't seem to be implemented for {self.__class__.__name__}.") + self.use_tiling = True + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + if not hasattr(self, "use_slicing"): + raise NotImplementedError(f"Slicing doesn't seem to be implemented for {self.__class__.__name__}.") + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False diff --git a/mindone/diffusers/models/autoencoders/vq_model.py b/mindone/diffusers/models/autoencoders/vq_model.py index f9a124bc6f..aadf3473df 100644 --- a/mindone/diffusers/models/autoencoders/vq_model.py +++ b/mindone/diffusers/models/autoencoders/vq_model.py @@ -24,6 +24,7 @@ from ...utils import BaseOutput from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin @dataclass @@ -39,7 +40,7 @@ class VQEncoderOutput(BaseOutput): latents: ms.Tensor -class VQModel(ModelMixin, ConfigMixin): +class VQModel(ModelMixin, AutoencoderMixin, ConfigMixin): r""" A VQ-VAE model for decoding latent representations.