diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index be4b86321..886d5e070 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -48,6 +48,10 @@ def check_qaic_sdk(): QEFFCommonLoader, ) from QEfficient.compile.compile_helper import compile + + # Imports for the diffusers + from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEFFFluxPipeline + from QEfficient.diffusers.pipelines.wan.pipeline_wan import QEFFWanPipeline from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv from QEfficient.peft import QEffAutoPeftModelForCausalLM @@ -67,6 +71,8 @@ def check_qaic_sdk(): "QEFFAutoModelForImageTextToText", "QEFFAutoModelForSpeechSeq2Seq", "QEFFCommonLoader", + "QEFFFluxPipeline", + "QEFFWanPipeline" ] else: diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index d9d6823ae..e2243c878 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -22,7 +22,7 @@ from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile from QEfficient.generation.cloud_infer import QAICInferenceSession -from QEfficient.utils import constants, create_json, dump_qconfig, generate_mdp_partition_config, load_json +from QEfficient.utils import constants, create_json, generate_mdp_partition_config, load_json # dump_qconfig #TODO: debug and enable from QEfficient.utils.cache import QEFF_HOME, to_hashable logger = logging.getLogger(__name__) @@ -179,7 +179,8 @@ def _export( input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=constants.ONNX_EXPORT_OPSET, + opset_version=17, + # verbose=True, **export_kwargs, ) logger.info("Pytorch export successful") @@ -213,7 +214,7 @@ def _export( self.onnx_path = onnx_path return onnx_path - @dump_qconfig + # @dump_qconfig def _compile( self, onnx_path: Optional[str] = None, @@ -352,6 +353,7 @@ def _compile( command.append(f"-aic-binary-dir={qpc_path}") logger.info(f"Running compiler: {' '.join(command)}") + print(command) try: subprocess.run(command, capture_output=True, check=True) except subprocess.CalledProcessError as e: diff --git a/QEfficient/diffusers/README.md b/QEfficient/diffusers/README.md new file mode 100644 index 000000000..088108461 --- /dev/null +++ b/QEfficient/diffusers/README.md @@ -0,0 +1,110 @@ + +
+ + +# **Diffusion Models on Qualcomm Cloud AI 100** + + +
+ +### 🎨 **Experience the Future of AI Image Generation** + +* Optimized for Qualcomm Cloud AI 100* + +Sample Output + +**Generated with**: `stabilityai/stable-diffusion-3.5-large` • `"A girl laughing"` • 28 steps • 2.0 guidance scale • ⚡ + + + +
+ + + +[![Diffusers](https://img.shields.io/badge/Diffusers-0.31.0-orange.svg)](https://github.com/huggingface/diffusers) +
+ +--- + +## ✨ Overview + +QEfficient Diffusers brings the power of state-of-the-art diffusion models to Qualcomm Cloud AI 100 hardware for text-to-image generation. Built on top of the popular HuggingFace Diffusers library, our optimized pipeline provides seamless inference on Qualcomm Cloud AI 100 hardware. + +## 🛠️ Installation + +### Prerequisites + +Ensure you have Python 3.8+ and the required dependencies: + +```bash +# Create Python virtual environment (Recommended Python 3.10) +sudo apt install python3.10-venv +python3.10 -m venv qeff_env +source qeff_env/bin/activate +pip install -U pip +``` + +### Install QEfficient + +```bash +# Install from GitHub (includes diffusers support) +pip install git+https://github.com/quic/efficient-transformers + +# Or build from source +git clone https://github.com/quic/efficient-transformers.git +cd efficient-transformers +pip install build wheel +python -m build --wheel --outdir dist +pip install dist/qefficient-0.0.1.dev0-py3-none-any.whl +``` + +### Install Diffusers Dependencies + +```bash +# Install diffusers optional dependencies +pip install "QEfficient[diffusers]" +``` + +--- + +## 🎯 Supported Models + +### Stable Diffusion 3.x Series +- ✅ [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3.5-large) +- ✅ [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo) +--- + + +## 📚 Examples + +Check out our comprehensive examples in the [`examples/diffusers/`](../../examples/diffusers/) directory: + +--- + +## 🤝 Contributing + +We welcome contributions! Please see our [Contributing Guide](../../CONTRIBUTING.md) for details. + +### Development Setup + +```bash +git clone https://github.com/quic/efficient-transformers.git +cd efficient-transformers +pip install -e ".[diffusers,test]" +``` + +--- + +## 🙏 Acknowledgments + +- **HuggingFace Diffusers**: For the excellent foundation library +- **Stability AI**: For the amazing Stable Diffusion models +--- + +## 📞 Support + +- 📖 **Documentation**: [https://quic.github.io/efficient-transformers/](https://quic.github.io/efficient-transformers/) +- 🐛 **Issues**: [GitHub Issues](https://github.com/quic/efficient-transformers/issues) + +--- + diff --git a/QEfficient/diffusers/__init__.py b/QEfficient/diffusers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/QEfficient/diffusers/models/__init__.py b/QEfficient/diffusers/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/QEfficient/diffusers/models/attention.py b/QEfficient/diffusers/models/attention.py new file mode 100644 index 000000000..3c9cc268d --- /dev/null +++ b/QEfficient/diffusers/models/attention.py @@ -0,0 +1,75 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import torch +from diffusers.models.attention import JointTransformerBlock, _chunked_feed_forward + + +class QEffJointTransformerBlock(JointTransformerBlock): + def forward( + self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor + ): + if self.use_dual_attention: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( + hidden_states, emb=temb + ) + else: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + if self.context_pre_only: + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + else: + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + ) + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + if self.use_dual_attention: + attn_output2 = self.attn2(hidden_states=norm_hidden_states2) + attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 + hidden_states = hidden_states + attn_output2 + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + # ff_output = self.ff(norm_hidden_states) + ff_output = self.ff(norm_hidden_states, block_size=4096) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + + # Process attention outputs for the `encoder_hidden_states`. + if self.context_pre_only: + encoder_hidden_states = None + else: + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + context_ff_output = _chunked_feed_forward( + self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size + ) + else: + # context_ff_output = self.ff_context(norm_encoder_hidden_states) + context_ff_output = self.ff_context(norm_encoder_hidden_states, block_size=333) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return encoder_hidden_states, hidden_states diff --git a/QEfficient/diffusers/models/attention_processor.py b/QEfficient/diffusers/models/attention_processor.py new file mode 100644 index 000000000..01954e55e --- /dev/null +++ b/QEfficient/diffusers/models/attention_processor.py @@ -0,0 +1,155 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from typing import Optional + +import torch +from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0 + + +class QEffAttention(Attention): + def __qeff_init__(self): + processor = QEffJointAttnProcessor2_0() + self.processor = processor + processor.query_block_size = 64 + + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[2], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key, + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + +class QEffJointAttnProcessor2_0(JointAttnProcessor2_0): + def __call__( + self, + attn: QEffAttention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + + query = query.reshape(-1, query.shape[-2], query.shape[-1]) + key = key.reshape(-1, key.shape[-2], key.shape[-1]) + value = value.reshape(-1, value.shape[-2], value.shape[-1]) + + # pre-transpose the key + key = key.transpose(-1, -2) + if query.size(-2) != value.size(-2): # cross-attention, use regular attention + # QKV done in single block + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + else: # self-attention, use blocked attention + # QKV done with block-attention (a la FlashAttentionV2) + query_block_size = self.query_block_size + query_seq_len = query.size(-2) + num_blocks = (query_seq_len + query_block_size - 1) // query_block_size + for qidx in range(num_blocks): + query_block = query[:, qidx * query_block_size : (qidx + 1) * query_block_size, :] + attention_probs = attn.get_attention_scores(query_block, key, attention_mask) + hidden_states_block = torch.bmm(attention_probs, value) + if qidx == 0: + hidden_states = hidden_states_block + else: + hidden_states = torch.cat((hidden_states, hidden_states_block), -2) + hidden_states = attn.batch_to_head_dim(hidden_states) + + if encoder_hidden_states is not None: + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states diff --git a/QEfficient/diffusers/models/autoencoders/__init__.py b/QEfficient/diffusers/models/autoencoders/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/models/autoencoders/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/models/autoencoders/autoencoder_kl.py b/QEfficient/diffusers/models/autoencoders/autoencoder_kl.py new file mode 100644 index 000000000..c652f07d2 --- /dev/null +++ b/QEfficient/diffusers/models/autoencoders/autoencoder_kl.py @@ -0,0 +1,31 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import torch +from diffusers import AutoencoderKL + + +class QEffAutoencoderKL(AutoencoderKL): + def encode(self, x: torch.Tensor, return_dict: bool = True): + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + return h diff --git a/QEfficient/diffusers/models/normalization.py b/QEfficient/diffusers/models/normalization.py new file mode 100644 index 000000000..ecd92e7d4 --- /dev/null +++ b/QEfficient/diffusers/models/normalization.py @@ -0,0 +1,49 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +import numbers +from typing import Dict, Optional, Tuple +import torch +from diffusers.models.normalization import AdaLayerNormZero, AdaLayerNormZeroSingle,AdaLayerNormContinuous + + +class QEffAdaLayerNormZero(AdaLayerNormZero): + def forward( + self, + x: torch.Tensor, + timestep: Optional[torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + hidden_dtype: Optional[torch.dtype] = None, + shift_msa: Optional[torch.Tensor] = None, + scale_msa: Optional[torch.Tensor] = None, + # emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if self.emb is not None: + emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) + # emb = self.linear(self.silu(emb)) + # shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x + +class QEffAdaLayerNormZeroSingle(AdaLayerNormZeroSingle): + def forward( + self, + x: torch.Tensor, + scale_msa: Optional[torch.Tensor] = None, + shift_msa: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x + +class QEffAdaLayerNormContinuous(AdaLayerNormContinuous): + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + # emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) + emb = conditioning_embedding + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x \ No newline at end of file diff --git a/QEfficient/diffusers/models/pytorch_transforms.py b/QEfficient/diffusers/models/pytorch_transforms.py new file mode 100644 index 000000000..c7a7bcc61 --- /dev/null +++ b/QEfficient/diffusers/models/pytorch_transforms.py @@ -0,0 +1,64 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +from typing import Tuple +import torch +from torch import nn +from diffusers.models.attention import JointTransformerBlock +from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0 +from diffusers.models.normalization import RMSNorm, AdaLayerNormZero, AdaLayerNormZeroSingle, AdaLayerNormContinuous +from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock, FluxTransformer2DModel +from diffusers.models.transformers.transformer_wan import WanTransformer3DModel + +from QEfficient.base.pytorch_transforms import ModuleMappingTransform +from QEfficient.customop.rms_norm import CustomRMSNormAIC +from QEfficient.diffusers.models.attention import QEffJointTransformerBlock +from QEfficient.diffusers.models.attention_processor import ( + QEffAttention, + QEffJointAttnProcessor2_0, +) +from QEfficient.diffusers.models.transformers.transformer_flux import QEffFluxSingleTransformerBlock, QEffFluxTransformerBlock, QEffFluxTransformer2DModel +from QEfficient.diffusers.models.transformers.transformer_wan import QEFFWanTransformer3DModel +from QEfficient.diffusers.models.normalization import QEffAdaLayerNormZero, QEffAdaLayerNormZeroSingle, QEffAdaLayerNormContinuous + +class CustomOpsTransform(ModuleMappingTransform): + _module_mapping = { + RMSNorm: CustomRMSNormAIC, + nn.RMSNorm: CustomRMSNormAIC # for torch.nn.RMSNorm + } + + @classmethod + def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + model, transformed = super().apply(model) + return model, transformed + + +class AttentionTransform(ModuleMappingTransform): + _module_mapping = { + Attention: QEffAttention, + JointAttnProcessor2_0: QEffJointAttnProcessor2_0, + JointTransformerBlock: QEffJointTransformerBlock, + FluxSingleTransformerBlock: QEffFluxSingleTransformerBlock, + FluxTransformerBlock: QEffFluxTransformerBlock, + FluxTransformer2DModel: QEffFluxTransformer2DModel, + WanTransformer3DModel : QEFFWanTransformer3DModel, + } + + @classmethod + def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + model, transformed = super().apply(model) + return model, transformed + +class NormalizationTransform(ModuleMappingTransform): + _module_mapping = { + AdaLayerNormZero: QEffAdaLayerNormZero, + AdaLayerNormZeroSingle: QEffAdaLayerNormZeroSingle, + AdaLayerNormContinuous: QEffAdaLayerNormContinuous, + } + @classmethod + def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + model, transformed = super().apply(model) + return model, transformed diff --git a/QEfficient/diffusers/models/transformers/transformer_flux.py b/QEfficient/diffusers/models/transformers/transformer_flux.py new file mode 100644 index 000000000..d1abade8f --- /dev/null +++ b/QEfficient/diffusers/models/transformers/transformer_flux.py @@ -0,0 +1,378 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +import os +from typing import Any, Callable, Dict, List, Tuple, Optional, Union +from venv import logger + +import torch +import torch.nn as nn +import numpy as np + +from diffusers.models.transformers.transformer_flux import FluxAttention,FluxSingleTransformerBlock, FluxTransformerBlock, FluxTransformer2DModel, FluxPosEmbed, FluxAttnProcessor +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings +from diffusers.models.attention import FeedForward + +from QEfficient.diffusers.models.normalization import QEffAdaLayerNormZero, QEffAdaLayerNormZeroSingle, QEffAdaLayerNormContinuous + +class QEffFluxSingleTransformerBlock(FluxSingleTransformerBlock): + def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + super().__init__(dim, num_attention_heads, attention_head_dim, mlp_ratio) + self.mlp_hidden_dim = int(dim * mlp_ratio) + self.norm = QEffAdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + self.attn = FluxAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=FluxAttnProcessor(), + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + temb = tuple(torch.split(temb, 1)) + gate = temb[2] + residual = hidden_states + norm_hidden_states = self.norm(hidden_states, scale_msa=temb[1], shift_msa=temb[0]) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + # if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + +class QEffFluxTransformerBlock(FluxTransformerBlock): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__(dim, num_attention_heads, attention_head_dim) + + self.norm1 = QEffAdaLayerNormZero(dim) + self.norm1_context = QEffAdaLayerNormZero(dim) + self.attn = FluxAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=FluxAttnProcessor(), + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + temb1 = tuple(torch.split(temb[:6], 1)) + temb2 = tuple(torch.split(temb[6:], 1)) + norm_hidden_states = self.norm1(hidden_states, shift_msa=temb1[0], scale_msa=temb1[1]) + gate_msa, shift_mlp, scale_mlp, gate_mlp = temb1[-4:] + + norm_encoder_hidden_states = self.norm1_context( + encoder_hidden_states, shift_msa=temb2[0], scale_msa=temb2[1] + ) + + c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = temb2[-4:] + + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + # if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + +class QEffFluxTransformer2DModel(FluxTransformer2DModel): + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + out_channels: Optional[int] = None, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = 768, + guidance_embeds: bool = False, + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + ): + + resolved_out_channels = out_channels or in_channels + inner_dim = num_attention_heads * attention_head_dim + + super().__init__( + patch_size=patch_size, + in_channels=in_channels, + out_channels=out_channels, + num_layers=num_layers, + num_single_layers=num_single_layers, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + joint_attention_dim=joint_attention_dim, + pooled_projection_dim=pooled_projection_dim, + guidance_embeds=guidance_embeds, + axes_dims_rope=axes_dims_rope, + ) + + self.out_channels = resolved_out_channels + self.inner_dim = inner_dim + + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings + ) + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim + ) + + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + QEffFluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + QEffFluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_single_layers) + ] + ) + + self.norm_out = QEffAdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + adaln_emb: torch.Tensor = None, + adaln_single_emb: torch.Tensor = None, + adaln_out: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + joint_attention_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=adaln_emb[index_block], + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(np.ceil(interval_control)) + # For Xlabs ControlNet. + if controlnet_blocks_repeat: + hidden_states = ( + hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + ) + else: + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + joint_attention_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=adaln_single_emb[index_block], + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_single_block_samples is not None: + interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] + + hidden_states = self.norm_out(hidden_states, adaln_out) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/QEfficient/diffusers/models/transformers/transformer_wan.py b/QEfficient/diffusers/models/transformers/transformer_wan.py new file mode 100644 index 000000000..bd9bf0e34 --- /dev/null +++ b/QEfficient/diffusers/models/transformers/transformer_wan.py @@ -0,0 +1,219 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +import os +from typing import Any, Callable, Dict, List, Tuple, Optional, Union +from venv import logger + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.transformers.transformer_wan import WanAttention, _get_qkv_projections, _get_added_kv_projections, dispatch_attention_fn, WanTransformer3DModel + + + +class WanAttnProcessor: + _attention_backend = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: "WanAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + # 512 is the context length of the text encoder, hardcoded for now + image_context_length = encoder_hidden_states.shape[1] - 512 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] + + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if rotary_emb is not None: + + def apply_rotary_emb( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2].type_as(hidden_states) # updated from patch # TODO : get confirm + sin = freqs_sin[..., 1::2].type_as(hidden_states) # updated from patch + real = x1 * cos - x2 * sin # updated from patch + img = x1 * sin + x2 * cos # updated from patch + x_rot = torch.stack([real,img],dim=-1) # updated from patch + return x_rot.flatten(-2).type_as(hidden_states) # updated from patch + + query = apply_rotary_emb(query, *rotary_emb) + key = apply_rotary_emb(key, *rotary_emb) + + # I2V task + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img) + key_img = attn.norm_added_k(key_img) + + key_img = key_img.unflatten(2, (attn.heads, -1)) + value_img = value_img.unflatten(2, (attn.heads, -1)) + + hidden_states_img = dispatch_attention_fn( + query, + key_img, + value_img, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + ) + hidden_states_img = hidden_states_img.flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class QEFFWanTransformer3DModel(WanTransformer3DModel): + def forward( + self, + hidden_states: torch.Tensor, + # timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + rotary_emb: torch.Tensor, + temb: torch.Tensor, + timestep_proj: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + # if attention_kwargs is not None: + # attention_kwargs = attention_kwargs.copy() + # lora_scale = attention_kwargs.pop("scale", 1.0) + # else: + # lora_scale = 1.0 + + # if USE_PEFT_BACKEND: + # # weight the lora layers by setting `lora_scale` for each PEFT layer + # scale_lora_layers(self, lora_scale) + # else: + # if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + # logger.warning( + # "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + # ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # rotary_emb = self.rope(hidden_states) + rotary_emb = torch.split(rotary_emb, 1, dim=0) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v) + # if timestep.ndim == 2: + # ts_seq_len = timestep.shape[1] + # timestep = timestep.flatten() # batch_size * seq_len + # else: + # ts_seq_len = None + + # temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + # timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len + # ) + # if ts_seq_len is not None: + # # batch_size, seq_len, 6, inner_dim + # timestep_proj = timestep_proj.unflatten(2, (6, -1)) + # else: + # # batch_size, 6, inner_dim + # timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + ) + else: + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + + # 5. Output norm, projection & unpatchify + if temb.ndim == 3: + # batch_size, seq_len, inner_dim (wan 2.2 ti2v) + shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2) + shift = shift.squeeze(2) + scale = scale.squeeze(2) + else: + # batch_size, inner_dim + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + # hidden_states = hidden_states.reshape( + # batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + # ) + # hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + # output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + ## Compiler Fix ## + output = hidden_states + + # if USE_PEFT_BACKEND: + # # remove `lora_scale` from each PEFT layer + # unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/QEfficient/diffusers/pipelines/__init__.py b/QEfficient/diffusers/pipelines/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/pipelines/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/pipelines/flux/__init__.py b/QEfficient/diffusers/pipelines/flux/__init__.py new file mode 100644 index 000000000..2bc49f7ea --- /dev/null +++ b/QEfficient/diffusers/pipelines/flux/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- \ No newline at end of file diff --git a/QEfficient/diffusers/pipelines/flux/pipeline_flux.py b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py new file mode 100644 index 000000000..53e0cf0ec --- /dev/null +++ b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py @@ -0,0 +1,823 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import os +import time +from typing import Any, Callable, Dict, List, Optional, Union +from venv import logger + +import numpy as np +import torch +from diffusers import FluxPipeline +from diffusers.image_processor import VaeImageProcessor, PipelineImageInput +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps # TODO +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput + +from QEfficient.diffusers.pipelines.pipeline_utils import QEffTextEncoder, QEffClipTextEncoder, QEffVAE, QEffFluxTransformerModel +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils import constants + +class QEFFFluxPipeline(FluxPipeline): + _hf_auto_class = FluxPipeline + """ + A QEfficient-optimized Flux pipeline, inheriting from `diffusers.FluxPipeline`. + + This class integrates QEfficient components (e.g., optimized models for Clip, t5 text encoders, + flux transformer, and VAE) to enhance performance, particularly for deployment on Qualcomm AI hardware. + It provides methods for text-to-image generation leveraging these optimized components. + """ + + def __init__(self, model, *args, **kwargs): + self.text_encoder = QEffClipTextEncoder(model.text_encoder) + self.text_encoder_2 = QEffTextEncoder(model.text_encoder_2) + self.transformer = QEffFluxTransformerModel(model.transformer) + self.vae_decode = QEffVAE(model, "decoder") + + self.tokenizer = model.tokenizer + self.text_encoder.tokenizer = model.tokenizer + self.text_encoder_2.tokenizer = model.tokenizer_2 + self.tokenizer_max_length = model.tokenizer_max_length + self.scheduler = model.scheduler + + self.register_modules( + vae=self.vae_decode, + text_encoder= self.text_encoder, + text_encoder_2= self.text_encoder_2, + tokenizer= self.tokenizer , + tokenizer_2= self.text_encoder_2.tokenizer, + transformer=self.transformer, + scheduler=self.scheduler, + ) + + self.vae_decode.model.forward = lambda latent_sample, return_dict: self.vae_decode.model.decode( + latent_sample, return_dict + ) + + self.vae_scale_factor = ( + 2 ** (len(model.vae.config.block_out_channels) - 1) if getattr(model, "vae", None) else 8 + ) + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + self.t_max_length = ( + model.tokenizer.model_max_length if hasattr(model, "tokenizer") and model.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + """ + Instantiate a QEffFluxTransformer2DModel from pretrained Diffusers models. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + The path to the pretrained model or its name. + **kwargs (additional keyword arguments): + Additional arguments that can be passed to the underlying `StableDiffusion3Pipeline.from_pretrained` + method. + """ + model = cls._hf_auto_class.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch.float32, + **kwargs, + ) + model.to("cpu") + return cls(model, pretrained_model_name_or_path) + + def export(self, export_dir: Optional[str] = None) -> str: + """ + Exports the model to ``ONNX`` format using ``torch.onnx.export``. + + ``Optional`` Args: + :export_dir (str, optional): The directory path to store ONNX-graph. + + Returns: + :str: Path of the generated ``ONNX`` graph. + """ + + # text_encoder - CLIP + example_inputs_text_encoder, dynamic_axes_text_encoder, output_names_text_encoder = ( + self.text_encoder.get_onnx_config(seq_len = self.tokenizer.model_max_length) + ) + self.text_encoder.export( + inputs=example_inputs_text_encoder, + output_names=output_names_text_encoder, + dynamic_axes=dynamic_axes_text_encoder, + export_dir=export_dir, + ) + + # text_encoder_2 - T5 + example_inputs_text_encoder_2, dynamic_axes_text_encoder_2, output_names_text_encoder_2 = ( + self.text_encoder_2.get_onnx_config(seq_len = self.text_encoder_2.tokenizer.model_max_length) + ) + self.text_encoder_2.export( + inputs=example_inputs_text_encoder_2, + output_names=output_names_text_encoder_2, + dynamic_axes=dynamic_axes_text_encoder_2, + export_dir=export_dir, + ) + + # transformers + example_inputs_transformer, dynamic_axes_transformer, output_names_transformer = ( + self.transformer.get_onnx_config() + ) + self.transformer.export( + inputs=example_inputs_transformer, + output_names=output_names_transformer, + dynamic_axes=dynamic_axes_transformer, + export_dir=export_dir, + ) + + # vae + example_inputs_vae, dynamic_axes_vae, output_names_vae = self.vae_decode.get_onnx_config() + self.vae_decoder_onnx_path = self.vae_decode.export( + example_inputs_vae, + output_names_vae, + dynamic_axes_vae, + export_dir=export_dir, + ) + + + def compile( + self, + onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + *, + seq_len: Union[int, List[int]] = 256, + batch_size: int = 1, + num_devices_text_encoder: int = 1, + num_devices_transformer: int = 4, + num_devices_vae_decoder: int = 1, + num_cores: int = 16, # FIXME: Make this mandatory arg + mxfp6_matmul: bool = False, + **compiler_options, + ) -> str: + """ + Compiles the ONNX graphs of the different model components for deployment on Qualcomm AI hardware. + + This method takes the ONNX paths of the text encoders, transformer, and VAE decoder, + and compiles them into an optimized format for inference. + + Args: + onnx_path (`str`, *optional*): + The base directory where ONNX files were exported. If None, it assumes the ONNX + paths are already set as attributes (e.g., `self.text_encoder_onnx_path`). + This parameter is currently not fully utilized as individual ONNX paths are derived + from the `export` method. + compile_dir (`str`, *optional*): + The directory path to store the compiled artifacts. If None, a default location + might be used by the underlying compilation process. + seq_len (`Union[int, List[int]]`, *optional*, defaults to 32): + The sequence length(s) to use for compiling the text encoders. Can be a single + integer or a list of integers for multiple sequence lengths. + batch_size (`int`, *optional*, defaults to 1): + The batch size to use for compilation. + num_devices_text_encoder (`int`, *optional*, defaults to 1): + The number of AI devices to deploy the text encoder models on. + num_devices_transformer (`int`, *optional*, defaults to 4): + The number of AI devices to deploy the transformer model on. + num_devices_vae_decoder (`int`, *optional*, defaults to 1): + The number of AI devices to deploy the VAE decoder model on. + num_cores (`int`, *optional*, defaults to 16): + The number of cores to use for compilation. This argument is currently marked + as `FIXME: Make this mandatory arg`. + mxfp6_matmul (`bool`, *optional*, defaults to `False`): + If `True`, enables mixed-precision floating-point 6-bit matrix multiplication + optimization during compilation. + **compiler_options: + Additional keyword arguments to pass to the underlying compiler. + + Returns: + `str`: A message indicating the compilation status or path to compiled artifacts. + (Note: The current implementation might need to return specific paths for each compiled model). + """ + if any( + path is None + for path in [ + self.text_encoder.onnx_path, + self.text_encoder_2.onnx_path, + self.transformer.onnx_path, + self.vae_decode.onnx_path, + ] + ): + self.export() + # text_encoder - CLIP + specializations_text_encoder = self.text_encoder.get_specializations( + batch_size, self.tokenizer.model_max_length + ) + + # self.text_encoder_compile_path = "" + self.text_encoder_compile_path = self.text_encoder._compile( + onnx_path, + compile_dir, + compile_only=True, + specializations=specializations_text_encoder, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices_text_encoder, + aic_num_cores=num_cores, + **compiler_options, + ) + + # text encoder 2 - T5 + specializations_text_encoder_2 = self.text_encoder_2.get_specializations( + batch_size, seq_len + ) + + self.text_encoder_2_compile_path = self.text_encoder_2._compile( + onnx_path, + compile_dir, + compile_only=True, + specializations=specializations_text_encoder_2, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices_text_encoder, + aic_num_cores=num_cores, + **compiler_options, + ) + + # transformer + specializations_transformer = self.transformer.get_specializations(batch_size, seq_len) + compiler_options = {"mos": 1, "mdts-mos":1} + self.trasformer_compile_path = self.transformer._compile( + onnx_path, + compile_dir, + compile_only=True, + specializations=specializations_transformer, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices_transformer, + aic_num_cores=num_cores, + **compiler_options, + ) + + # vae + specializations_vae = self.vae_decode.get_specializations(batch_size) + self.vae_decoder_compile_path = self.vae_decode._compile( + onnx_path, + compile_dir, + compile_only=True, + specializations=specializations_vae, + convert_to_fp16=True, + mdp_ts_num_devices=num_devices_vae_decoder, + ) + + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device_ids: Optional[List[int]] = None, + dtype: Optional[torch.dtype] = None, + ): + """ + Get T5 prompt embeddings for the given prompt(s). + + Args: + prompt (Union[str, List[str]], optional): The input prompt(s) to encode. + num_images_per_prompt (int, defaults to 1): Number of images to generate per prompt. + max_sequence_length (int, defaults to 256): Maximum sequence length for tokenization. + device ids (Optional[torch.device], optional): The device to place tensors on QAIC device ids. + dtype (Optional[torch.dtype], optional): The data type for tensors. + + Returns: + torch.Tensor: The T5 prompt embeddings with shape (batch_size * num_images_per_prompt, seq_len, hidden_size). + """ + + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + embed_dim = 4096 + + text_inputs = self.text_encoder_2.tokenizer( + prompt, + padding="max_length", + max_length= max_sequence_length, + truncation= True, + return_length= False, + return_overflowing_tokens= False, + return_tensors= "pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.text_encoder_2.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.text_encoder_2.tokenizer.batch_decode(untruncated_ids[:,self.text_encoder_2.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" { self.text_encoder_2.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if self.text_encoder_2.qpc_session is None: + self.text_encoder_2.qpc_session = QAICInferenceSession(str(self.text_encoder_2_compile_path), device_ids=device_ids) + + text_encoder_2_output = { + "last_hidden_state": np.random.rand(batch_size, max_sequence_length, embed_dim).astype(np.int32), + } + self.text_encoder_2.qpc_session.set_buffers(text_encoder_2_output) + + aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)} + prompt_embeds = torch.tensor(self.text_encoder_2.qpc_session.run(aic_text_input)["last_hidden_state"]) + + + # # # AIC Testing + # prompt_embeds_pytorch = self.text_encoder_2.model(text_input_ids, output_hidden_states=False) + # mad = torch.abs(prompt_embeds_pytorch["last_hidden_state"] - prompt_embeds).mean() + # print(">>>>>>>>>>>> MAD for text-encoder-2 - T5 => Pytorch vs AI 100:", mad) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device_ids: Optional[List[int]] = None, + ): + """ + Get CLIP prompt embeddings for a given text encoder and tokenizer. + + Args: + prompt (Union[str, List[str]]): The input prompt(s) to encode. + num_images_per_prompt (Optional[int], defaults to 1): Number of images to generate per prompt. + device_ids (List[int], optional): List of device IDs to use for inference. + + Returns: + - prompt_embd_text_encoder: The prompt embeddings from the text encoder. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + embed_dim = 768 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + + if self.text_encoder.qpc_session is None: + self.text_encoder.qpc_session = QAICInferenceSession(str(self.text_encoder_compile_path), device_ids=device_ids) + + text_encoder_output = { + "pooler_output": np.random.rand(batch_size, embed_dim).astype(np.int32), + "last_hidden_state": np.random.rand(batch_size, self.tokenizer_max_length, embed_dim).astype(np.int32), + } + + self.text_encoder.qpc_session.set_buffers(text_encoder_output) + + aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)} + aic_embeddings = self.text_encoder.qpc_session.run(aic_text_input) + aic_text_encoder_emb = aic_embeddings["pooler_output"] + + + # # # [TEMP] CHECK ACC # # + # prompt_embeds_pytorch = self.text_encoder.model(text_input_ids, output_hidden_states=False) + # pt_pooled_embed = prompt_embeds_pytorch["pooler_output"].detach().numpy() + # mad = np.mean(np.abs(pt_pooled_embed - aic_text_encoder_emb)) + # print(f">>>>>>>>>>>> CLIP text encoder pooled embed MAD: ", mad) ## 0.0043082903 ##TODO : Clean up + ### END CHECK ACC ### + + # Use pooled output of CLIPTextModel + prompt_embeds = torch.tensor(aic_embeddings["pooler_output"]) + # prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + device_ids_text_encoder_1 : Optional[List[int]] = None, + device_ids_text_encoder_2 : Optional[List[int]] = None + ): + r""" + Encode the given prompts into text embeddings using the two text encoders (CLIP and T5). + + This method processes prompts through multiple text encoders to generate embeddings suitable + for Flux pipeline. It handles both positive and negative prompts for + classifier-free guidance. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + device_ids_text_encoder_1 (List[int], optional): List of device IDs to use for CLIP instance . + device_ids_text_encoder_2 (List[int], optional): List of device IDs to use for T5 . + """ + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device_ids=device_ids_text_encoder_1, + num_images_per_prompt=num_images_per_prompt, + + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device_ids=device_ids_text_encoder_2, + ) + + text_ids = torch.zeros(prompt_embeds.shape[1], 3) + return prompt_embeds, pooled_prompt_embeds, text_ids + + + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + device_ids_text_encoder_1 : Optional[List[int]] = None, + device_ids_text_encoder_2 : Optional[List[int]] = None, + device_ids_transformer : Optional[List[int]] = None, + device_ids_vae_decoder : Optional[List[int]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + `negative_prompt` is provided. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 3.5): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + device_ids_text_encoder1 (List[int], optional): List of device IDs to use for CLIP instance. + device_ids_text_encoder2 (List[int], optional): List of device IDs to use for T5. + device_ids_transformer (List[int], optional): List of device IDs to use for Flux transformer. + device_ids_vae_decoder (List[int], optional): List of device IDs to use for VAE decoder. + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + + Examples: + ```python + # Basic text-to-image generation + from QEfficient import QEFFFluxPipeline + pipeline = QEFFFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + pipeline.compile(num_devices_text_encoder=1, num_devices_transformer=4, num_devices_vae_decoder=1) + + generator = torch.manual_seed(42) + # NOTE: guidance_scale <=1 is not supported + image = pipeline("A cat holding a sign that says hello world", + guidance_scale=0.0, + num_inference_steps=4, + max_sequence_length=256, + generator=generator).images[0] + image.save("flux-schnell_aic.png") + ``` + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + device = 'cpu' + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device_ids_text_encoder_1=device_ids_text_encoder_1, + device_ids_text_encoder_2=device_ids_text_encoder_2 + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device_ids_text_encoder_1=device_ids_text_encoder_1, + device_ids_text_encoder_2=device_ids_text_encoder_2 + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.model.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + ###### AIC related changes of transformers ###### + if self.transformer.qpc_session is None: + self.transformer.qpc_session = QAICInferenceSession(str(self.trasformer_compile_path), device_ids=device_ids_transformer) + + output_buffer = { + "output": np.random.rand( + batch_size, self.transformer.model.config.joint_attention_dim , self.transformer.model.config.in_channels + ).astype(np.int32), + } + + self.transformer.qpc_session.set_buffers(output_buffer) + + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + temb = self.transformer.model.time_text_embed(timestep, pooled_prompt_embeds) + + adaln_emb = [] + for i in range(19): + f1 = self.transformer.model.transformer_blocks[i].norm1.linear(self.transformer.model.transformer_blocks[i].norm1.silu(temb)).chunk(6, dim=1) + f2 = self.transformer.model.transformer_blocks[i].norm1_context.linear(self.transformer.model.transformer_blocks[i].norm1_context.silu(temb)).chunk(6, dim=1) + adaln_emb.append(torch.cat(list(f1) + list(f2))) + + adaln_dual_emb = torch.stack(adaln_emb) + + adaln_emb = [] + + for i in range(38): + f1 = self.transformer.model.single_transformer_blocks[i].norm.linear(self.transformer.model.single_transformer_blocks[i].norm.silu(temb)).chunk(3, dim=1) + adaln_emb.append(torch.cat(list(f1))) + + adaln_single_emb = torch.stack(adaln_emb) + + temp = self.transformer.model.norm_out + adaln_out = temp.linear(temp.silu(temb)) + + timestep = timestep / 1000 + + inputs_aic = {"hidden_states": latents.detach().numpy(), + "encoder_hidden_states": prompt_embeds.detach().numpy(), + "pooled_projections": pooled_prompt_embeds.detach().numpy(), + "timestep": timestep.detach().numpy(), + "img_ids": latent_image_ids.detach().numpy(), + "txt_ids": text_ids.detach().numpy(), + "adaln_emb": adaln_dual_emb.detach().numpy(), + "adaln_single_emb": adaln_single_emb.detach().numpy(), + "adaln_out": adaln_out.detach().numpy()} + + # noise_pred_torch = self.transformer.model( + # hidden_states=latents, + # encoder_hidden_states = prompt_embeds, + # pooled_projections=pooled_prompt_embeds, + # timestep=torch.tensor(timestep), + # img_ids = latent_image_ids, + # txt_ids = text_ids, + # adaln_emb = adaln_dual_emb, + # adaln_single_emb=adaln_single_emb, + # adaln_out = adaln_out, + # return_dict=False, + # )[0] + + start_time = time.time() + outputs = self.transformer.qpc_session.run(inputs_aic) + end_time = time.time() + print(f"Time : {end_time - start_time:.2f} seconds") + + noise_pred = torch.from_numpy(outputs["output"]) + + # # # ###### ACCURACY TESTING ####### + # mad=np.mean(np.abs(noise_pred_torch.detach().numpy()-outputs['output'])) + # print(f">>>>>>>>> at t = {t} FLUX transfromer model MAD:{mad}") + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae_decode.model.scaling_factor) + self.vae_decode.model.shift_factor + + if self.vae_decode.qpc_session is None: + self.vae_decode.qpc_session = QAICInferenceSession(str(self.vae_decoder_compile_path), device_ids=device_ids_vae_decoder) + + output_buffer = { + "sample": np.random.rand( + batch_size, 3, self.vae_decode.model.config.sample_size, self.vae_decode.model.config.sample_size + ).astype(np.int32) + } + self.vae_decode.qpc_session.set_buffers(output_buffer) + + inputs = {"latent_sample": latents.numpy()} + image = self.vae_decode.qpc_session.run(inputs) + + ###### ACCURACY TESTING ####### + # image_torch = self.vae_decode.model(latents, return_dict=False)[0] + # mad= np.mean(np.abs(image['sample']-image_torch.detach().numpy())) + # print(">>>>>>>>>>>> VAE mad: ",mad) + + image_tensor = torch.from_numpy(image['sample']) + image = self.image_processor.postprocess(image_tensor, output_type=output_type) + + # Offload all models + # self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) \ No newline at end of file diff --git a/QEfficient/diffusers/pipelines/pipeline_utils.py b/QEfficient/diffusers/pipelines/pipeline_utils.py new file mode 100644 index 000000000..b680e148f --- /dev/null +++ b/QEfficient/diffusers/pipelines/pipeline_utils.py @@ -0,0 +1,635 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import copy +import hashlib + +import torch +import torch.nn as nn + +from QEfficient.base.modeling_qeff import QEFFBaseModel +from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform +from QEfficient.diffusers.models.pytorch_transforms import AttentionTransform, NormalizationTransform, CustomOpsTransform +from QEfficient.transformers.models.pytorch_transforms import ( + T5ModelTransform, +) +from QEfficient.utils import constants +from QEfficient.utils.cache import to_hashable + + +class QEffTextEncoder(QEFFBaseModel): + _pytorch_transforms = [CustomOpsTransform, T5ModelTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + """ + QEffTextEncoder is a wrapper class for text encoder models that provides ONNX export and compilation capabilities. + + This class extends QEFFBaseModel to handle text encoder models (like T5EncoderModel) with specific + transformations and optimizations for efficient inference on Qualcomm AI hardware. + """ + + def __init__(self, model: nn.modules): + super().__init__(model) + self.model = copy.deepcopy(model) + + def get_onnx_config(self, seq_len = 512): + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + seq_len = seq_len + + example_inputs = { + "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), + } + + dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}} + output_names = ["pooler_output", "last_hidden_state"] + if self.model.__class__.__name__ == "T5EncoderModel": + output_names = ["last_hidden_state"] + else: + example_inputs["output_hidden_states"] = (True,) + return example_inputs, dynamic_axes, output_names + + def export(self, inputs, output_names, dynamic_axes, export_dir=None): + return self._export(inputs, output_names, dynamic_axes, export_dir) + + def get_specializations( + self, + batch_size: int, + seq_len: int, + ): + specializations = [ + {"batch_size": batch_size, "seq_len": seq_len}, + ] + + return specializations + + def compile( + self, + compile_dir, + compile_only, + specializations, + convert_to_fp16, + mxfp6_matmul, + mdp_ts_num_devices, + aic_num_cores, + custom_io, + **compiler_options, + ) -> str: + return self._compile( + compile_dir=compile_dir, + compile_only=compile_only, + specializations=specializations, + convert_to_fp16=convert_to_fp16, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=mdp_ts_num_devices, + aic_num_cores=aic_num_cores, + custom_io=custom_io, + **compiler_options, + ) + + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(self.model.config.to_diff_dict())) + mhash.update(to_hashable(self._transform_names())) + mhash = mhash.hexdigest()[:16] + return mhash + + @property + def model_name(self) -> str: + mname = self.model.__class__.__name__ + if mname.startswith("QEff") or mname.startswith("QEFF"): + mname = mname[4:] + return mname + + @property + def get_model_config(self) -> dict: + return self.model.model.vision_model.config.__dict__ + + +class QEffClipTextEncoder(QEFFBaseModel): + _pytorch_transforms = [CustomOpsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + """ + class QEffClipTextEncoder is a wrapper class for CLIP text encoder models that provides ONNX export and compilation capabilities. + + This class extends QEFFBaseModel to handle clip text encoder models with specific + transformations and optimizations for efficient inference on Qualcomm AI hardware. + """ + + def __init__(self, model: nn.modules): + super().__init__(model) + self.model = copy.deepcopy(model) + + def get_onnx_config(self, seq_len= 77 ): + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + # seq_len = self.tokenizer.model_max_length + + example_inputs = { "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), + "attention_mask": None, + "position_ids": None, + "output_attentions": None + } + example_inputs["output_hidden_states"] = False + + dynamic_axes = { + 'input_ids': {0: 'batch_size', 1: 'sequence_length'}, + 'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}, + 'pooler_output': {0: 'batch_size'} + } + output_names = ["last_hidden_state", "pooler_output"] + + return example_inputs, dynamic_axes, output_names + + def export(self, inputs, output_names, dynamic_axes, export_dir=None): + return self._export(inputs, output_names, dynamic_axes, export_dir) + + def get_specializations( + self, + batch_size: int, + seq_len: int, + ): + specializations = [ + {"batch_size": batch_size, "sequence_length": seq_len}, + ] + + return specializations + + def compile( + self, + compile_dir, + compile_only, + specializations, + convert_to_fp16, + mxfp6_matmul, + mdp_ts_num_devices, + aic_num_cores, + custom_io, + **compiler_options, + ) -> str: + return self._compile( + compile_dir=compile_dir, + compile_only=compile_only, + specializations=specializations, + convert_to_fp16=convert_to_fp16, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=mdp_ts_num_devices, + aic_num_cores=aic_num_cores, + custom_io=custom_io, + **compiler_options, + ) + + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(self.model.config.to_diff_dict())) + mhash.update(to_hashable(self._transform_names())) + mhash = mhash.hexdigest()[:16] + return mhash + + @property + def model_name(self) -> str: + mname = self.model.__class__.__name__ + if mname.startswith("QEff") or mname.startswith("QEFF"): + mname = mname[4:] + return mname + + @property + def get_model_config(self) -> dict: + return self.model.model.config.__dict__ + + +class QEffUNet(QEFFBaseModel): + _pytorch_transforms = [CustomOpsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + """ + QEffUNet is a wrapper class for UNet models that provides ONNX export and compilation capabilities. + + This class extends QEFFBaseModel to handle UNet models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. It is commonly used in diffusion models for image + generation tasks. + """ + + def __init__(self, model: nn.modules): + super().__init__(model.unet) + self.model = model.unet + + def export(self, inputs, output_names, dynamic_axes, export_dir=None): + return self._export(inputs, output_names, dynamic_axes, export_dir) + + def compile( + self, + compile_dir, + compile_only, + specializations, + convert_to_fp16, + mxfp6_matmul, + mdp_ts_num_devices, + aic_num_cores, + custom_io, + **compiler_options, + ) -> str: + return self._compile( + compile_dir=compile_dir, + compile_only=compile_only, + specializations=specializations, + convert_to_fp16=convert_to_fp16, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=mdp_ts_num_devices, + aic_num_cores=aic_num_cores, + custom_io=custom_io, + **compiler_options, + ) + + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(dict(self.model.config))) + mhash.update(to_hashable(self._transform_names())) + mhash = mhash.hexdigest()[:16] + return mhash + + @property + def model_name(self) -> str: + mname = self.model.__class__.__name__ + if mname.startswith("QEff") or mname.startswith("QEFF"): + mname = mname[4:] + return mname + + @property + def get_model_config(self) -> dict: + return self.model.model.vision_model.config.__dict__ + + +class QEffVAE(QEFFBaseModel): + _pytorch_transforms = [CustomOpsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + """ + QEffVAE is a wrapper class for Variational Autoencoder (VAE) models that provides ONNX export and compilation capabilities. + + This class extends QEFFBaseModel to handle VAE models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. VAE models are commonly used in diffusion pipelines + for encoding images to latent space and decoding latent representations back to images. + """ + + def __init__(self, model: nn.modules, type: str): + super().__init__(model.vae) + self.model = copy.deepcopy(model.vae) + self.type = type + + def get_onnx_config(self): + # VAE decode + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + example_inputs = { + "latent_sample": torch.randn(bs, 16, 64, 64), + "return_dict": False, + } + + output_names = ["sample"] + + dynamic_axes = { + "latent_sample": {0: "batch_size", 1: "channels", 2: "height", 3: "width"}, + } + return example_inputs, dynamic_axes, output_names + + def export(self, inputs, output_names, dynamic_axes, export_dir=None): + return self._export(inputs, output_names, dynamic_axes, export_dir) + + def get_specializations( + self, + batch_size: int, + ): + sepcializations = [ + { + "batch_size": batch_size, + "channels": 16, + "height": 128, + "width": 128, + } + ] + return sepcializations + + def compile( + self, + compile_dir, + compile_only, + specializations, + convert_to_fp16, + mxfp6_matmul, + mdp_ts_num_devices, + aic_num_cores, + custom_io, + **compiler_options, + ) -> str: + return self._compile( + compile_dir=compile_dir, + compile_only=compile_only, + specializations=specializations, + convert_to_fp16=convert_to_fp16, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=mdp_ts_num_devices, + aic_num_cores=aic_num_cores, + custom_io=custom_io, + **compiler_options, + ) + + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(dict(self.model.config))) + mhash.update(to_hashable(self._transform_names())) + mhash.update(to_hashable(self.type)) + mhash = mhash.hexdigest()[:16] + return mhash + + @property + def model_name(self) -> str: + mname = self.model.__class__.__name__ + if mname.startswith("QEff") or mname.startswith("QEFF"): + mname = mname[4:] + return mname + + @property + def get_model_config(self) -> dict: + return self.model.model.vision_model.config.__dict__ + + +class QEffSafetyChecker(QEFFBaseModel): + _pytorch_transforms = [CustomOpsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + """ + QEffSafetyChecker is a wrapper class for safety checker models that provides ONNX export and compilation capabilities. + + This class extends QEFFBaseModel to handle safety checker models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. Safety checker models are commonly used in diffusion pipelines + to filter out potentially harmful or inappropriate generated content. + """ + + def __init__(self, model: nn.modules): + super().__init__(model.vae) + self.model = model.safety_checker + + def export(self, inputs, output_names, dynamic_axes, export_dir=None): + return self._export(inputs, output_names, dynamic_axes, export_dir) + + def compile( + self, + compile_dir, + compile_only, + specializations, + convert_to_fp16, + mxfp6_matmul, + mdp_ts_num_devices, + aic_num_cores, + custom_io, + **compiler_options, + ) -> str: + return self._compile( + compile_dir=compile_dir, + compile_only=compile_only, + specializations=specializations, + convert_to_fp16=convert_to_fp16, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=mdp_ts_num_devices, + aic_num_cores=aic_num_cores, + custom_io=custom_io, + **compiler_options, + ) + + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(self.model.config.to_diff_dict())) + mhash.update(to_hashable(self._transform_names())) + mhash = mhash.hexdigest()[:16] + return mhash + + @property + def model_name(self) -> str: + mname = self.model.__class__.__name__ + if mname.startswith("QEff") or mname.startswith("QEFF"): + mname = mname[4:] + return mname + + @property + def get_model_config(self) -> dict: + return self.model.model.vision_model.config.__dict__ + + +class QEffFluxTransformerModel(QEFFBaseModel): + _pytorch_transforms = [AttentionTransform, CustomOpsTransform, NormalizationTransform ] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + """ + QEffFluxTransformerModel is a wrapper class for Flux Transformer2D models that provides ONNX export and compilation capabilities. + + This class extends QEFFBaseModel to handle Flux Transformer2D models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. It is designed for the newer Flux transformer architecture + that uses transformer-based diffusion models instead of traditional UNet architectures. + """ + + def __init__(self, model: nn.modules): + super().__init__(model) + self.model = model + def get_onnx_config(self, batch_size=1, seq_length = 256): + example_inputs = { + "hidden_states": torch.randn(batch_size, self.model.config.joint_attention_dim, self.model.config.in_channels, dtype=torch.float32), + "encoder_hidden_states": torch.randn(batch_size, seq_length , self.model.config.joint_attention_dim, dtype=torch.float32), + "pooled_projections": torch.randn(batch_size, self.model.config.pooled_projection_dim, dtype=torch.float32), + "timestep": torch.tensor([1.0], dtype=torch.float32), + "img_ids": torch.randn(self.model.config.joint_attention_dim, 3, dtype=torch.float32), + "txt_ids": torch.randn(seq_length, 3, dtype=torch.float32), + "adaln_emb": torch.randn(self.model.config.num_layers, 12, 3072, dtype=torch.float32), #num_layers, #chunks, # Adalan_hidden_dim + "adaln_single_emb": torch.randn(self.model.config.num_single_layers, 3, 3072, dtype=torch.float32), + "adaln_out": torch.randn(batch_size, 6144, dtype=torch.float32), + } + + output_names = ["output"] + + dynamic_axes = { + "hidden_states": {0: "batch_size"}, + "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, + "pooled_projections": {0: "batch_size"}, + "timestep": {0: "steps"}, + # "img_ids": {0: "image_tokens"}, + # "txt_ids": {0: "text_tokens"}, + "adaln_emb": {0: "num_layers"}, + "adaln_single_emb": {0: "num_single_layers"}, + } + + return example_inputs, dynamic_axes, output_names + + def export(self, inputs, output_names, dynamic_axes, export_dir=None): + return self._export(inputs, output_names, dynamic_axes, export_dir) + + def get_specializations( + self, + batch_size: int, + seq_len: int, + ): + specializations = [ + { + "batch_size": batch_size, + "stats-batchsize": batch_size, + "num_layers": self.model.config.num_layers, + "num_single_layers": self.model.config.num_single_layers, + "sequence_length": seq_len, + "steps": 1, + } + ] + + return specializations + + def compile( + self, + compile_dir, + compile_only, + specializations, + convert_to_fp16, + mxfp6_matmul, + mdp_ts_num_devices, + aic_num_cores, + custom_io, + **compiler_options, + ) -> str: + return self._compile( + compile_dir=compile_dir, + compile_only=compile_only, + specializations=specializations, + convert_to_fp16=convert_to_fp16, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=mdp_ts_num_devices, + aic_num_cores=aic_num_cores, + custom_io=custom_io, + **compiler_options, + ) + + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + dict_model_config = dict(self.model.config) + dict_model_config.pop("_use_default_values", None) + mhash.update(to_hashable(dict_model_config)) + mhash.update(to_hashable(self._transform_names())) + mhash = mhash.hexdigest()[:16] + return mhash + + @property + def model_name(self) -> str: + mname = self.model.__class__.__name__ + if mname.startswith("QEff") or mname.startswith("QEFF"): + mname = mname[4:] + return mname + + +class QEffWanTransformerModel(QEFFBaseModel): + _pytorch_transforms = [AttentionTransform, CustomOpsTransform, NormalizationTransform ] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + """ + QEffWanTransformerModel is a wrapper class for WanTransformer 3DModel models that provides ONNX export and compilation capabilities. + + This class extends QEFFBaseModel to handle Wan Transformer3DModel models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. + """ + def __init__(self, model: nn.modules): + super().__init__(model) + self.model = model + + def get_onnx_config(self, batch_size=1, seq_length = 512,cl=6240, latent_height = 30 , latent_width=52): #cl = 3840, # TODO update generic for Wan 2.2 5 B, 14 B + example_inputs = { + "hidden_states": torch.randn(batch_size, self.model.config.out_channels, 16 , latent_height, latent_width ,dtype=torch.float32), #TODO check self.model.config.num_frames #1, 48, 16, 30, 52), + "encoder_hidden_states": torch.randn(batch_size, seq_length , self.model.text_dim, dtype=torch.float32), # BS, seq len , text dim + "rotary_emb": torch.randn(2, cl, 1, 128 , dtype=torch.float32), #TODO update wtih CL + "temb": torch.randn(1, cl, 3072, dtype=torch.float32), + "timestep_proj": torch.randn(1, cl, 6, 3072, dtype=torch.float32), + } + + output_names = ["output"] + + dynamic_axes={ + "hidden_states": { + 0: "batch_size", + 1: "num_channels", + 2: "num_frames", + 3: "latent_height", + 4: "latent_width", + }, + "timestep": {0: "steps"}, + "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, + # "rotary_emb": {1: "cl"} + }, + + return example_inputs, dynamic_axes, output_names + + + def export(self, inputs, output_names, dynamic_axes, export_dir=None): + return self._export(inputs, output_names, dynamic_axes, export_dir) + + def get_specializations( + self, + batch_size: int, + seq_len: int, + ): + specializations = [ + { + "batch_size": batch_size, + "num_channels": "48", # TODO update with self.model + "num_frames": "16", + "latent_height": "30", + "latent_width": "52", + "sequence_length": seq_len, + "steps": 1, + } + ] + + return specializations + + def compile( + self, + compile_dir, + compile_only, + specializations, + convert_to_fp16, + mxfp6_matmul, + mdp_ts_num_devices, + aic_num_cores, + custom_io, + **compiler_options, + ) -> str: + return self._compile( + compile_dir=compile_dir, + compile_only=compile_only, + specializations=specializations, + convert_to_fp16=convert_to_fp16, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=mdp_ts_num_devices, + aic_num_cores=aic_num_cores, + custom_io=custom_io, + **compiler_options, + ) + + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(dict(self.model.config))) + mhash.update(to_hashable(self._transform_names())) + mhash = mhash.hexdigest()[:16] + return mhash + + @property + def model_name(self) -> str: + mname = self.model.__class__.__name__ + if mname.startswith("QEff") or mname.startswith("QEFF"): + mname = mname[4:] + return mname diff --git a/QEfficient/diffusers/pipelines/wan/__init__.py b/QEfficient/diffusers/pipelines/wan/__init__.py new file mode 100644 index 000000000..2bc49f7ea --- /dev/null +++ b/QEfficient/diffusers/pipelines/wan/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- \ No newline at end of file diff --git a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py new file mode 100644 index 000000000..5237bb835 --- /dev/null +++ b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py @@ -0,0 +1,760 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import os +import time +from typing import Any, Callable, Dict, List, Optional, Union +from venv import logger + +import numpy as np +import torch +import regex as re + +from diffusers import WanPipeline +from diffusers.pipelines.wan.pipeline_wan import prompt_clean +from diffusers.video_processor import VideoProcessor +from diffusers.models import AutoencoderKLWan, WanTransformer3DModel +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback + +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput +from transformers import AutoTokenizer, UMT5EncoderModel + +from QEfficient.diffusers.pipelines.pipeline_utils import QEffTextEncoder, QEffClipTextEncoder, QEffVAE, QEffWanTransformerModel +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils import constants + +class QEFFWanPipeline(WanPipeline): + _hf_auto_class = WanPipeline + + r""" + Pipeline for text-to-video generation using Wan. + A QEfficient-optimized Wan pipeline, inheriting from `diffusers.WanPipeline`. + + This class integrates QEfficient components (e.g., optimized models for umt5 text encoders, + wan transformer, and VAE) to enhance performance, particularly for deployment on Qualcomm AI hardware. + It provides methods for text-to-video generation leveraging these optimized components. + """ + + # model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae" + # _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + # _optional_components = ["transformer", "transformer_2"] + # self, + # tokenizer: AutoTokenizer, + # text_encoder: UMT5EncoderModel, + # vae: AutoencoderKLWan, + # scheduler: FlowMatchEulerDiscreteScheduler, + # transformer: Optional[WanTransformer3DModel] = None, + # transformer_2: Optional[WanTransformer3DModel] = None, + # boundary_ratio: Optional[float] = None, + # expand_timesteps: bool = False, # Wan2.2 ti2v + + def __init__(self, model, *args, **kwargs): + + # Required by diffusers for serialization and device management + self.model = model + self.args = args + self.kwargs = kwargs + + self.text_encoder = model.text_encoder # UMT5EncoderModel ##TODO : update with UMT5 encoder + self.transformer = QEffWanTransformerModel(model.transformer) + self.vae_decode = QEffVAE(model, "decoder") # TODO check and compile + self.tokenizer = model.tokenizer + self.text_encoder.tokenizer = model.tokenizer + self.scheduler = model.scheduler + # import pdb; pdb.set_trace() + # super().__init__(tokenizer=self.tokenizer, text_encoder=self.text_encoder, vae=self.vae_decode, scheduler=self.scheduler) # taken everything from parent + + + self.register_modules( + vae=self.vae_decode, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + transformer=self.transformer , + scheduler=self.scheduler, + transformer_2=model.transformer_2, + ) + # import pdb; pdb.set_trace() + boundary_ratio = self.kwargs.get("boundary_ratio", None) + expand_timesteps = self.kwargs.get("expand_timesteps", True) ##TODO : not used this part of code in onboarding + self.register_to_config(boundary_ratio=boundary_ratio) + self.register_to_config(expand_timesteps=expand_timesteps) + self.vae_scale_factor_temporal = self.vae_decode.model.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae_decode.model.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + """ + Instantiate a QEFFWanTransformer3DModel from pretrained Diffusers models. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + The path to the pretrained model or its name. + **kwargs (additional keyword arguments): + Additional arguments that can be passed to the underlying `StableDiffusion3Pipeline.from_pretrained` + method. + """ + model = cls._hf_auto_class.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch.float32, + **kwargs, + ) + model.to("cpu") + return cls(model, pretrained_model_name_or_path) + + @property + def components(self): + return { + "text_encoder": self.text_encoder, + "transformer": self.transformer, + "transformer_2": getattr(self, "transformer_2", None), + "vae": self.vae_decode, + "tokenizer": self.tokenizer, + "scheduler": self.scheduler, + } + + + def export(self, export_dir: Optional[str] = None) -> str: + """ + Exports the model to ``ONNX`` format using ``torch.onnx.export``. + + ``Optional`` Args: + :export_dir (str, optional): The directory path to store ONNX-graph. + + Returns: + :str: Path of the generated ``ONNX`` graph. + """ + + # text_encoder - umt5 ##TODO: update once umt5 modeling is available + # example_inputs_text_encoder, dynamic_axes_text_encoder, output_names_text_encoder = ( + # self.text_encoder.get_onnx_config(seq_len = self.tokenizer.model_max_length) + # ) + # self.text_encoder.export( + # inputs=example_inputs_text_encoder, + # output_names=output_names_text_encoder, + # dynamic_axes=dynamic_axes_text_encoder, + # export_dir=export_dir, + # ) + + # transformers + example_inputs_transformer, dynamic_axes_transformer, output_names_transformer = ( + self.transformer.get_onnx_config() + ) + self.transformer.export( + inputs=example_inputs_transformer, + output_names=output_names_transformer, + dynamic_axes=dynamic_axes_transformer, + export_dir=export_dir, + ) + + # vae + example_inputs_vae, dynamic_axes_vae, output_names_vae = self.vae_decode.get_onnx_config() + self.vae_decoder_onnx_path = self.vae_decode.export( + example_inputs_vae, + output_names_vae, + dynamic_axes_vae, + export_dir=export_dir, + ) + + + def compile( + self, + onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + *, + seq_len: Union[int, List[int]] = 512, + batch_size: int = 1, + num_devices_text_encoder: int = 1, + num_devices_transformer: int = 16, + num_devices_vae_decoder: int = 1, + num_cores: int = 16, # FIXME: Make this mandatory arg + mxfp6_matmul: bool = False, + **compiler_options, + ) -> str: + """ + Compiles the ONNX graphs of the different model components for deployment on Qualcomm AI hardware. + + This method takes the ONNX paths of the text encoders, transformer, and VAE decoder, + and compiles them into an optimized format for inference. + + Args: + onnx_path (`str`, *optional*): + The base directory where ONNX files were exported. If None, it assumes the ONNX + paths are already set as attributes (e.g., `self.text_encoder_onnx_path`). + This parameter is currently not fully utilized as individual ONNX paths are derived + from the `export` method. + compile_dir (`str`, *optional*): + The directory path to store the compiled artifacts. If None, a default location + might be used by the underlying compilation process. + seq_len (`Union[int, List[int]]`, *optional*, defaults to 32): + The sequence length(s) to use for compiling the text encoders. Can be a single + integer or a list of integers for multiple sequence lengths. + batch_size (`int`, *optional*, defaults to 1): + The batch size to use for compilation. + num_devices_text_encoder (`int`, *optional*, defaults to 1): + The number of AI devices to deploy the text encoder models on. + num_devices_transformer (`int`, *optional*, defaults to 4): + The number of AI devices to deploy the transformer model on. + num_devices_vae_decoder (`int`, *optional*, defaults to 1): + The number of AI devices to deploy the VAE decoder model on. + num_cores (`int`, *optional*, defaults to 16): + The number of cores to use for compilation. This argument is currently marked + as `FIXME: Make this mandatory arg`. + mxfp6_matmul (`bool`, *optional*, defaults to `False`): + If `True`, enables mixed-precision floating-point 6-bit matrix multiplication + optimization during compilation. + **compiler_options: + Additional keyword arguments to pass to the underlying compiler. + + Returns: + `str`: A message indicating the compilation status or path to compiled artifacts. + (Note: The current implementation might need to return specific paths for each compiled model). + """ + # if any( + # path is None + # for path in [ + # # self.text_encoder.onnx_path, + # self.transformer.onnx_path, + # self.vae_decode.onnx_path, + # ] + # ): + # self.export() + # text_encoder - umt5 + # specializations_text_encoder = self.text_encoder.get_specializations( + # batch_size, self.tokenizer.model_max_length + # ) + + # self.text_encoder_compile_path = self.text_encoder._compile( + # onnx_path, + # compile_dir, + # compile_only=True, + # specializations=specializations_text_encoder, + # convert_to_fp16=True, + # mxfp6_matmul=mxfp6_matmul, + # mdp_ts_num_devices=num_devices_text_encoder, + # aic_num_cores=num_cores, + # **compiler_options, + # ) + + + # transformer + # import pdb; pdb.set_trace() + specializations_transformer = self.transformer.get_specializations(batch_size, seq_len) + compiler_options = {"mos": 1, "mdts-mos":1} + # self.trasformer_compile_path = "/home/vtirumal/wan_onboard/Wan2.2_5B-Diffusers-/qpcs/transformer/qpc_transformer/" + self.trasformer_compile_path = self.transformer._compile( + onnx_path, + compile_dir, + compile_only=True, + specializations=specializations_transformer, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices_transformer, + aic_num_cores=num_cores, + **compiler_options, + ) + + # vae + # specializations_vae = self.vae_decode.get_specializations(batch_size) + # self.vae_decoder_compile_path = self.vae_decode._compile( + # onnx_path, + # compile_dir, + # compile_only=True, + # specializations=specializations_vae, + # convert_to_fp16=True, + # mdp_ts_num_devices=num_devices_vae_decoder, + # ) + + # def _get_t5_prompt_embeds( + # self, + # prompt: Union[str, List[str]] = None, + # num_videos_per_prompt: int = 1, + # max_sequence_length: int = 226, + # device: Optional[torch.device] = None, + # dtype: Optional[torch.dtype] = None, + # ): + # device = device or self._execution_device + # dtype = dtype or self.text_encoder.dtype + + # prompt = [prompt] if isinstance(prompt, str) else prompt + # prompt = [prompt_clean(u) for u in prompt] + # batch_size = len(prompt) + + # text_inputs = self.tokenizer( + # prompt, + # padding="max_length", + # max_length=max_sequence_length, + # truncation=True, + # add_special_tokens=True, + # return_attention_mask=True, + # return_tensors="pt", + # ) + # text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + # seq_lens = mask.gt(0).sum(dim=1).long() + + # prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + # prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + # prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + # prompt_embeds = torch.stack( + # [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + # ) + + # # duplicate text embeddings for each generation per prompt, using mps friendly method + # _, seq_len, _ = prompt_embeds.shape + # prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + # prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + # return prompt_embeds + + # def encode_prompt( + # self, + # prompt: Union[str, List[str]], + # negative_prompt: Optional[Union[str, List[str]]] = None, + # do_classifier_free_guidance: bool = True, + # num_videos_per_prompt: int = 1, + # prompt_embeds: Optional[torch.Tensor] = None, + # negative_prompt_embeds: Optional[torch.Tensor] = None, + # max_sequence_length: int = 226, + # device: Optional[torch.device] = None, + # dtype: Optional[torch.dtype] = None, + # ): + # r""" + # Encodes the prompt into text encoder hidden states. + + # Args: + # prompt (`str` or `List[str]`, *optional*): + # prompt to be encoded + # negative_prompt (`str` or `List[str]`, *optional*): + # The prompt or prompts not to guide the image generation. If not defined, one has to pass + # `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + # less than `1`). + # do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + # Whether to use classifier free guidance or not. + # num_videos_per_prompt (`int`, *optional*, defaults to 1): + # Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + # prompt_embeds (`torch.Tensor`, *optional*): + # Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + # provided, text embeddings will be generated from `prompt` input argument. + # negative_prompt_embeds (`torch.Tensor`, *optional*): + # Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + # weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + # argument. + # device: (`torch.device`, *optional*): + # torch device + # dtype: (`torch.dtype`, *optional*): + # torch dtype + # """ + # device = device or self._execution_device + + # prompt = [prompt] if isinstance(prompt, str) else prompt + # if prompt is not None: + # batch_size = len(prompt) + # else: + # batch_size = prompt_embeds.shape[0] + + # if prompt_embeds is None: + # prompt_embeds = self._get_t5_prompt_embeds( + # prompt=prompt, + # num_videos_per_prompt=num_videos_per_prompt, + # max_sequence_length=max_sequence_length, + # device=device, + # dtype=dtype, + # ) + + # if do_classifier_free_guidance and negative_prompt_embeds is None: + # negative_prompt = negative_prompt or "" + # negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + # if prompt is not None and type(prompt) is not type(negative_prompt): + # raise TypeError( + # f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + # f" {type(prompt)}." + # ) + # elif batch_size != len(negative_prompt): + # raise ValueError( + # f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + # f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + # " the batch size of `prompt`." + # ) + + # negative_prompt_embeds = self._get_t5_prompt_embeds( + # prompt=negative_prompt, + # num_videos_per_prompt=num_videos_per_prompt, + # max_sequence_length=max_sequence_length, + # device=device, + # dtype=dtype, + # ) + + # return prompt_embeds, negative_prompt_embeds + + + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 61, #81 + num_inference_steps: int = 50, + guidance_scale: float = 3.0, + guidance_scale_2: Optional[float] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_scale_2 (`float`, *optional*, defaults to `None`): + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's + `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` + and the pipeline's `boundary_ratio` are not None. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Examples: + + Returns: + [`~WanPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + # if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + # callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2, + ) + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + if self.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.model.dtype # if self.transformer is not None else self.transformer_2.dtype update it to self.transformer_2.model.dtype for 14 B + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = ( + self.transformer.model.config.in_channels + # if self.transformer is not None + # else self.transformer_2.model.config.in_channels + ) + + # import pdb; pdb.set_trace() + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + mask = torch.ones(latents.shape, dtype=torch.float32, device=device) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + # self._num_timesteps = len(timesteps) + + if self.config.boundary_ratio is not None: + boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + # 6. Denoising loop + ###### AIC related changes of transformers ###### + if self.transformer.qpc_session is None: + self.transformer.qpc_session = QAICInferenceSession(str(self.trasformer_compile_path)) #, device_ids=device_ids_transformer) + + output_buffer = { + "output": np.random.rand( + batch_size, 6240, 192 #self.transformer.model.config.joint_attention_dim , self.transformer.model.config.in_channels + ).astype(np.int32), + } + self.transformer.qpc_session.set_buffers(output_buffer) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + # wan2.1 or high-noise stage in wan2.2 + current_model = self.transformer.model + current_guidance_scale = guidance_scale + else: + print("NOT expected for wan 5 B ") + # # low-noise stage in wan2.2 + # current_model = self.transformer_2.model + # current_guidance_scale = guidance_scale_2 + + latent_model_input = latents.to(transformer_dtype) + if self.config.expand_timesteps: + # seq_len: num_latent_frames * latent_height//2 * latent_width//2 + temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() + # batch_size, seq_len + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + timestep = t.expand(latents.shape[0]) + + batch_size, num_channels, num_frames, height, width = latents.shape # modeling + p_t, p_h, p_w = self.transformer.model.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # patch_states = self.transformer.patch_embedding(latent_model_input) + # import pdb; pdb.set_trace() + rotary_emb = self.transformer.model.rope(latent_model_input) + rotary_emb = torch.cat(rotary_emb, dim=0) + ts_seq_len = None + # ts_seq_len = timestep.shape[1] + timestep = timestep.flatten() + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.transformer.model.condition_embedder( + timestep, prompt_embeds, encoder_hidden_states_image=None, timestep_seq_len=ts_seq_len + ) + temb, timestep_proj, encoder_hidden_states_neg, encoder_hidden_states_image = self.transformer.model.condition_embedder( + timestep, negative_prompt_embeds, encoder_hidden_states_image=None, timestep_seq_len=ts_seq_len + ) + # timestep_proj = timestep_proj.unflatten(2, (6, -1)) # for 5 B rnew_app.py ##TODO: cross check once + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + # import pdb; pdb.set_trace() + inputs_aic = { + "hidden_states": latents.detach().numpy(), + "encoder_hidden_states": encoder_hidden_states.detach().numpy(), + "rotary_emb": rotary_emb.detach().numpy(), + "temb": temb.detach().numpy(), + "timestep_proj": timestep_proj.detach().numpy() + } + + inputs_aic2 = { + "hidden_states": latents.detach().numpy(), + "encoder_hidden_states": encoder_hidden_states_neg.detach().numpy(), + "rotary_emb": rotary_emb.detach().numpy(), + "temb": temb.detach().numpy(), + "timestep_proj": timestep_proj.detach().numpy() + } + + # import pdb; pdb.set_trace() + + # with current_model.cache_context("cond"): + # noise_pred_torch = current_model( + # hidden_states=latent_model_input, + # # timestep=timestep, + # encoder_hidden_states=encoder_hidden_states, + # rotary_emb=rotary_emb, + # temb=temb, + # timestep_proj=timestep_proj, + # attention_kwargs=attention_kwargs, + # return_dict=False, + # )[0] + + start_time = time.time() + outputs = self.transformer.qpc_session.run(inputs_aic) + end_time = time.time() + print(f"Time : {end_time - start_time:.2f} seconds") + + # noise_pred = torch.from_numpy(outputs["output"]) + hidden_states = torch.tensor(outputs["output"]) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + noise_pred = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + + if self.do_classifier_free_guidance: + # with current_model.cache_context("uncond"): + # noise_uncond_pytorch = current_model( + # hidden_states=latent_model_input, + # timestep=timestep, + # encoder_hidden_states=negative_prompt_embeds, + # attention_kwargs=attention_kwargs, + # return_dict=False, + # )[0] + start_time = time.time() + outputs = self.transformer.qpc_session.run(inputs_aic2) + end_time = time.time() + print(f"Time : {end_time - start_time:.2f} seconds") + + # noise_uncond = torch.from_numpy(outputs["output"]) + hidden_states = torch.tensor(outputs["output"]) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + noise_uncond = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # if callback_on_step_end is not None: + # callback_kwargs = {} + # for k in callback_on_step_end_tensor_inputs: + # callback_kwargs[k] = locals()[k] + # callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + # latents = callback_outputs.pop("latents", latents) + # prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + # negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # if XLA_AVAILABLE: + # xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae_decode.model.dtype) + latents_mean = ( + torch.tensor(self.vae_decode.model.config.latents_mean) + .view(1, self.vae_decode.model.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae_decode.model.config.latents_std).view(1, self.vae_decode.model.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + # import pdb; pdb.set_trace() + video = self.model.vae.decode(latents, return_dict=False)[0] #TODO: to enable aic with qpc self.vae_decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video.detach()) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) \ No newline at end of file diff --git a/QEfficient/generation/cloud_infer.py b/QEfficient/generation/cloud_infer.py index 8519d824c..8fe1c0868 100644 --- a/QEfficient/generation/cloud_infer.py +++ b/QEfficient/generation/cloud_infer.py @@ -84,7 +84,7 @@ def __init__( self.binding_index_map = {binding.name: binding.index for binding in self.bindings} # Create and load Program prog_properties = qaicrt.QAicProgramProperties() - prog_properties.SubmitRetryTimeoutMs = 60_000 + prog_properties.SubmitRetryTimeoutMs = 60_00000 if device_ids and len(device_ids) > 1: prog_properties.devMapping = ":".join(map(str, device_ids)) self.program = qaicrt.Program(self.context, None, qpc, prog_properties) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index ca74c0ddd..6719396c0 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -142,6 +142,13 @@ Starcoder2ForCausalLM, Starcoder2Model, ) +from transformers.models.t5.modeling_t5 import ( + T5Attention, + T5LayerCrossAttention, + T5LayerFF, + T5LayerNorm, + T5LayerSelfAttention, +) from transformers.models.whisper.modeling_whisper import ( WhisperAttention, WhisperDecoder, @@ -309,6 +316,13 @@ QEffStarcoder2ForCausalLM, QEffStarcoder2Model, ) +from QEfficient.transformers.models.t5.modeling_t5 import ( + QEffT5Attention, + QEffT5LayerCrossAttention, + QEffT5LayerFF, + QEffT5LayerNorm, + QEffT5LayerSelfAttention, +) from QEfficient.transformers.models.whisper.modeling_whisper import ( QEffWhisperAttention, QEffWhisperDecoder, @@ -617,6 +631,22 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): _match_class_replace_method = {} +class T5ModelTransform(ModuleMappingTransform): + # supported architectures + _module_mapping = { + T5LayerFF: QEffT5LayerFF, + T5LayerSelfAttention: QEffT5LayerSelfAttention, + T5LayerCrossAttention: QEffT5LayerCrossAttention, + T5Attention: QEffT5Attention, + T5LayerNorm: QEffT5LayerNorm, + } + + @classmethod + def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + model, transformed = super().apply(model) + return model, transformed + + class PoolingTransform: """ Apply a pooling transformation to the model. This transformation appends a pooling layer to the model, allowing for the reduction of spatial dimensions in the output. diff --git a/QEfficient/transformers/models/t5/__init__.py b/QEfficient/transformers/models/t5/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/QEfficient/transformers/models/t5/modeling_t5.py b/QEfficient/transformers/models/t5/modeling_t5.py new file mode 100644 index 000000000..9ba5869d7 --- /dev/null +++ b/QEfficient/transformers/models/t5/modeling_t5.py @@ -0,0 +1,217 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import torch.nn as nn +from transformers.models.t5.modeling_t5 import ( + T5Attention, + T5LayerCrossAttention, + T5LayerFF, + T5LayerNorm, + T5LayerSelfAttention, +) + + +class QEffT5LayerNorm(T5LayerNorm): + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1], dtype=torch.float32)) + variance = div_first.pow(2).sum(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class QEffT5LayerFF(T5LayerFF): + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states * 1.0 + self.dropout(forwarded_states) + return hidden_states + + +class QEffT5Attention(T5Attention): + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + # Original line: position_bias = position_bias[:, :, -seq_length:, :] + if past_key_value is not None: # This block is where the patch applies + # position_bias = position_bias[:, :, -hidden_states.size(1) :, :] # Original line (commented in patch) + position_bias = position_bias[:, :, -1:, :] # Added by patch + + if mask is not None: + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = self.o(attn_output) + + outputs = (attn_output, past_key_value, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class QEffT5LayerSelfAttention(T5LayerSelfAttention): + def __qeff_init__(self): + self.scaling_factor = 1.0 + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = hidden_states * 1.0 + self.dropout(attention_output[0]) # Modified by patch + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class QEffT5LayerCrossAttention(T5LayerCrossAttention): + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + cache_position=cache_position, + ) + layer_output = hidden_states * 1.0 + self.dropout(attention_output[0]) # Modified by patch + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 50f36ea32..e458fe5b2 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -68,7 +68,7 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS = 512 ONNX_EXPORT_EXAMPLE_TOP_PS = 0.80 ONNX_EXPORT_EXAMPLE_MIN_PS = 0.99 -ONNX_EXPORT_OPSET = 13 +ONNX_EXPORT_OPSET = 17 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"] @@ -103,6 +103,35 @@ def get_models_dir(): GEMMA3_MAX_POSITION_EMBEDDINGS = 32768 +# wo_sfs: weight output scaling factors (used to normalize T5 encoder output weights before export) +WO_SFS = [ + 61, + 203, + 398, + 615, + 845, + 1190, + 1402, + 2242, + 1875, + 2393, + 3845, + 3213, + 3922, + 4429, + 5020, + 5623, + 6439, + 6206, + 5165, + 4593, + 2802, + 2618, + 1891, + 1419, +] + + class Constants: # Export Constants. SEQ_LEN = 32 diff --git a/docs/image/girl_laughing.png b/docs/image/girl_laughing.png new file mode 100644 index 000000000..f3ad34a7a Binary files /dev/null and b/docs/image/girl_laughing.png differ diff --git a/examples/diffusers/flux/flux_1_schnell.py b/examples/diffusers/flux/flux_1_schnell.py new file mode 100644 index 000000000..c67e86c22 --- /dev/null +++ b/examples/diffusers/flux/flux_1_schnell.py @@ -0,0 +1,30 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +import torch +from QEfficient import QEFFFluxPipeline + +pipeline = QEFFFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + +######## for single layer +# original_blocks = pipeline.transformer.model.transformer_blocks +# org_single_blocks = pipeline.transformer.model.single_transformer_blocks +# pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0]]) +# pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0]]) +# # Update num_layers to reflect the change +# pipeline.transformer.model.config.num_layers = 1 +# pipeline.transformer.model.config.num_single_layers = 1 + +pipeline.compile(num_devices_text_encoder=1, num_devices_transformer=4, num_devices_vae_decoder=1) + +generator = torch.manual_seed(42) +# NOTE: guidance_scale <=1 is not supported +image = pipeline("A cat holding a sign that says hello world", + guidance_scale=0.0, + num_inference_steps=4, + max_sequence_length=256, + generator=generator, device_ids_text_encoder_1=[40], device_ids_text_encoder_2=[41], device_ids_vae_decoder=[42], device_ids_transformer=[44,45,46,47]).images[0] +image.save("flux-schnell_aic.png") diff --git a/examples/diffusers/wan/wan2_2_5B.py b/examples/diffusers/wan/wan2_2_5B.py new file mode 100644 index 000000000..1871e8388 --- /dev/null +++ b/examples/diffusers/wan/wan2_2_5B.py @@ -0,0 +1,36 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +import torch +from diffusers import AutoencoderKLWan +from diffusers.utils import export_to_video +from diffusers import AutoencoderKLWan, WanPipeline +from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler +from QEfficient import QEFFWanPipeline +model_id = "Wan-AI/Wan2.2-TI2V-5B-Diffusers" +vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) +pipeline = QEFFWanPipeline.from_pretrained(model_id, vae=vae) + +# ######## for single layer +# original_blocks = pipeline.transformer.model.blocks +# pipeline.transformer.model.blocks = torch.nn.ModuleList([original_blocks[0]]) +# pipeline.transformer.config.num_layers = 1 +# import pdb; pdb.set_trace() +pipeline.compile(num_devices_text_encoder=1, num_devices_transformer=4, num_devices_vae_decoder=1) +flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P +pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift) +pipeline.to("cpu") +prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=480, + width=832, + num_frames=61, + guidance_scale=3.0, + ).frames[0] +export_to_video(output, "output.mp4", fps=12) diff --git a/output.mp4 b/output.mp4 new file mode 100644 index 000000000..9cd551d5c Binary files /dev/null and b/output.mp4 differ diff --git a/pyproject.toml b/pyproject.toml index 479736c22..73136dcef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,9 +20,9 @@ classifiers = [ requires-python = ">=3.8,<3.11" dependencies = [ "transformers==4.51.3", - "huggingface-hub==0.30.0", + "huggingface-hub==0.34.0", "hf_transfer==0.1.9", - "peft==0.13.2", + "peft==0.17.0", "datasets==2.20.0", "fsspec==2023.6.0", "multidict==6.0.4", @@ -39,18 +39,18 @@ dependencies = [ "fire", "py7zr", "torchmetrics==1.7.0", - "torch==2.4.1; platform_machine=='aarch64'", + "torch==2.7.1; platform_machine=='aarch64'", # Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11 "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'", - "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp39-cp39-linux_x86_64.whl ; python_version=='3.9' and platform_machine=='x86_64'", - "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp310-cp310-linux_x86_64.whl ; python_version=='3.10' and platform_machine=='x86_64'", + "torch@https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_machine=='x86_64'", + "torch@https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_machine=='x86_64'", ] [project.optional-dependencies] test = ["pytest","pytest-mock"] docs = ["Sphinx==7.1.2","sphinx-rtd-theme==2.0.0","myst-parser==3.0.1","sphinx-multiversion"] quality = ["black", "ruff", "hf_doc_builder@git+https://github.com/huggingface/doc-builder.git"] - +diffusers = ["diffusers== 0.35.1"] [build-system] requires = ["setuptools>=62.0.0"] build-backend = "setuptools.build_meta" @@ -71,4 +71,4 @@ target-version = "py310" [tool.pytest.ini_options] addopts = "-W ignore -s -v" junit_logging = "all" -doctest_optionflags = "NUMBER NORMALIZE_WHITESPACE ELLIPSIS" +doctest_optionflags = "NUMBER NORMALIZE_WHITESPACE ELLIPSIS" \ No newline at end of file