diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py
index 33c6f5588..b4e7d9f1b 100644
--- a/QEfficient/__init__.py
+++ b/QEfficient/__init__.py
@@ -50,6 +50,9 @@ def check_qaic_sdk():
QEFFCommonLoader,
)
from QEfficient.compile.compile_helper import compile
+
+ # Imports for the diffusers
+ from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEFFFluxPipeline
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
from QEfficient.peft import QEffAutoPeftModelForCausalLM
@@ -70,6 +73,7 @@ def check_qaic_sdk():
"QEFFAutoModelForImageTextToText",
"QEFFAutoModelForSpeechSeq2Seq",
"QEFFCommonLoader",
+ "QEFFFluxPipeline",
]
else:
diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py
index 6ecbf0fc0..23db61472 100644
--- a/QEfficient/base/modeling_qeff.py
+++ b/QEfficient/base/modeling_qeff.py
@@ -117,8 +117,30 @@ def _model_offloaded_check(self) -> None:
raise RuntimeError(error_msg)
@property
- @abstractmethod
- def model_name(self) -> str: ...
+ def model_name(self) -> str:
+ """
+ Get the model class name without QEff/QEFF prefix.
+
+ This property extracts the underlying model's class name and removes
+ any QEff or QEFF prefix that may have been added during wrapping.
+
+ Returns:
+ str: Model class name (e.g., "CLIPTextModel" instead of "QEffCLIPTextModel")
+ """
+ mname = self.model.__class__.__name__
+ if mname.startswith("QEff") or mname.startswith("QEFF"):
+ mname = mname[4:]
+ return mname
+
+ @property
+ def get_model_config(self) -> Dict:
+ """
+ Get the model configuration as a dictionary.
+
+ Returns:
+ Dict: The configuration dictionary of the underlying HuggingFace model
+ """
+ return self.model.config.__dict__
@abstractmethod
def export(self, export_dir: Optional[str] = None) -> Path:
diff --git a/QEfficient/diffusers/README.md b/QEfficient/diffusers/README.md
new file mode 100644
index 000000000..a42cc4bdf
--- /dev/null
+++ b/QEfficient/diffusers/README.md
@@ -0,0 +1,108 @@
+
+
+
+
+# **Diffusion Models on Qualcomm Cloud AI 100**
+
+
+
+
+### 🎨 **Experience the Future of AI Image Generation**
+
+* Optimized for Qualcomm Cloud AI 100*
+
+

+
+**Generated with**: `black-forest-labs/FLUX.1-schnell` • `"A girl laughing"` • 4 steps • 0.0 guidance scale • ⚡
+
+
+
+
+
+
+
+[](https://github.com/huggingface/diffusers)
+
+
+---
+
+## ✨ Overview
+
+QEfficient Diffusers brings the power of state-of-the-art diffusion models to Qualcomm Cloud AI 100 hardware for text-to-image generation. Built on top of the popular HuggingFace Diffusers library, our optimized pipeline provides seamless inference on Qualcomm Cloud AI 100 hardware.
+
+## 🛠️ Installation
+
+### Prerequisites
+
+Ensure you have Python 3.8+ and the required dependencies:
+
+```bash
+# Create Python virtual environment (Recommended Python 3.10)
+sudo apt install python3.10-venv
+python3.10 -m venv qeff_env
+source qeff_env/bin/activate
+pip install -U pip
+```
+
+### Install QEfficient
+
+```bash
+# Install from GitHub (includes diffusers support)
+pip install git+https://github.com/quic/efficient-transformers
+
+# Or build from source
+git clone https://github.com/quic/efficient-transformers.git
+cd efficient-transformers
+pip install build wheel
+python -m build --wheel --outdir dist
+pip install dist/qefficient-0.0.1.dev0-py3-none-any.whl
+```
+
+### Install Diffusers Dependencies
+
+```bash
+# Install diffusers optional dependencies
+pip install "QEfficient[diffusers]"
+```
+
+---
+
+## 🎯 Supported Models
+- ✅ [`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell)
+
+---
+
+
+## 📚 Examples
+
+Check out our comprehensive examples in the [`examples/diffusers/`](../../examples/diffusers/) directory:
+
+---
+
+## 🤝 Contributing
+
+We welcome contributions! Please see our [Contributing Guide](../../CONTRIBUTING.md) for details.
+
+### Development Setup
+
+```bash
+git clone https://github.com/quic/efficient-transformers.git
+cd efficient-transformers
+pip install -e ".[diffusers,test]"
+```
+
+---
+
+## 🙏 Acknowledgments
+
+- **HuggingFace Diffusers**: For the excellent foundation library
+- **Stability AI**: For the amazing Stable Diffusion models
+---
+
+## 📞 Support
+
+- 📖 **Documentation**: [https://quic.github.io/efficient-transformers/](https://quic.github.io/efficient-transformers/)
+- 🐛 **Issues**: [GitHub Issues](https://github.com/quic/efficient-transformers/issues)
+
+---
+
diff --git a/QEfficient/diffusers/__init__.py b/QEfficient/diffusers/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/models/__init__.py b/QEfficient/diffusers/models/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/models/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/models/attention.py b/QEfficient/diffusers/models/attention.py
new file mode 100644
index 000000000..3c9cc268d
--- /dev/null
+++ b/QEfficient/diffusers/models/attention.py
@@ -0,0 +1,75 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+import torch
+from diffusers.models.attention import JointTransformerBlock, _chunked_feed_forward
+
+
+class QEffJointTransformerBlock(JointTransformerBlock):
+ def forward(
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
+ ):
+ if self.use_dual_attention:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
+ hidden_states, emb=temb
+ )
+ else:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+
+ if self.context_pre_only:
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
+ else:
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+
+ # Attention.
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
+ )
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ if self.use_dual_attention:
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
+ attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
+ hidden_states = hidden_states + attn_output2
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ # ff_output = self.ff(norm_hidden_states)
+ ff_output = self.ff(norm_hidden_states, block_size=4096)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+ if self.context_pre_only:
+ encoder_hidden_states = None
+ else:
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ context_ff_output = _chunked_feed_forward(
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
+ )
+ else:
+ # context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ context_ff_output = self.ff_context(norm_encoder_hidden_states, block_size=333)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+
+ return encoder_hidden_states, hidden_states
diff --git a/QEfficient/diffusers/models/attention_processor.py b/QEfficient/diffusers/models/attention_processor.py
new file mode 100644
index 000000000..01954e55e
--- /dev/null
+++ b/QEfficient/diffusers/models/attention_processor.py
@@ -0,0 +1,155 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+from typing import Optional
+
+import torch
+from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
+
+
+class QEffAttention(Attention):
+ def __qeff_init__(self):
+ processor = QEffJointAttnProcessor2_0()
+ self.processor = processor
+ processor.query_block_size = 64
+
+ def get_attention_scores(
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ dtype = query.dtype
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ if attention_mask is None:
+ baddbmm_input = torch.empty(
+ query.shape[0], query.shape[1], key.shape[2], dtype=query.dtype, device=query.device
+ )
+ beta = 0
+ else:
+ baddbmm_input = attention_mask
+ beta = 1
+
+ attention_scores = torch.baddbmm(
+ baddbmm_input,
+ query,
+ key,
+ beta=beta,
+ alpha=self.scale,
+ )
+ del baddbmm_input
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ del attention_scores
+
+ attention_probs = attention_probs.to(dtype)
+
+ return attention_probs
+
+
+class QEffJointAttnProcessor2_0(JointAttnProcessor2_0):
+ def __call__(
+ self,
+ attn: QEffAttention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ batch_size = hidden_states.shape[0]
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # `context` projections.
+ if encoder_hidden_states is not None:
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
+
+ query = query.reshape(-1, query.shape[-2], query.shape[-1])
+ key = key.reshape(-1, key.shape[-2], key.shape[-1])
+ value = value.reshape(-1, value.shape[-2], value.shape[-1])
+
+ # pre-transpose the key
+ key = key.transpose(-1, -2)
+ if query.size(-2) != value.size(-2): # cross-attention, use regular attention
+ # QKV done in single block
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ else: # self-attention, use blocked attention
+ # QKV done with block-attention (a la FlashAttentionV2)
+ query_block_size = self.query_block_size
+ query_seq_len = query.size(-2)
+ num_blocks = (query_seq_len + query_block_size - 1) // query_block_size
+ for qidx in range(num_blocks):
+ query_block = query[:, qidx * query_block_size : (qidx + 1) * query_block_size, :]
+ attention_probs = attn.get_attention_scores(query_block, key, attention_mask)
+ hidden_states_block = torch.bmm(attention_probs, value)
+ if qidx == 0:
+ hidden_states = hidden_states_block
+ else:
+ hidden_states = torch.cat((hidden_states, hidden_states_block), -2)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ if encoder_hidden_states is not None:
+ # Split the attention outputs.
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : residual.shape[1]],
+ hidden_states[:, residual.shape[1] :],
+ )
+ if not attn.context_pre_only:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if encoder_hidden_states is not None:
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
diff --git a/QEfficient/diffusers/models/autoencoders/__init__.py b/QEfficient/diffusers/models/autoencoders/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/models/autoencoders/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/models/autoencoders/autoencoder_kl.py b/QEfficient/diffusers/models/autoencoders/autoencoder_kl.py
new file mode 100644
index 000000000..c652f07d2
--- /dev/null
+++ b/QEfficient/diffusers/models/autoencoders/autoencoder_kl.py
@@ -0,0 +1,31 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+import torch
+from diffusers import AutoencoderKL
+
+
+class QEffAutoencoderKL(AutoencoderKL):
+ def encode(self, x: torch.Tensor, return_dict: bool = True):
+ """
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded images. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+ return h
diff --git a/QEfficient/diffusers/models/normalization.py b/QEfficient/diffusers/models/normalization.py
new file mode 100644
index 000000000..87afcf670
--- /dev/null
+++ b/QEfficient/diffusers/models/normalization.py
@@ -0,0 +1,51 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+from typing import Optional, Tuple
+
+import torch
+from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
+
+
+class QEffAdaLayerNormZero(AdaLayerNormZero):
+ def forward(
+ self,
+ x: torch.Tensor,
+ timestep: Optional[torch.Tensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ hidden_dtype: Optional[torch.dtype] = None,
+ shift_msa: Optional[torch.Tensor] = None,
+ scale_msa: Optional[torch.Tensor] = None,
+ # emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # if self.emb is not None:
+ # emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
+ # emb = self.linear(self.silu(emb))
+ # shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x
+
+
+class QEffAdaLayerNormZeroSingle(AdaLayerNormZeroSingle):
+ def forward(
+ self,
+ x: torch.Tensor,
+ scale_msa: Optional[torch.Tensor] = None,
+ shift_msa: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x
+
+
+class QEffAdaLayerNormContinuous(AdaLayerNormContinuous):
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
+ # emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
+ emb = conditioning_embedding
+ scale, shift = torch.chunk(emb, 2, dim=1)
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
+ return x
diff --git a/QEfficient/diffusers/models/pytorch_transforms.py b/QEfficient/diffusers/models/pytorch_transforms.py
new file mode 100644
index 000000000..582adfac7
--- /dev/null
+++ b/QEfficient/diffusers/models/pytorch_transforms.py
@@ -0,0 +1,92 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+from typing import Tuple
+
+from diffusers.models.attention import JointTransformerBlock
+from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
+from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm
+from diffusers.models.transformers.transformer_flux import (
+ FluxAttention,
+ FluxAttnProcessor,
+ FluxSingleTransformerBlock,
+ FluxTransformer2DModel,
+ FluxTransformerBlock,
+)
+from torch import nn
+
+from QEfficient.base.pytorch_transforms import ModuleMappingTransform
+from QEfficient.customop.rms_norm import CustomRMSNormAIC
+from QEfficient.diffusers.models.attention import QEffJointTransformerBlock
+from QEfficient.diffusers.models.attention_processor import (
+ QEffAttention,
+ QEffJointAttnProcessor2_0,
+)
+from QEfficient.diffusers.models.normalization import (
+ QEffAdaLayerNormContinuous,
+ QEffAdaLayerNormZero,
+ QEffAdaLayerNormZeroSingle,
+)
+from QEfficient.diffusers.models.transformers.transformer_flux import (
+ QEffFluxAttention,
+ QEffFluxAttnProcessor,
+ QEffFluxSingleTransformerBlock,
+ QEffFluxTransformer2DModel,
+ QEffFluxTransformer2DModelOF,
+ QEffFluxTransformerBlock,
+)
+
+
+class CustomOpsTransform(ModuleMappingTransform):
+ _module_mapping = {
+ RMSNorm: CustomRMSNormAIC,
+ nn.RMSNorm: CustomRMSNormAIC, # for torch.nn.RMSNorm
+ }
+
+ @classmethod
+ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
+ model, transformed = super().apply(model)
+ return model, transformed
+
+
+class AttentionTransform(ModuleMappingTransform):
+ _module_mapping = {
+ Attention: QEffAttention,
+ JointAttnProcessor2_0: QEffJointAttnProcessor2_0,
+ JointTransformerBlock: QEffJointTransformerBlock,
+ FluxSingleTransformerBlock: QEffFluxSingleTransformerBlock,
+ FluxTransformerBlock: QEffFluxTransformerBlock,
+ FluxTransformer2DModel: QEffFluxTransformer2DModel,
+ FluxAttention: QEffFluxAttention,
+ FluxAttnProcessor: QEffFluxAttnProcessor,
+ }
+
+ @classmethod
+ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
+ model, transformed = super().apply(model)
+ return model, transformed
+
+
+class NormalizationTransform(ModuleMappingTransform):
+ _module_mapping = {
+ AdaLayerNormZero: QEffAdaLayerNormZero,
+ AdaLayerNormZeroSingle: QEffAdaLayerNormZeroSingle,
+ AdaLayerNormContinuous: QEffAdaLayerNormContinuous,
+ }
+
+ @classmethod
+ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
+ model, transformed = super().apply(model)
+ return model, transformed
+
+
+class OnnxFunctionTransform(ModuleMappingTransform):
+ _module_mapping = {QEffFluxTransformer2DModel, QEffFluxTransformer2DModelOF}
+
+ @classmethod
+ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
+ model, transformed = super().apply(model)
+ return model, transformed
diff --git a/QEfficient/diffusers/models/transformers/__init__.py b/QEfficient/diffusers/models/transformers/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/models/transformers/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/models/transformers/transformer_flux.py b/QEfficient/diffusers/models/transformers/transformer_flux.py
new file mode 100644
index 000000000..8a9635b13
--- /dev/null
+++ b/QEfficient/diffusers/models/transformers/transformer_flux.py
@@ -0,0 +1,425 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+from typing import Any, Dict, Optional, Tuple, Union
+from venv import logger
+
+import numpy as np
+import torch
+import torch.nn as nn
+from diffusers.models.attention_dispatch import dispatch_attention_fn
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.transformers.transformer_flux import (
+ FluxAttention,
+ FluxAttnProcessor,
+ FluxSingleTransformerBlock,
+ FluxTransformer2DModel,
+ FluxTransformerBlock,
+ _get_qkv_projections,
+)
+
+from QEfficient.diffusers.models.normalization import (
+ QEffAdaLayerNormZero,
+ QEffAdaLayerNormZeroSingle,
+)
+
+
+def qeff_apply_rotary_emb(
+ x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
+ tensors contain rotary embeddings and are returned as real tensors.
+
+ Args:
+ x (`torch.Tensor`):
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+ """
+ cos, sin = freqs_cis # [S, D]
+ cos = cos[None, :, None, :]
+ sin = sin[None, :, None, :]
+ cos, sin = cos.to(x.device), sin.to(x.device)
+ B, S, H, D = x.shape
+ x_real, x_imag = x.reshape(B, -1, H, D // 2, 2).unbind(-1)
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+ return out
+
+
+class QEffFluxAttnProcessor(FluxAttnProcessor):
+ _attention_backend = None
+ _parallel_config = None
+
+ def __call__(
+ self,
+ attn: "QEffFluxAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
+ attn, hidden_states, encoder_hidden_states
+ )
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if attn.added_kv_proj_dim is not None:
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+
+ encoder_query = attn.norm_added_q(encoder_query)
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=1)
+ key = torch.cat([encoder_key, key], dim=1)
+ value = torch.cat([encoder_value, value], dim=1)
+
+ if image_rotary_emb is not None:
+ query = qeff_apply_rotary_emb(query, image_rotary_emb)
+ key = qeff_apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states = dispatch_attention_fn(
+ query, key, value, attn_mask=attention_mask, backend=self._attention_backend
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ )
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class QEffFluxAttention(FluxAttention):
+ def __qeff_init__(self):
+ processor = QEffFluxAttnProcessor()
+ self.processor = processor
+
+
+class QEffFluxSingleTransformerBlock(FluxSingleTransformerBlock):
+ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
+ super().__init__(dim, num_attention_heads, attention_head_dim, mlp_ratio)
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
+ self.norm = QEffAdaLayerNormZeroSingle(dim)
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
+ self.act_mlp = nn.GELU(approximate="tanh")
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
+ self.attn = QEffFluxAttention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=True,
+ processor=QEffFluxAttnProcessor(),
+ eps=1e-6,
+ pre_only=True,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ text_seq_len = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ temb = tuple(torch.split(temb, 1))
+ gate = temb[2]
+ residual = hidden_states
+ norm_hidden_states = self.norm(hidden_states, scale_msa=temb[1], shift_msa=temb[0])
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ gate = gate.unsqueeze(1)
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = residual + hidden_states
+ # if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
+ return encoder_hidden_states, hidden_states
+
+
+class QEffFluxTransformerBlock(FluxTransformerBlock):
+ def __init__(
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
+ ):
+ super().__init__(dim, num_attention_heads, attention_head_dim)
+
+ self.norm1 = QEffAdaLayerNormZero(dim)
+ self.norm1_context = QEffAdaLayerNormZero(dim)
+ self.attn = QEffFluxAttention(
+ query_dim=dim,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ context_pre_only=False,
+ bias=True,
+ processor=QEffFluxAttnProcessor(),
+ eps=eps,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ temb1 = tuple(torch.split(temb[:6], 1))
+ temb2 = tuple(torch.split(temb[6:], 1))
+ norm_hidden_states = self.norm1(hidden_states, shift_msa=temb1[0], scale_msa=temb1[1])
+ gate_msa, shift_mlp, scale_mlp, gate_mlp = temb1[-4:]
+
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, shift_msa=temb2[0], scale_msa=temb2[1])
+
+ c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = temb2[-4:]
+
+ joint_attention_kwargs = joint_attention_kwargs or {}
+
+ # Attention.
+ attention_outputs = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ if len(attention_outputs) == 2:
+ attn_output, context_attn_output = attention_outputs
+ elif len(attention_outputs) == 3:
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+ if len(attention_outputs) == 3:
+ hidden_states = hidden_states + ip_attn_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+ # if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states
+
+
+class QEffFluxTransformer2DModel(FluxTransformer2DModel):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ adaln_emb: torch.Tensor = None,
+ adaln_single_emb: torch.Tensor = None,
+ adaln_out: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_block_samples=None,
+ controlnet_single_block_samples=None,
+ return_dict: bool = True,
+ controlnet_blocks_repeat: bool = False,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ """
+ The [`FluxTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+
+ hidden_states = self.x_embedder(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+
+ temb = (
+ self.time_text_embed(timestep, pooled_projections)
+ if guidance is None
+ else self.time_text_embed(timestep, guidance, pooled_projections)
+ )
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if txt_ids.ndim == 3:
+ logger.warning(
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ txt_ids = txt_ids[0]
+ if img_ids.ndim == 3:
+ logger.warning(
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ img_ids = img_ids[0]
+
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ image_rotary_emb = self.pos_embed(ids)
+
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
+
+ for index_block, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ joint_attention_kwargs,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=adaln_emb[index_block],
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ # controlnet residual
+ if controlnet_block_samples is not None:
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ # For Xlabs ControlNet.
+ if controlnet_blocks_repeat:
+ hidden_states = (
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
+ )
+ else:
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
+
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ joint_attention_kwargs,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=adaln_single_emb[index_block],
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ # controlnet residual
+ if controlnet_single_block_samples is not None:
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
+
+ hidden_states = self.norm_out(hidden_states, adaln_out)
+ output = self.proj_out(hidden_states)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
+
+
+class QEffFluxTransformer2DModelOF(QEffFluxTransformer2DModel):
+ def __qeff_init__(self):
+ self.transformer_blocks = nn.ModuleList()
+ self._block_classes = set()
+
+ for _ in range(self.config.num_layers):
+ BlockClass = QEffFluxTransformerBlock
+ block = BlockClass(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ )
+ self.transformer_blocks.append(block)
+ self._block_classes.add(BlockClass)
+
+ self.single_transformer_blocks = nn.ModuleList()
+
+ for _ in range(self.config.num_single_layers):
+ SingleBlockClass = QEffFluxSingleTransformerBlock
+ single_block = SingleBlockClass(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ )
+ self.single_transformer_blocks.append(single_block)
+ self._block_classes.add(SingleBlockClass)
diff --git a/QEfficient/diffusers/pipelines/__init__.py b/QEfficient/diffusers/pipelines/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/pipelines/flux/__init__.py b/QEfficient/diffusers/pipelines/flux/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/flux/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/diffusers/pipelines/flux/flux_config.json b/QEfficient/diffusers/pipelines/flux/flux_config.json
new file mode 100644
index 000000000..546528445
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/flux/flux_config.json
@@ -0,0 +1,94 @@
+{
+ "description": "Default configuration for Flux pipeline",
+
+ "modules":
+ {
+ "text_encoder":
+ {
+ "specializations":{
+ "batch_size": 1,
+ "seq_len": 77
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+
+ },
+ "text_encoder_2":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "transformer":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256,
+ "steps": 1
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 4,
+ "mxfp6_matmul": true,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "mos": 1,
+ "mdts-mos": 1
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "vae_decoder":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "channels": 16
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ }
+ }
+}
diff --git a/QEfficient/diffusers/pipelines/flux/pipeline_flux.py b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py
new file mode 100644
index 000000000..38df59941
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py
@@ -0,0 +1,731 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+import os
+import time
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from diffusers import FluxPipeline
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
+from tqdm import tqdm
+
+from QEfficient.diffusers.pipelines.pipeline_module import (
+ QEffFluxTransformerModel,
+ QEffTextEncoder,
+ QEffVAE,
+)
+from QEfficient.diffusers.pipelines.pipeline_utils import (
+ ModulePerf,
+ QEffPipelineOutput,
+ compile_modules_parallel,
+ compile_modules_sequential,
+ config_manager,
+ set_module_device_ids,
+)
+from QEfficient.generation.cloud_infer import QAICInferenceSession
+from QEfficient.utils.logging_utils import logger
+
+
+class QEFFFluxPipeline(FluxPipeline):
+ """
+ QEfficient-optimized Flux pipeline for text-to-image generation on Qualcomm AI hardware.
+
+ Attributes:
+ text_encoder (QEffTextEncoder): Optimized CLIP text encoder
+ text_encoder_2 (QEffTextEncoder): Optimized T5 text encoder
+ transformer (QEffFluxTransformerModel): Optimized Flux transformer
+ vae_decode (QEffVAE): Optimized VAE decoder
+ modules (Dict): Dictionary of all pipeline modules for iteration
+ """
+
+ _hf_auto_class = FluxPipeline
+
+ def __init__(self, model, use_onnx_function: bool, *args, **kwargs):
+ """
+ Initialize the QEfficient Flux pipeline.
+
+ Args:
+ model: Pre-loaded FluxPipeline model
+ use_onnx_function (bool): Whether to export transformer blocks as ONNX functions
+ **kwargs: Additional arguments including height and width
+ """
+ # Wrap model components with QEfficient optimized versions
+ self.text_encoder = QEffTextEncoder(model.text_encoder)
+ self.text_encoder_2 = QEffTextEncoder(model.text_encoder_2)
+ self.transformer = QEffFluxTransformerModel(model.transformer, use_onnx_function=use_onnx_function)
+ self.vae_decode = QEffVAE(model, "decoder")
+ self.use_onnx_function = use_onnx_function
+
+ # Store all modules in a dictionary for easy iteration during export/compile
+ self.modules = {
+ "text_encoder": self.text_encoder,
+ "text_encoder_2": self.text_encoder_2,
+ "transformer": self.transformer,
+ "vae_decoder": self.vae_decode,
+ }
+
+ # Copy tokenizers and scheduler from the original model
+ self.tokenizer = model.tokenizer
+ self.text_encoder.tokenizer = model.tokenizer
+ self.text_encoder_2.tokenizer = model.tokenizer_2
+ self.tokenizer_max_length = model.tokenizer_max_length
+ self.scheduler = model.scheduler
+
+ # Set default image dimensions
+ self.height = kwargs.get("height", 256)
+ self.width = kwargs.get("width", 256)
+
+ # Override VAE forward method to use decode directly
+ self.vae_decode.model.forward = lambda latent_sample, return_dict: self.vae_decode.model.decode(
+ latent_sample, return_dict
+ )
+
+ # Calculate VAE scale factor from model config
+ self.vae_scale_factor = (
+ 2 ** (len(model.vae.config.block_out_channels) - 1) if getattr(model, "vae", None) else 8
+ )
+
+ # Flux uses 2x2 patches, so multiply scale factor by patch size
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+
+ # Set tokenizer max length with fallback
+ self.t_max_length = (
+ model.tokenizer.model_max_length if hasattr(model, "tokenizer") and model.tokenizer is not None else 77
+ )
+
+ # Calculate latent dimensions based on image size and VAE scale factor
+ self.default_sample_size = 128
+ self.latent_height = self.height // self.vae_scale_factor
+ self.latent_width = self.width // self.vae_scale_factor
+ # cl = compressed latent dimension (divided by 4 for Flux's 2x2 packing)
+ self.cl = (self.latent_height * self.latent_width) // 4
+
+ # Sync max position embeddings between text encoders
+ self.text_encoder_2.model.config.max_position_embeddings = (
+ self.text_encoder.model.config.max_position_embeddings
+ )
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+ use_onnx_function: bool = False,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ **kwargs,
+ ):
+ """
+ Load a pretrained Flux model and wrap it with QEfficient optimizations.
+
+ Args:
+ pretrained_model_name_or_path (str or os.PathLike): HuggingFace model ID or local path
+ use_onnx_function (bool): Whether to export transformer blocks as ONNX functions
+ height (int): Target image height (default: 512)
+ width (int): Target image width (default: 512)
+ **kwargs: Additional arguments passed to FluxPipeline.from_pretrained
+
+ Returns:
+ QEFFFluxPipeline: Initialized pipeline instance
+ """
+ # Load the base Flux model in float32 on CPU
+ model = cls._hf_auto_class.from_pretrained(
+ pretrained_model_name_or_path,
+ torch_dtype=torch.float32,
+ **kwargs,
+ )
+ model.to("cpu")
+
+ return cls(
+ model=model,
+ use_onnx_function=use_onnx_function,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ height=height,
+ width=width,
+ **kwargs,
+ )
+
+ def export(self, export_dir: Optional[str] = None) -> str:
+ """
+ Export all pipeline modules to ONNX format.
+
+ This method iterates through all modules (text encoders, transformer, VAE decoder)
+ and exports each to ONNX using their respective configurations.
+
+ Args:
+ export_dir (str, optional): Directory to save ONNX models. If None, uses default.
+
+ Returns:
+ str: Path to the export directory
+ """
+ for module_name, module_obj in tqdm(self.modules.items(), desc="Exporting modules", unit="module"):
+ # Get ONNX export configuration for this module
+ example_inputs, dynamic_axes, output_names = module_obj.get_onnx_config()
+
+ export_kwargs = {}
+ # Special handling for transformer: export blocks as functions if enabled
+ if module_name == "transformer" and self.use_onnx_function:
+ export_kwargs = {
+ "export_modules_as_functions": self.transformer.model._block_classes,
+ }
+
+ # Export the module to ONNX
+ module_obj.export(
+ inputs=example_inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ export_kwargs=export_kwargs,
+ )
+
+ @staticmethod
+ def get_default_config_path() -> str:
+ """
+ Get the path to the default Flux pipeline configuration file.
+
+ Returns:
+ str: Absolute path to flux_config.json
+ """
+ return os.path.join(os.path.dirname(__file__), "flux_config.json")
+
+ def compile(self, compile_config: Optional[str] = None, parallel: bool = False) -> None:
+ """
+ Compile ONNX models for deployment on Qualcomm AI hardware.
+
+ This method compiles all pipeline modules (text encoders, transformer, VAE decoder)
+ into optimized QPC (Qualcomm Program Container) format for inference on QAIC devices.
+
+ Args:
+ compile_config (str, optional): Path to JSON configuration file.
+ If None, uses default configuration.
+ parallel (bool): If True, compile modules in parallel using ProcessPoolExecutor.
+ If False, compile sequentially (default: False).
+ """
+ # Ensure all modules are exported to ONNX before compilation
+ if any(
+ path is None
+ for path in [
+ self.text_encoder.onnx_path,
+ self.text_encoder_2.onnx_path,
+ self.transformer.onnx_path,
+ self.vae_decode.onnx_path,
+ ]
+ ):
+ self.export()
+
+ # Load compilation configuration
+ if self.custom_config is None:
+ config_manager(self, config_source=compile_config)
+
+ # Prepare dynamic specialization updates based on image dimensions
+ specialization_updates = {
+ "transformer": {"cl": self.cl},
+ "vae_decoder": {
+ "latent_height": self.latent_height,
+ "latent_width": self.latent_width,
+ },
+ }
+
+ # Use generic utility functions for compilation
+ if parallel:
+ compile_modules_parallel(self.modules, self.custom_config, specialization_updates)
+ else:
+ compile_modules_sequential(self.modules, self.custom_config, specialization_updates)
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device_ids: Optional[List[int]] = None,
+ ):
+ """
+ Encode prompts using the T5 text encoder.
+
+ Args:
+ prompt (str or List[str]): Input prompt(s) to encode
+ num_images_per_prompt (int): Number of images to generate per prompt
+ max_sequence_length (int): Maximum token sequence length (default: 512)
+ device_ids (List[int], optional): QAIC device IDs for inference
+
+ Returns:
+ tuple: (prompt_embeds, inference_time)
+ - prompt_embeds (torch.Tensor): Encoded embeddings [batch*num_images, seq_len, 4096]
+ - inference_time (float): T5 encoder inference time in seconds
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+ embed_dim = 4096 # T5 embedding dimension
+
+ # Tokenize prompts with padding and truncation
+ text_inputs = self.text_encoder_2.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ # Check for truncation and warn user
+ untruncated_ids = self.text_encoder_2.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.text_encoder_2.tokenizer.batch_decode(
+ untruncated_ids[:, self.text_encoder_2.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ f"The following part of your input was truncated because `max_sequence_length` is set to "
+ f"{self.text_encoder_2.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ # Initialize QAIC inference session if not already created
+ if self.text_encoder_2.qpc_session is None:
+ self.text_encoder_2.qpc_session = QAICInferenceSession(
+ str(self.text_encoder_2.qpc_path), device_ids=device_ids
+ )
+
+ # Allocate output buffers for QAIC inference
+ text_encoder_2_output = {
+ "last_hidden_state": np.random.rand(batch_size, max_sequence_length, embed_dim).astype(np.float32),
+ }
+ self.text_encoder_2.qpc_session.set_buffers(text_encoder_2_output)
+
+ # Prepare input for QAIC inference
+ aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)}
+
+ # Run T5 encoder inference and measure time
+ start_t5_time = time.time()
+ prompt_embeds = torch.tensor(self.text_encoder_2.qpc_session.run(aic_text_input)["last_hidden_state"])
+ end_t5_time = time.time()
+ text_encoder_2_perf = end_t5_time - start_t5_time
+
+ # Duplicate embeddings for multiple images per prompt
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds, text_encoder_2_perf
+
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device_ids: Optional[List[int]] = None,
+ ):
+ """
+ Encode prompts using the CLIP text encoder.
+
+ Args:
+ prompt (str or List[str]): Input prompt(s) to encode
+ num_images_per_prompt (int): Number of images to generate per prompt
+ device_ids (List[int], optional): QAIC device IDs for inference
+
+ Returns:
+ tuple: (pooled_prompt_embeds, inference_time)
+ - pooled_prompt_embeds (torch.Tensor): Pooled embeddings [batch*num_images, 768]
+ - inference_time (float): CLIP encoder inference time in seconds
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+ embed_dim = 768 # CLIP embedding dimension
+
+ # Tokenize prompts
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+
+ # Check for truncation and warn user
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ f"The following part of your input was truncated because CLIP can only handle sequences up to "
+ f"{self.tokenizer_max_length} tokens: {removed_text}"
+ )
+
+ # Initialize QAIC inference session if not already created
+ if self.text_encoder.qpc_session is None:
+ self.text_encoder.qpc_session = QAICInferenceSession(str(self.text_encoder.qpc_path), device_ids=device_ids)
+
+ # Allocate output buffers for QAIC inference
+ text_encoder_output = {
+ "last_hidden_state": np.random.rand(batch_size, self.tokenizer_max_length, embed_dim).astype(np.float32),
+ "pooler_output": np.random.rand(batch_size, embed_dim).astype(np.float32),
+ }
+ self.text_encoder.qpc_session.set_buffers(text_encoder_output)
+
+ # Prepare input for QAIC inference
+ aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)}
+
+ # Run CLIP encoder inference and measure time
+ start_text_encoder_time = time.time()
+ aic_embeddings = self.text_encoder.qpc_session.run(aic_text_input)
+ end_text_encoder_time = time.time()
+ text_encoder_perf = end_text_encoder_time - start_text_encoder_time
+
+ # Extract pooled output (used for conditioning in Flux)
+ prompt_embeds = torch.tensor(aic_embeddings["pooler_output"])
+
+ # Duplicate embeddings for multiple images per prompt
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds, text_encoder_perf
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ ):
+ """
+ Encode prompts using both CLIP and T5 text encoders.
+
+ Flux uses a dual text encoder setup:
+ - CLIP provides pooled embeddings for global conditioning
+ - T5 provides sequence embeddings for detailed text understanding
+
+ Args:
+ prompt (str or List[str]): Primary prompt(s)
+ prompt_2 (str or List[str], optional): Secondary prompt(s) for T5. If None, uses primary prompt
+ num_images_per_prompt (int): Number of images to generate per prompt
+ prompt_embeds (torch.FloatTensor, optional): Pre-computed T5 embeddings
+ pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed CLIP pooled embeddings
+ max_sequence_length (int): Maximum sequence length for T5 tokenization
+
+ Returns:
+ tuple: (prompt_embeds, pooled_prompt_embeds, text_ids, encoder_perf_times)
+ - prompt_embeds: T5 sequence embeddings
+ - pooled_prompt_embeds: CLIP pooled embeddings
+ - text_ids: Position IDs for text tokens
+ - encoder_perf_times: List of [CLIP_time, T5_time]
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ # Use primary prompt for both encoders if secondary not provided
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # Encode with CLIP (returns pooled embeddings)
+ pooled_prompt_embeds, text_encoder_perf = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device_ids=self.text_encoder.device_ids,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+
+ # Encode with T5 (returns sequence embeddings)
+ prompt_embeds, text_encoder_2_perf = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device_ids=self.text_encoder_2.device_ids,
+ )
+
+ # Create text position IDs (required by Flux transformer)
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids, [text_encoder_perf, text_encoder_2_perf]
+
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ custom_config_path: Optional[str] = None,
+ parallel_compile: bool = False,
+ ):
+ """
+ Generate images from text prompts using the Flux pipeline.
+
+ This is the main entry point for image generation. It orchestrates the entire pipeline:
+ 1. Validates inputs and loads configuration
+ 2. Encodes prompts using CLIP and T5
+ 3. Prepares latents and timesteps
+ 4. Runs denoising loop with transformer
+ 5. Decodes latents to images with VAE
+
+ Args:
+ prompt (str or List[str]): Text prompt(s) for image generation
+ prompt_2 (str or List[str], optional): Secondary prompt for T5 encoder
+ negative_prompt (str or List[str], optional): Negative prompt for classifier-free guidance
+ negative_prompt_2 (str or List[str], optional): Secondary negative prompt
+ true_cfg_scale (float): True CFG scale (default: 1.0, disabled)
+ num_inference_steps (int): Number of denoising steps (default: 28)
+ timesteps (List[int], optional): Custom timestep schedule
+ guidance_scale (float): Guidance scale for generation (default: 3.5)
+ num_images_per_prompt (int): Number of images per prompt (default: 1)
+ generator (torch.Generator, optional): Random generator for reproducibility
+ latents (torch.FloatTensor, optional): Pre-generated latents
+ prompt_embeds (torch.FloatTensor, optional): Pre-computed prompt embeddings
+ pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed pooled embeddings
+ negative_prompt_embeds (torch.FloatTensor, optional): Pre-computed negative embeddings
+ negative_pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed negative pooled embeddings
+ output_type (str): Output format - "pil", "np", or "latent" (default: "pil")
+ return_dict (bool): Whether to return QEffPipelineOutput object (default: True)
+ joint_attention_kwargs (dict, optional): Additional attention processor kwargs
+ callback_on_step_end (Callable, optional): Callback function after each step
+ callback_on_step_end_tensor_inputs (List[str]): Tensors to pass to callback
+ max_sequence_length (int): Maximum sequence length for T5 (default: 512)
+ custom_config_path (str, optional): Path to custom compilation config
+ parallel_compile (bool): If True, compile modules in parallel for faster compilation.
+ If False, compile sequentially (default: False).
+
+ Returns:
+ QEffPipelineOutput or tuple: Generated images and performance metrics
+ """
+ device = "cpu"
+
+ # Step 1: Load configuration and compile models if needed
+ if custom_config_path is not None:
+ config_manager(self, custom_config_path)
+ set_module_device_ids(self)
+
+ self.compile(compile_config=custom_config_path, parallel=parallel_compile)
+
+ # Validate all inputs
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ self.height,
+ self.width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._interrupt = False
+
+ # Step 2: Determine batch size from inputs
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Step 3: Encode prompts with both text encoders
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+
+ (prompt_embeds, pooled_prompt_embeds, text_ids, text_encoder_perf) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # Encode negative prompts if using true classifier-free guidance
+ if do_true_cfg:
+ (
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ negative_text_ids,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # Step 4: Prepare timesteps for denoising
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # Step 5: Prepare initial latents
+ num_channels_latents = self.transformer.model.config.in_channels // 4
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ self.height,
+ self.width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # Step 6: Initialize transformer inference session
+ if self.transformer.qpc_session is None:
+ self.transformer.qpc_session = QAICInferenceSession(
+ str(self.transformer.qpc_path), device_ids=self.transformer.device_ids
+ )
+
+ # Allocate output buffer for transformer
+ output_buffer = {
+ "output": np.random.rand(batch_size, self.cl, self.transformer.model.config.in_channels).astype(np.float32),
+ }
+ self.transformer.qpc_session.set_buffers(output_buffer)
+
+ transformer_perf = []
+ self.scheduler.set_begin_index(0)
+
+ # Step 7: Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # Prepare timestep embedding
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ temb = self.transformer.model.time_text_embed(timestep, pooled_prompt_embeds)
+
+ # Compute AdaLN (Adaptive Layer Normalization) embeddings for dual transformer blocks
+ adaln_emb = []
+ for block_idx in range(len(self.transformer.model.transformer_blocks)):
+ block = self.transformer.model.transformer_blocks[block_idx]
+ # Process through norm1 and norm1_context
+ f1 = block.norm1.linear(block.norm1.silu(temb)).chunk(6, dim=1)
+ f2 = block.norm1_context.linear(block.norm1_context.silu(temb)).chunk(6, dim=1)
+ adaln_emb.append(torch.cat(list(f1) + list(f2)))
+ adaln_dual_emb = torch.stack(adaln_emb)
+
+ # Compute AdaLN embeddings for single transformer blocks
+ adaln_emb = []
+ for block_idx in range(len(self.transformer.model.single_transformer_blocks)):
+ block = self.transformer.model.single_transformer_blocks[block_idx]
+ f1 = block.norm.linear(block.norm.silu(temb)).chunk(3, dim=1)
+ adaln_emb.append(torch.cat(list(f1)))
+ adaln_single_emb = torch.stack(adaln_emb)
+
+ # Compute output AdaLN embedding
+ temp = self.transformer.model.norm_out
+ adaln_out = temp.linear(temp.silu(temb))
+
+ # Normalize timestep to [0, 1] range
+ timestep = timestep / 1000
+
+ # Prepare all inputs for transformer inference
+ inputs_aic = {
+ "hidden_states": latents.detach().numpy(),
+ "encoder_hidden_states": prompt_embeds.detach().numpy(),
+ "pooled_projections": pooled_prompt_embeds.detach().numpy(),
+ "timestep": timestep.detach().numpy(),
+ "img_ids": latent_image_ids.detach().numpy(),
+ "txt_ids": text_ids.detach().numpy(),
+ "adaln_emb": adaln_dual_emb.detach().numpy(),
+ "adaln_single_emb": adaln_single_emb.detach().numpy(),
+ "adaln_out": adaln_out.detach().numpy(),
+ }
+
+ # Run transformer inference and measure time
+ start_transformer_step_time = time.time()
+ outputs = self.transformer.qpc_session.run(inputs_aic)
+ end_transformer_step_time = time.time()
+ transformer_perf.append(end_transformer_step_time - start_transformer_step_time)
+
+ noise_pred = torch.from_numpy(outputs["output"])
+
+ # Update latents using scheduler (x_t -> x_t-1)
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # Handle dtype mismatch (workaround for MPS backend bug)
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ latents = latents.to(latents_dtype)
+
+ # Execute callback if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # Update progress bar
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ # Step 8: Decode latents to images (unless output_type is "latent")
+ if output_type == "latent":
+ image = latents
+ else:
+ # Unpack and denormalize latents
+ latents = self._unpack_latents(latents, self.height, self.width, self.vae_scale_factor)
+ latents = (latents / self.vae_decode.model.scaling_factor) + self.vae_decode.model.shift_factor
+
+ # Initialize VAE decoder inference session
+ if self.vae_decode.qpc_session is None:
+ self.vae_decode.qpc_session = QAICInferenceSession(
+ str(self.vae_decode.qpc_path), device_ids=self.vae_decode.device_ids
+ )
+
+ # Allocate output buffer for VAE decoder
+ output_buffer = {"sample": np.random.rand(batch_size, 3, self.height, self.width).astype(np.int32)}
+ self.vae_decode.qpc_session.set_buffers(output_buffer)
+
+ # Run VAE decoder inference and measure time
+ inputs = {"latent_sample": latents.numpy()}
+ start_decode_time = time.time()
+ image = self.vae_decode.qpc_session.run(inputs)
+ end_decode_time = time.time()
+ vae_decode_perf = end_decode_time - start_decode_time
+
+ # Post-process image
+ image_tensor = torch.from_numpy(image["sample"])
+ image = self.image_processor.postprocess(image_tensor, output_type=output_type)
+
+ # Build performance metrics
+ perf_metrics = [
+ ModulePerf(module_name="text_encoder", perf=text_encoder_perf[0]),
+ ModulePerf(module_name="text_encoder_2", perf=text_encoder_perf[1]),
+ ModulePerf(module_name="transformer", perf=transformer_perf),
+ ModulePerf(module_name="vae_decoder", perf=vae_decode_perf),
+ ]
+
+ return QEffPipelineOutput(
+ pipeline_module=perf_metrics,
+ images=image,
+ )
diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py
new file mode 100644
index 000000000..224124b90
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/pipeline_module.py
@@ -0,0 +1,538 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+import copy
+from typing import Dict, List, Tuple
+
+import torch
+import torch.nn as nn
+
+from QEfficient.base.modeling_qeff import QEFFBaseModel
+from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
+from QEfficient.diffusers.models.pytorch_transforms import (
+ AttentionTransform,
+ CustomOpsTransform,
+ NormalizationTransform,
+ OnnxFunctionTransform,
+)
+from QEfficient.transformers.models.pytorch_transforms import (
+ T5ModelTransform,
+)
+from QEfficient.utils import constants
+
+
+class QEffTextEncoder(QEFFBaseModel):
+ """
+ Wrapper for text encoder models with ONNX export and QAIC compilation capabilities.
+
+ This class handles text encoder models (CLIP, T5) with specific transformations and
+ optimizations for efficient inference on Qualcomm AI hardware. It applies custom
+ PyTorch and ONNX transformations to prepare models for deployment.
+
+ Attributes:
+ model (nn.Module): The wrapped text encoder model (deep copy of original)
+ _pytorch_transforms (List): PyTorch transformations applied before ONNX export
+ _onnx_transforms (List): ONNX transformations applied after export
+ """
+
+ _pytorch_transforms = [CustomOpsTransform, T5ModelTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ def __init__(self, model: nn.Module) -> None:
+ """
+ Initialize the text encoder wrapper.
+
+ Args:
+ model (nn.Module): The text encoder model to wrap (CLIP or T5)
+ """
+ super().__init__(model)
+ self.model = copy.deepcopy(model)
+
+ def get_onnx_config(self) -> Tuple[Dict, Dict, List[str]]:
+ """
+ Generate ONNX export configuration for the text encoder.
+
+ Creates example inputs, dynamic axes specifications, and output names
+ tailored to the specific text encoder type (CLIP vs T5).
+
+ Returns:
+ Tuple containing:
+ - example_inputs (Dict): Sample inputs for ONNX export
+ - dynamic_axes (Dict): Specification of dynamic dimensions
+ - output_names (List[str]): Names of model outputs
+ """
+ bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
+
+ # Create example input with max sequence length
+ example_inputs = {
+ "input_ids": torch.zeros((bs, self.model.config.max_position_embeddings), dtype=torch.int64),
+ }
+
+ # Define which dimensions can vary at runtime
+ dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}}
+
+ # T5 only outputs hidden states, CLIP outputs both hidden states and pooled output
+ if self.model.__class__.__name__ == "T5EncoderModel":
+ output_names = ["last_hidden_state"]
+ else:
+ output_names = ["last_hidden_state", "pooler_output"]
+ example_inputs["output_hidden_states"] = False
+
+ return example_inputs, dynamic_axes, output_names
+
+ def export(
+ self,
+ inputs: Dict,
+ output_names: List[str],
+ dynamic_axes: Dict,
+ export_dir: str = None,
+ export_kwargs: Dict = None,
+ ) -> str:
+ """
+ Export the text encoder model to ONNX format.
+
+ Args:
+ inputs (Dict): Example inputs for ONNX export
+ output_names (List[str]): Names of model outputs
+ dynamic_axes (Dict): Specification of dynamic dimensions
+ export_dir (str, optional): Directory to save ONNX model
+ export_kwargs (Dict, optional): Additional export arguments
+
+ Returns:
+ str: Path to the exported ONNX model
+ """
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ export_kwargs=export_kwargs,
+ )
+
+ def compile(self, specializations: List[Dict], **compiler_options) -> None:
+ """
+ Compile the ONNX model for Qualcomm AI hardware.
+
+ Args:
+ specializations (List[Dict]): Model specialization configurations
+ **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations)
+ """
+ self._compile(specializations=specializations, **compiler_options)
+
+
+class QEffUNet(QEFFBaseModel):
+ """
+ Wrapper for UNet models with ONNX export and QAIC compilation capabilities.
+
+ This class handles UNet models with specific transformations and optimizations
+ for efficient inference on Qualcomm AI hardware. UNet is commonly used in
+ diffusion models for image generation tasks.
+
+ Attributes:
+ model (nn.Module): The wrapped UNet model
+ _pytorch_transforms (List): PyTorch transformations applied before ONNX export
+ _onnx_transforms (List): ONNX transformations applied after export
+ """
+
+ _pytorch_transforms = [CustomOpsTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ def __init__(self, model: nn.Module) -> None:
+ """
+ Initialize the UNet wrapper.
+
+ Args:
+ model (nn.Module): The pipeline model containing the UNet
+ """
+ super().__init__(model.unet)
+ self.model = model.unet
+
+ def export(
+ self,
+ inputs: Dict,
+ output_names: List[str],
+ dynamic_axes: Dict,
+ export_dir: str = None,
+ export_kwargs: Dict = None,
+ ) -> str:
+ """
+ Export the UNet model to ONNX format.
+
+ Args:
+ inputs (Dict): Example inputs for ONNX export
+ output_names (List[str]): Names of model outputs
+ dynamic_axes (Dict): Specification of dynamic dimensions
+ export_dir (str, optional): Directory to save ONNX model
+ export_kwargs (Dict, optional): Additional export arguments
+
+ Returns:
+ str: Path to the exported ONNX model
+ """
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ export_kwargs=export_kwargs,
+ )
+
+ def compile(self, specializations: List[Dict], **compiler_options) -> None:
+ """
+ Compile the ONNX model for Qualcomm AI hardware.
+
+ Args:
+ specializations (List[Dict]): Model specialization configurations
+ **compiler_options: Additional compiler options
+ """
+ self._compile(specializations=specializations, **compiler_options)
+
+
+class QEffVAE(QEFFBaseModel):
+ """
+ Wrapper for Variational Autoencoder (VAE) models with ONNX export and QAIC compilation.
+
+ This class handles VAE models with specific transformations and optimizations
+ for efficient inference on Qualcomm AI hardware. VAE models are used in diffusion
+ pipelines for encoding images to latent space and decoding latents back to images.
+
+ Attributes:
+ model (nn.Module): The wrapped VAE model (deep copy of original)
+ type (str): VAE operation type ("encoder" or "decoder")
+ _pytorch_transforms (List): PyTorch transformations applied before ONNX export
+ _onnx_transforms (List): ONNX transformations applied after export
+ """
+
+ _pytorch_transforms = [CustomOpsTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ def __init__(self, model: nn.Module, type: str) -> None:
+ """
+ Initialize the VAE wrapper.
+
+ Args:
+ model (nn.Module): The pipeline model containing the VAE
+ type (str): VAE operation type ("encoder" or "decoder")
+ """
+ super().__init__(model.vae)
+ self.model = copy.deepcopy(model.vae)
+ self.type = type
+
+ def get_onnx_config(self, latent_height: int = 32, latent_width: int = 32) -> Tuple[Dict, Dict, List[str]]:
+ """
+ Generate ONNX export configuration for the VAE decoder.
+
+ Args:
+ latent_height (int): Height of latent representation (default: 32)
+ latent_width (int): Width of latent representation (default: 32)
+
+ Returns:
+ Tuple containing:
+ - example_inputs (Dict): Sample inputs for ONNX export
+ - dynamic_axes (Dict): Specification of dynamic dimensions
+ - output_names (List[str]): Names of model outputs
+ """
+ bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
+
+ # VAE decoder takes latent representation as input
+ example_inputs = {
+ "latent_sample": torch.randn(bs, 16, latent_height, latent_width),
+ "return_dict": False,
+ }
+
+ output_names = ["sample"]
+
+ # All dimensions except channels can be dynamic
+ dynamic_axes = {
+ "latent_sample": {0: "batch_size", 1: "channels", 2: "latent_height", 3: "latent_width"},
+ }
+
+ return example_inputs, dynamic_axes, output_names
+
+ def export(
+ self,
+ inputs: Dict,
+ output_names: List[str],
+ dynamic_axes: Dict,
+ export_dir: str = None,
+ export_kwargs: Dict = None,
+ ) -> str:
+ """
+ Export the VAE model to ONNX format.
+
+ Args:
+ inputs (Dict): Example inputs for ONNX export
+ output_names (List[str]): Names of model outputs
+ dynamic_axes (Dict): Specification of dynamic dimensions
+ export_dir (str, optional): Directory to save ONNX model
+ export_kwargs (Dict, optional): Additional export arguments
+
+ Returns:
+ str: Path to the exported ONNX model
+ """
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ export_kwargs=export_kwargs,
+ )
+
+ def compile(self, specializations: List[Dict], **compiler_options) -> None:
+ """
+ Compile the ONNX model for Qualcomm AI hardware.
+
+ Args:
+ specializations (List[Dict]): Model specialization configurations
+ **compiler_options: Additional compiler options
+ """
+ self._compile(specializations=specializations, **compiler_options)
+
+
+class QEffSafetyChecker(QEFFBaseModel):
+ """
+ Wrapper for safety checker models with ONNX export and QAIC compilation capabilities.
+
+ This class handles safety checker models with specific transformations and optimizations
+ for efficient inference on Qualcomm AI hardware. Safety checkers are used in diffusion
+ pipelines to filter out potentially harmful or inappropriate generated content.
+
+ Attributes:
+ model (nn.Module): The wrapped safety checker model
+ _pytorch_transforms (List): PyTorch transformations applied before ONNX export
+ _onnx_transforms (List): ONNX transformations applied after export
+ """
+
+ _pytorch_transforms = [CustomOpsTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ def __init__(self, model: nn.Module) -> None:
+ """
+ Initialize the safety checker wrapper.
+
+ Args:
+ model (nn.Module): The pipeline model containing the safety checker
+ """
+ super().__init__(model.safety_checker)
+ self.model = model.safety_checker
+
+ def export(
+ self,
+ inputs: Dict,
+ output_names: List[str],
+ dynamic_axes: Dict,
+ export_dir: str = None,
+ export_kwargs: Dict = None,
+ ) -> str:
+ """
+ Export the safety checker model to ONNX format.
+
+ Args:
+ inputs (Dict): Example inputs for ONNX export
+ output_names (List[str]): Names of model outputs
+ dynamic_axes (Dict): Specification of dynamic dimensions
+ export_dir (str, optional): Directory to save ONNX model
+ export_kwargs (Dict, optional): Additional export arguments
+
+ Returns:
+ str: Path to the exported ONNX model
+ """
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ export_kwargs=export_kwargs,
+ )
+
+ def compile(self, specializations: List[Dict], **compiler_options) -> None:
+ """
+ Compile the ONNX model for Qualcomm AI hardware.
+
+ Args:
+ specializations (List[Dict]): Model specialization configurations
+ **compiler_options: Additional compiler options
+ """
+ self._compile(specializations=specializations, **compiler_options)
+
+
+class QEffFluxTransformerModel(QEFFBaseModel):
+ """
+ Wrapper for Flux Transformer2D models with ONNX export and QAIC compilation capabilities.
+
+ This class handles Flux Transformer2D models with specific transformations and optimizations
+ for efficient inference on Qualcomm AI hardware. Flux uses a transformer-based diffusion
+ architecture instead of traditional UNet, with dual transformer blocks and adaptive layer
+ normalization (AdaLN) for conditioning.
+
+ Attributes:
+ model (nn.Module): The wrapped Flux transformer model
+ _pytorch_transforms (List): PyTorch transformations applied before ONNX export
+ _onnx_transforms (List): ONNX transformations applied after export
+ """
+
+ _pytorch_transforms = [AttentionTransform, NormalizationTransform, CustomOpsTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ def __init__(self, model: nn.Module, use_onnx_function: bool) -> None:
+ """
+ Initialize the Flux transformer wrapper.
+
+ Args:
+ model (nn.Module): The Flux transformer model to wrap
+ use_onnx_function (bool): Whether to export transformer blocks as ONNX functions
+ for better modularity and potential optimization
+ """
+
+ # Optionally apply ONNX function transform for modular export
+
+ if use_onnx_function:
+ model, _ = OnnxFunctionTransform.apply(model)
+
+ super().__init__(model)
+
+ if use_onnx_function:
+ self._pytorch_transforms.append(OnnxFunctionTransform)
+
+ # Ensure model is on CPU to avoid meta device issues
+ self.model = model.to("cpu")
+
+ def get_onnx_config(
+ self, batch_size: int = 1, seq_length: int = 256, cl: int = 4096
+ ) -> Tuple[Dict, Dict, List[str]]:
+ """
+ Generate ONNX export configuration for the Flux transformer.
+
+ Creates example inputs for all Flux-specific inputs including hidden states,
+ text embeddings, timestep conditioning, and AdaLN embeddings.
+
+ Args:
+ batch_size (int): Batch size for example inputs (default: 1)
+ seq_length (int): Text sequence length (default: 256)
+ cl (int): Compressed latent dimension (default: 4096)
+
+ Returns:
+ Tuple containing:
+ - example_inputs (Dict): Sample inputs for ONNX export
+ - dynamic_axes (Dict): Specification of dynamic dimensions
+ - output_names (List[str]): Names of model outputs
+ """
+ example_inputs = {
+ # Latent representation of the image
+ "hidden_states": torch.randn(batch_size, cl, self.model.config.in_channels, dtype=torch.float32),
+ # Text embeddings from T5 encoder
+ "encoder_hidden_states": torch.randn(
+ batch_size, seq_length, self.model.config.joint_attention_dim, dtype=torch.float32
+ ),
+ # Pooled text embeddings from CLIP encoder
+ "pooled_projections": torch.randn(batch_size, self.model.config.pooled_projection_dim, dtype=torch.float32),
+ # Diffusion timestep (normalized to [0, 1])
+ "timestep": torch.tensor([1.0], dtype=torch.float32),
+ # Position IDs for image patches
+ "img_ids": torch.randn(cl, 3, dtype=torch.float32),
+ # Position IDs for text tokens
+ "txt_ids": torch.randn(seq_length, 3, dtype=torch.float32),
+ # AdaLN embeddings for dual transformer blocks
+ # Shape: [num_layers, 12 chunks (6 for norm1 + 6 for norm1_context), hidden_dim]
+ "adaln_emb": torch.randn(
+ self.model.config.num_layers,
+ 12, # 6 chunks for norm1 + 6 chunks for norm1_context
+ 3072, # AdaLN hidden dimension
+ dtype=torch.float32,
+ ),
+ # AdaLN embeddings for single transformer blocks
+ # Shape: [num_single_layers, 3 chunks, hidden_dim]
+ "adaln_single_emb": torch.randn(
+ self.model.config.num_single_layers,
+ 3, # 3 chunks for single block norm
+ 3072, # AdaLN hidden dimension
+ dtype=torch.float32,
+ ),
+ # Output AdaLN embedding
+ # Shape: [batch_size, 2 * hidden_dim] for final projection
+ "adaln_out": torch.randn(batch_size, 6144, dtype=torch.float32), # 2 * 3072
+ }
+
+ output_names = ["output"]
+
+ # Define dynamic dimensions for runtime flexibility
+ dynamic_axes = {
+ "hidden_states": {0: "batch_size", 1: "cl"},
+ "encoder_hidden_states": {0: "batch_size", 1: "seq_len"},
+ "pooled_projections": {0: "batch_size"},
+ "timestep": {0: "steps"},
+ "img_ids": {0: "cl"},
+ }
+
+ return example_inputs, dynamic_axes, output_names
+
+ def export(
+ self,
+ inputs: Dict,
+ output_names: List[str],
+ dynamic_axes: Dict,
+ export_dir: str = None,
+ export_kwargs: Dict = None,
+ ) -> str:
+ """
+ Export the Flux transformer model to ONNX format.
+
+ Args:
+ inputs (Dict): Example inputs for ONNX export
+ output_names (List[str]): Names of model outputs
+ dynamic_axes (Dict): Specification of dynamic dimensions
+ export_dir (str, optional): Directory to save ONNX model
+ export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions)
+
+ Returns:
+ str: Path to the exported ONNX model
+ """
+ return self._export(
+ example_inputs=inputs,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ export_dir=export_dir,
+ export_kwargs=export_kwargs,
+ )
+
+ def get_specializations(self, batch_size: int, seq_len: int, cl: int) -> List[Dict]:
+ """
+ Generate specialization configuration for compilation.
+
+ Specializations define fixed values for certain dimensions to enable
+ compiler optimizations specific to the target use case.
+
+ Args:
+ batch_size (int): Batch size for inference
+ seq_len (int): Text sequence length
+ cl (int): Compressed latent dimension
+
+ Returns:
+ List[Dict]: Specialization configurations for the compiler
+ """
+ specializations = [
+ {
+ "batch_size": batch_size,
+ "stats-batchsize": batch_size,
+ "num_layers": self.model.config.num_layers,
+ "num_single_layers": self.model.config.num_single_layers,
+ "seq_len": seq_len,
+ "cl": cl,
+ "steps": 1,
+ }
+ ]
+
+ return specializations
+
+ def compile(self, specializations: List[Dict], **compiler_options) -> None:
+ """
+ Compile the ONNX model for Qualcomm AI hardware.
+
+ Args:
+ specializations (List[Dict]): Model specialization configurations
+ **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations)
+ """
+ self._compile(specializations=specializations, **compiler_options)
diff --git a/QEfficient/diffusers/pipelines/pipeline_utils.py b/QEfficient/diffusers/pipelines/pipeline_utils.py
new file mode 100644
index 000000000..5c8c2ba2d
--- /dev/null
+++ b/QEfficient/diffusers/pipelines/pipeline_utils.py
@@ -0,0 +1,195 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+import os
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+from tqdm import tqdm
+
+from QEfficient.utils._utils import load_json
+from QEfficient.utils.logging_utils import logger
+
+
+def config_manager(cls, config_source: Optional[str] = None):
+ """
+ JSON-based compilation configuration manager for diffusion pipelines.
+
+ Supports loading configuration from JSON files only. Automatically detects
+ model type and handles model-specific requirements.
+ Initialize the configuration manager.
+
+ Args:
+ config_source: Path to JSON configuration file. If None, uses default config.
+ """
+ if config_source is None:
+ config_source = cls.get_default_config_path()
+
+ if not isinstance(config_source, str):
+ raise ValueError("config_source must be a path to JSON configuration file")
+
+ # Direct use of load_json utility - no wrapper needed
+ if not os.path.exists(config_source):
+ raise FileNotFoundError(f"Configuration file not found: {config_source}")
+
+ cls.custom_config = load_json(config_source)
+
+
+def set_module_device_ids(cls):
+ """
+ Set device IDs for each module based on the custom configuration.
+
+ Iterates through all modules in the pipeline and assigns device IDs
+ from the configuration file to each module's device_ids attribute.
+ """
+ config_modules = cls.custom_config["modules"]
+ for module_name, module_obj in cls.modules.items():
+ module_obj.device_ids = config_modules[module_name]["execute"]["device_ids"]
+
+
+def compile_modules_parallel(
+ modules: Dict[str, Any],
+ config: Dict[str, Any],
+ specialization_updates: Dict[str, Dict[str, Any]] = None,
+) -> None:
+ """
+ Compile multiple pipeline modules in parallel using ThreadPoolExecutor.
+
+ Args:
+ modules: Dictionary of module_name -> module_object pairs to compile
+ config: Configuration dictionary containing module-specific compilation settings
+ specialization_updates: Optional dictionary of module_name -> specialization_updates
+ to apply dynamic values (e.g., image dimensions)
+ """
+
+ def _prepare_and_compile(module_name: str, module_obj: Any) -> None:
+ """Prepare specializations and compile a single module."""
+ specializations = config["modules"][module_name]["specializations"].copy()
+ compile_kwargs = config["modules"][module_name]["compilation"]
+
+ if specialization_updates and module_name in specialization_updates:
+ specializations.update(specialization_updates[module_name])
+
+ module_obj.compile(specializations=[specializations], **compile_kwargs)
+
+ # Execute compilations in parallel
+ with ThreadPoolExecutor(max_workers=len(modules)) as executor:
+ futures = {executor.submit(_prepare_and_compile, name, obj): name for name, obj in modules.items()}
+
+ with tqdm(total=len(futures), desc="Compiling modules", unit="module") as pbar:
+ for future in as_completed(futures):
+ try:
+ future.result()
+ except Exception as e:
+ logger.error(f"Compilation failed for {futures[future]}: {e}")
+ raise
+ pbar.update(1)
+
+
+def compile_modules_sequential(
+ modules: Dict[str, Any],
+ config: Dict[str, Any],
+ specialization_updates: Dict[str, Dict[str, Any]] = None,
+) -> None:
+ """
+ Compile multiple pipeline modules sequentially.
+
+ This function provides a generic way to compile diffusion pipeline modules
+ sequentially, which is the default behavior for backward compatibility.
+
+ Args:
+ modules: Dictionary of module_name -> module_object pairs to compile
+ config: Configuration dictionary containing module-specific compilation settings
+ specialization_updates: Optional dictionary of module_name -> specialization_updates
+ to apply dynamic values (e.g., image dimensions)
+
+ """
+ for module_name, module_obj in tqdm(modules.items(), desc="Compiling modules", unit="module"):
+ module_config = config["modules"]
+ specializations = module_config[module_name]["specializations"].copy()
+ compile_kwargs = module_config[module_name]["compilation"]
+
+ # Apply dynamic specialization updates if provided
+ if specialization_updates and module_name in specialization_updates:
+ specializations.update(specialization_updates[module_name])
+
+ # Compile the module to QPC format
+ module_obj.compile(specializations=[specializations], **compile_kwargs)
+
+
+@dataclass(frozen=True)
+class ModulePerf:
+ """
+ Data class to store performance metrics for a pipeline module.
+
+ Attributes:
+ module_name: Name of the pipeline module (e.g., 'text_encoder', 'transformer', 'vae_decoder')
+ perf: Performance metric in seconds. Can be a single float for modules that run once,
+ or a list of floats for modules that run multiple times (e.g., transformer steps)
+ """
+
+ module_name: str
+ perf: int
+
+
+@dataclass(frozen=True)
+class QEffPipelineOutput:
+ """
+ Data class to store the output of a QEfficient diffusion pipeline.
+
+ Attributes:
+ pipeline_module: List of ModulePerf objects containing performance metrics for each module
+ images: Generated images as either a list of PIL Images or numpy array
+ """
+
+ pipeline_module: list[ModulePerf]
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+ def __repr__(self):
+ output_str = "=" * 60 + "\n"
+ output_str += "QEfficient Diffusers Pipeline Inference Report\n"
+ output_str += "=" * 60 + "\n\n"
+
+ # Module-wise inference times
+ output_str += "Module-wise Inference Times:\n"
+ output_str += "-" * 60 + "\n"
+
+ # Calculate E2E time while iterating
+ e2e_time = 0
+ for module_perf in self.pipeline_module:
+ module_name = module_perf.module_name
+ inference_time = module_perf.perf
+
+ # Add to E2E time
+ e2e_time += sum(inference_time) if isinstance(inference_time, list) else inference_time
+
+ # Format module name for display
+ display_name = module_name.replace("_", " ").title()
+
+ # Handle transformer specially as it has a list of times
+ if isinstance(inference_time, list) and len(inference_time) > 0:
+ total_time = sum(inference_time)
+ avg_time = total_time / len(inference_time)
+ output_str += f" {display_name:25s} {total_time:.4f} s\n"
+ output_str += f" - Total steps: {len(inference_time)}\n"
+ output_str += f" - Average per step: {avg_time:.4f} s\n"
+ output_str += f" - Min step time: {min(inference_time):.4f} s\n"
+ output_str += f" - Max step time: {max(inference_time):.4f} s\n"
+ else:
+ # Single inference time value
+ output_str += f" {display_name:25s} {inference_time:.4f} s\n"
+
+ output_str += "-" * 60 + "\n\n"
+
+ # Print E2E time after all modules
+ output_str += f"End-to-End Inference Time: {e2e_time:.4f} s\n\n"
+ output_str += "=" * 60 + "\n"
+
+ return output_str
diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py
index 60f60c768..aa49ef03b 100644
--- a/QEfficient/transformers/models/modeling_auto.py
+++ b/QEfficient/transformers/models/modeling_auto.py
@@ -122,21 +122,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path)
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying HuggingFace model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
class MultimodalUtilityMixin:
"""
@@ -302,18 +287,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **k
return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, pooling=pooling, **kwargs)
- @property
- def get_model_config(self) -> dict:
- """
- Get the model configuration as a dictionary.
-
- Returns
- -------
- dict
- The configuration dictionary of the underlying HuggingFace model.
- """
- return self.model.config.__dict__
-
def export(self, export_dir: Optional[str] = None) -> str:
"""
Export the model to ONNX format using ``torch.onnx.export``.
@@ -673,21 +646,6 @@ def compile(
**compiler_options,
)
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying vision encoder model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
@property
def get_model_config(self) -> dict:
"""
@@ -815,21 +773,6 @@ def compile(
**compiler_options,
)
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying language decoder model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
@property
def get_model_config(self) -> dict:
"""
@@ -886,21 +829,6 @@ def __init__(
self.continuous_batching = continuous_batching
self.input_shapes, self.output_names = None, None
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying multimodal model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
"""
@@ -1898,33 +1826,6 @@ def cloud_ai_100_generate(
),
)
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying multimodal model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
- @property
- def get_model_config(self) -> dict:
- """
- Get the configuration dictionary of the underlying HuggingFace model.
-
- Returns
- -------
- dict
- The configuration dictionary.
- """
- return self.model.config.__dict__
-
class QEFFAutoModelForImageTextToText:
"""
@@ -2182,21 +2083,6 @@ def __init__(
if self.is_tlm:
self.model.qaic_config["return_pdfs"] = True
- @property
- def model_name(self) -> str:
- """
- Get the name of the underlying Causal Language Model.
-
- Returns
- -------
- str
- The model's class name, with "QEff" or "QEFF" prefix removed if present.
- """
- mname = self.model.__class__.__name__
- if mname.startswith("QEff") or mname.startswith("QEFF"):
- mname = mname[4:]
- return mname
-
def __repr__(self) -> str:
return self.__class__.__name__ + "\n" + self.model.__repr__()
@@ -2283,18 +2169,6 @@ def from_pretrained(
**kwargs,
)
- @property
- def get_model_config(self) -> dict:
- """
- Get the model configuration as a dictionary.
-
- Returns
- -------
- dict
- The configuration dictionary of the underlying HuggingFace model.
- """
- return self.model.config.__dict__
-
def export(self, export_dir: Optional[str] = None) -> str:
"""
Export the model to ONNX format using ``torch.onnx.export``.
@@ -2931,18 +2805,6 @@ def __init__(self, model: nn.Module, **kwargs):
self.num_layers = model.config.num_hidden_layers
self.hash_params["qeff_auto_class"] = self.__class__.__name__
- @property
- def get_model_config(self) -> dict:
- """
- Get the configuration dictionary of the underlying HuggingFace model.
-
- Returns
- -------
- dict
- The configuration dictionary.
- """
- return self.model.config.__dict__
-
def export(self, export_dir: Optional[str] = None) -> str:
"""
Export the model to ONNX format using ``torch.onnx.export``.
@@ -3303,10 +3165,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **k
return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, pooling=pooling, **kwargs)
- @property
- def get_model_config(self) -> dict:
- return self.model.config.__dict__
-
def export(self, export_dir: Optional[str] = None) -> str:
"""
Exports the model to ``ONNX`` format using ``torch.onnx.export``.
diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py
index 773ce178c..0f2a5a5fd 100644
--- a/QEfficient/transformers/models/pytorch_transforms.py
+++ b/QEfficient/transformers/models/pytorch_transforms.py
@@ -195,6 +195,13 @@
Starcoder2ForCausalLM,
Starcoder2Model,
)
+from transformers.models.t5.modeling_t5 import (
+ T5Attention,
+ T5LayerCrossAttention,
+ T5LayerFF,
+ T5LayerNorm,
+ T5LayerSelfAttention,
+)
from transformers.models.whisper.modeling_whisper import (
WhisperAttention,
WhisperDecoder,
@@ -414,6 +421,13 @@
QEffStarcoder2ForCausalLM,
QEffStarcoder2Model,
)
+from QEfficient.transformers.models.t5.modeling_t5 import (
+ QEffT5Attention,
+ QEffT5LayerCrossAttention,
+ QEffT5LayerFF,
+ QEffT5LayerNorm,
+ QEffT5LayerSelfAttention,
+)
from QEfficient.transformers.models.whisper.modeling_whisper import (
QEffWhisperAttention,
QEffWhisperDecoder,
@@ -804,6 +818,22 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform):
_match_class_replace_method = {}
+class T5ModelTransform(ModuleMappingTransform):
+ # supported architectures
+ _module_mapping = {
+ T5LayerFF: QEffT5LayerFF,
+ T5LayerSelfAttention: QEffT5LayerSelfAttention,
+ T5LayerCrossAttention: QEffT5LayerCrossAttention,
+ T5Attention: QEffT5Attention,
+ T5LayerNorm: QEffT5LayerNorm,
+ }
+
+ @classmethod
+ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
+ model, transformed = super().apply(model)
+ return model, transformed
+
+
class PoolingTransform:
"""
Apply a pooling transformation to the model. This transformation appends a pooling layer to the model, allowing for the reduction of spatial dimensions in the output.
diff --git a/QEfficient/transformers/models/t5/__init__.py b/QEfficient/transformers/models/t5/__init__.py
new file mode 100644
index 000000000..75daf1953
--- /dev/null
+++ b/QEfficient/transformers/models/t5/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
diff --git a/QEfficient/transformers/models/t5/modeling_t5.py b/QEfficient/transformers/models/t5/modeling_t5.py
new file mode 100644
index 000000000..9ba5869d7
--- /dev/null
+++ b/QEfficient/transformers/models/t5/modeling_t5.py
@@ -0,0 +1,217 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import torch
+import torch.nn as nn
+from transformers.models.t5.modeling_t5 import (
+ T5Attention,
+ T5LayerCrossAttention,
+ T5LayerFF,
+ T5LayerNorm,
+ T5LayerSelfAttention,
+)
+
+
+class QEffT5LayerNorm(T5LayerNorm):
+ def forward(self, hidden_states):
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
+ # half-precision inputs is done in fp32
+
+ div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1], dtype=torch.float32))
+ variance = div_first.pow(2).sum(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+
+ return self.weight * hidden_states
+
+
+class QEffT5LayerFF(T5LayerFF):
+ def forward(self, hidden_states):
+ forwarded_states = self.layer_norm(hidden_states)
+ forwarded_states = self.DenseReluDense(forwarded_states)
+ hidden_states = hidden_states * 1.0 + self.dropout(forwarded_states)
+ return hidden_states
+
+
+class QEffT5Attention(T5Attention):
+ def forward(
+ self,
+ hidden_states,
+ mask=None,
+ key_value_states=None,
+ position_bias=None,
+ past_key_value=None,
+ layer_head_mask=None,
+ query_length=None,
+ use_cache=False,
+ output_attentions=False,
+ cache_position=None,
+ ):
+ """
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
+ """
+ # Input is (batch_size, seq_length, dim)
+ # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
+ is_cross_attention = key_value_states is not None
+
+ query_states = self.q(hidden_states)
+ query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
+
+ if past_key_value is not None:
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ curr_past_key_value = past_key_value.cross_attention_cache
+ else:
+ curr_past_key_value = past_key_value.self_attention_cache
+
+ current_states = key_value_states if is_cross_attention else hidden_states
+ if is_cross_attention and past_key_value is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = curr_past_key_value.key_cache[self.layer_idx]
+ value_states = curr_past_key_value.value_cache[self.layer_idx]
+ else:
+ key_states = self.k(current_states)
+ value_states = self.v(current_states)
+ key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
+ value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
+
+ if past_key_value is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = curr_past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention:
+ past_key_value.is_updated[self.layer_idx] = True
+
+ # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
+ scores = torch.matmul(query_states, key_states.transpose(3, 2))
+
+ if position_bias is None:
+ key_length = key_states.shape[-2]
+ # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
+ real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
+ if not self.has_relative_attention_bias:
+ position_bias = torch.zeros(
+ (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
+ )
+ if self.gradient_checkpointing and self.training:
+ position_bias.requires_grad = True
+ else:
+ position_bias = self.compute_bias(
+ real_seq_length, key_length, device=scores.device, cache_position=cache_position
+ )
+ # Original line: position_bias = position_bias[:, :, -seq_length:, :]
+ if past_key_value is not None: # This block is where the patch applies
+ # position_bias = position_bias[:, :, -hidden_states.size(1) :, :] # Original line (commented in patch)
+ position_bias = position_bias[:, :, -1:, :] # Added by patch
+
+ if mask is not None:
+ causal_mask = mask[:, :, :, : key_states.shape[-2]]
+ position_bias = position_bias + causal_mask
+
+ if self.pruned_heads:
+ mask = torch.ones(position_bias.shape[1])
+ mask[list(self.pruned_heads)] = 0
+ position_bias_masked = position_bias[:, mask.bool()]
+ else:
+ position_bias_masked = position_bias
+
+ scores += position_bias_masked
+
+ # (batch_size, n_heads, seq_length, key_length)
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ # Mask heads if we want to
+ if layer_head_mask is not None:
+ attn_weights = attn_weights * layer_head_mask
+
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(batch_size, -1, self.inner_dim)
+ attn_output = self.o(attn_output)
+
+ outputs = (attn_output, past_key_value, position_bias)
+
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+ return outputs
+
+
+class QEffT5LayerSelfAttention(T5LayerSelfAttention):
+ def __qeff_init__(self):
+ self.scaling_factor = 1.0
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+ past_key_value=None,
+ use_cache=False,
+ output_attentions=False,
+ cache_position=None,
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.SelfAttention(
+ normed_hidden_states,
+ mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ hidden_states = hidden_states * 1.0 + self.dropout(attention_output[0]) # Modified by patch
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+
+class QEffT5LayerCrossAttention(T5LayerCrossAttention):
+ def forward(
+ self,
+ hidden_states,
+ key_value_states,
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+ past_key_value=None,
+ use_cache=False,
+ query_length=None,
+ output_attentions=False,
+ cache_position=None,
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.EncDecAttention(
+ normed_hidden_states,
+ mask=attention_mask,
+ key_value_states=key_value_states,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ query_length=query_length,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ layer_output = hidden_states * 1.0 + self.dropout(attention_output[0]) # Modified by patch
+ outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
+ return outputs
diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py
index d58f54952..29ab567fb 100644
--- a/QEfficient/utils/_utils.py
+++ b/QEfficient/utils/_utils.py
@@ -532,7 +532,11 @@ def create_model_params(qeff_model, **kwargs) -> Dict:
"""
model_params = copy.deepcopy(kwargs)
model_params = {k: v for k, v in model_params.items() if k in KWARGS_INCLUSION_LIST}
- model_params["config"] = qeff_model.model.config.to_diff_dict()
+ model_params["config"] = (
+ qeff_model.model.config.to_diff_dict()
+ if hasattr(qeff_model.model.config, "to_diff_dict")
+ else qeff_model.model.config
+ )
model_params["peft_config"] = getattr(qeff_model.model, "active_peft_config", None)
model_params["applied_transform_names"] = qeff_model._transform_names()
return model_params
@@ -564,7 +568,8 @@ def wrapper(self, *args, **kwargs):
model_params=self.hash_params,
output_names=all_args.get("output_names"),
dynamic_axes=all_args.get("dynamic_axes"),
- export_kwargs=all_args.get("export_kwargs", None),
+ # TODO: Re-enable export_kwargs hashing before merging this PR
+ # export_kwargs=all_args.get("export_kwargs", None),
onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None),
)
export_dir = export_dir.with_name(export_dir.name + "-" + export_hash)
diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py
index 1504bdae5..30e9afd17 100644
--- a/QEfficient/utils/constants.py
+++ b/QEfficient/utils/constants.py
@@ -129,6 +129,35 @@ def get_models_dir():
QWEN2_5_VL_WIDTH = 536
+# wo_sfs: weight output scaling factors (used to normalize T5 encoder output weights before export)
+WO_SFS = [
+ 61,
+ 203,
+ 398,
+ 615,
+ 845,
+ 1190,
+ 1402,
+ 2242,
+ 1875,
+ 2393,
+ 3845,
+ 3213,
+ 3922,
+ 4429,
+ 5020,
+ 5623,
+ 6439,
+ 6206,
+ 5165,
+ 4593,
+ 2802,
+ 2618,
+ 1891,
+ 1419,
+]
+
+
class Constants:
# Export Constants.
SEQ_LEN = 32
diff --git a/QEfficient/utils/hash_utils.py b/QEfficient/utils/hash_utils.py
index b6b38b8b4..5fa2d0a78 100644
--- a/QEfficient/utils/hash_utils.py
+++ b/QEfficient/utils/hash_utils.py
@@ -15,6 +15,9 @@
def json_serializable(obj):
if isinstance(obj, set):
return sorted(obj)
+ # Handle objects with to_dict() method (e.g., transformers config objects)
+ if hasattr(obj, "to_dict") and callable(getattr(obj, "to_dict")):
+ return obj.to_dict()
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
diff --git a/docs/image/girl_laughing.png b/docs/image/girl_laughing.png
new file mode 100644
index 000000000..9e58da61d
Binary files /dev/null and b/docs/image/girl_laughing.png differ
diff --git a/examples/diffusers/flux/flux_1_schnell.py b/examples/diffusers/flux/flux_1_schnell.py
new file mode 100644
index 000000000..438d9532f
--- /dev/null
+++ b/examples/diffusers/flux/flux_1_schnell.py
@@ -0,0 +1,51 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+"""
+FLUX.1-schnell Image Generation Example
+
+This example demonstrates how to use the QEFFFluxPipeline to generate images
+using the FLUX.1-schnell model from Black Forest Labs. FLUX.1-schnell is a
+fast, distilled version of the FLUX.1 text-to-image model optimized for
+speed with minimal quality loss.
+
+Key Features:
+- Fast inference with only 4 steps
+- High-quality image generation from text prompts
+- Optimized for Qualcomm Cloud AI 100 using ONNX runtime
+- Deterministic output using fixed random seed
+
+Output:
+- Generates an image based on the text prompt
+- Saves the image as 'cat_with_sign.png' in the current directory
+"""
+
+import torch
+
+from QEfficient import QEFFFluxPipeline
+
+# Initialize the FLUX.1-schnell pipeline from pretrained weights
+# use_onnx_function=True enables ONNX-based optimizations for faster compilation
+pipeline = QEFFFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", use_onnx_function=False)
+
+# Generate an image from a text prompt
+output = pipeline(
+ prompt="A cat holding a sign that says hello world",
+ guidance_scale=0.0,
+ num_inference_steps=4,
+ max_sequence_length=256,
+ generator=torch.manual_seed(42),
+)
+
+# Extract the generated image from the output
+image = output.images[0]
+
+# Save the generated image to disk
+image.save("cat_with_sign.png")
+
+# Print the output object (contains perf info)
+print(output)
diff --git a/examples/diffusers/flux/flux_1_shnell_custom.py b/examples/diffusers/flux/flux_1_shnell_custom.py
new file mode 100644
index 000000000..f9f52396e
--- /dev/null
+++ b/examples/diffusers/flux/flux_1_shnell_custom.py
@@ -0,0 +1,119 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+"""
+FLUX.1 Schnell Custom Configuration Example
+
+This example demonstrates how to customize the FLUX.1 model with various options:
+1. Custom image dimensions (height/width)
+2. Custom transformer model and text encoder
+3. Custom scheduler configuration
+4. Reduced model layers for faster inference
+5. Custom compilation settings
+6. Custom runtime configuration via JSON config file
+
+Use this example to learn how to fine-tune FLUX.1 for your specific needs.
+"""
+
+import torch
+
+from QEfficient import QEFFFluxPipeline
+
+# ============================================================================
+# PIPELINE INITIALIZATION WITH CUSTOM PARAMETERS
+# ============================================================================
+# Initialize the FLUX pipeline with custom settings.
+#
+# Key parameters:
+# - Base model: "black-forest-labs/FLUX.1-schnell" (optimized for fast inference)
+# - height/width: Output image dimensions (default is 1024x1024, here using 512x512)
+#
+# Note: Smaller dimensions = faster generation but lower resolution
+
+# Option 1: Basic initialization with custom image dimensions
+# NOTE: use_onnx_function=True enables modular ONNX export optimizations (Experimental so not recommended)
+# This feature improves export performance by breaking down the model into smaller,
+# more manageable ONNX functions, which can lead to better compilation and runtime efficiency.
+pipeline = QEFFFluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell", height=256, width=256, use_onnx_function=False
+)
+
+# Option 2: Advanced initialization with custom modules
+# Uncomment and modify to use your own custom components:
+#
+# pipeline = QEFFFluxPipeline.from_pretrained(
+# "black-forest-labs/FLUX.1-schnell",
+# height=512,
+# width=512,
+# text_encoder=custom_text_encoder, # Your custom CLIP text encoder
+# transformer=custom_transformer, # Your custom transformer model
+# tokenizer=custom_tokenizer, # Your custom tokenizer
+# )
+
+# ============================================================================
+# OPTIONAL: CUSTOM SCHEDULER CONFIGURATION
+# ============================================================================
+# Uncomment to use a custom scheduler (e.g., different sampling methods):
+#
+# pipeline.scheduler = custom_scheduler.from_config(pipeline.scheduler.config)
+
+# ============================================================================
+# OPTIONAL: REDUCE MODEL LAYERS FOR FASTER INFERENCE
+# ============================================================================
+# Reduce the number of transformer blocks to speed up image generation.
+#
+# Trade-off: Faster inference but potentially lower image quality
+# Use case: Quick testing, prototyping, or when speed is critical
+#
+# Uncomment the following lines to use only the first transformer block:
+#
+# original_blocks = pipeline.transformer.model.transformer_blocks
+# org_single_blocks = pipeline.transformer.model.single_transformer_blocks
+# pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0]])
+# pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0]])
+# pipeline.transformer.model.config.num_layers = 1
+# pipeline.transformer.model.config.num_single_layers = 1
+
+# ============================================================================
+# OPTIONAL: COMPILE WITH CUSTOM CONFIGURATION
+# ============================================================================
+# Pre-compile the model for optimized performance on target hardware.
+#
+# When to use:
+# - When you want to compile the model separately before generation
+# - When you need to skip image generation and only prepare the model
+#
+# Note: If compile_config is not specified, the default configuration from
+# QEfficient/diffusers/pipelines/flux/flux_config.json will be used
+#
+# Uncomment to compile with a custom configuration:
+# pipeline.compile(compile_config="examples/diffusers/flux/flux_config.json")
+
+
+# ============================================================================
+# IMAGE GENERATION WITH CUSTOM RUNTIME CONFIGURATION
+# ============================================================================
+# Generate an image using the configured pipeline.
+# - custom_config_path: Path to JSON file with runtime settings (device IDs, etc.)
+#
+# Note: Using custom_config_path provides flexibility to set device_ids for each
+# module, so you can skip the separate pipeline.compile() step
+
+output = pipeline(
+ prompt="A girl laughing",
+ custom_config_path="examples/diffusers/flux/flux_config.json",
+ guidance_scale=0.0,
+ num_inference_steps=4,
+ max_sequence_length=256,
+ generator=torch.manual_seed(42),
+ parallel_compile=True,
+)
+
+images = output.images[0]
+# Save the generated image to disk
+images.save("girl_laughing.png")
+print(output)
diff --git a/examples/diffusers/flux/flux_config.json b/examples/diffusers/flux/flux_config.json
new file mode 100644
index 000000000..c0d2b4bbc
--- /dev/null
+++ b/examples/diffusers/flux/flux_config.json
@@ -0,0 +1,94 @@
+{
+ "description": "Default configuration for Flux pipeline",
+
+ "modules":
+ {
+ "text_encoder":
+ {
+ "specializations":{
+ "batch_size": 1,
+ "seq_len": 77
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+
+ },
+ "text_encoder_2":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "transformer":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "seq_len": 256,
+ "steps": 1
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 4,
+ "mxfp6_matmul": true,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16,
+ "mos": 1,
+ "mdts-mos": 1
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ },
+ "vae_decoder":
+ {
+ "specializations":
+ {
+ "batch_size": 1,
+ "channels": 16
+ },
+ "compilation":
+ {
+ "onnx_path": null,
+ "compile_dir": null,
+ "mdp_ts_num_devices": 1,
+ "mxfp6_matmul": false,
+ "convert_to_fp16": true,
+ "aic_num_cores": 16
+ },
+ "execute":
+ {
+ "device_ids": null
+ }
+ }
+ }
+}
diff --git a/pyproject.toml b/pyproject.toml
index ea3c3405d..e32e2e88d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -22,7 +22,7 @@ dependencies = [
"transformers==4.55.0",
"huggingface-hub==0.34.0",
"hf_transfer==0.1.9",
- "peft==0.13.2",
+ "peft==0.17.0",
"datasets==2.20.0",
"fsspec==2023.6.0",
"multidict==6.0.4",
@@ -50,7 +50,7 @@ dependencies = [
test = ["pytest","pytest-mock"]
docs = ["Sphinx==7.1.2","sphinx-rtd-theme==2.0.0","myst-parser==3.0.1","sphinx-multiversion"]
quality = ["black", "ruff", "hf_doc_builder@git+https://github.com/huggingface/doc-builder.git"]
-
+diffusers = ["diffusers== 0.35.1"]
[build-system]
requires = ["setuptools>=62.0.0"]
build-backend = "setuptools.build_meta"
@@ -71,4 +71,4 @@ target-version = "py310"
[tool.pytest.ini_options]
addopts = "-W ignore -s -v"
junit_logging = "all"
-doctest_optionflags = "NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
+doctest_optionflags = "NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
\ No newline at end of file
diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile
index d9d391d47..232d224af 100644
--- a/scripts/Jenkinsfile
+++ b/scripts/Jenkinsfile
@@ -22,6 +22,7 @@ pipeline {
. preflight_qeff/bin/activate &&
pip install --upgrade pip setuptools &&
pip install .[test] &&
+ pip install .[diffusers] &&
pip install junitparser pytest-xdist &&
pip install librosa==0.10.2 soundfile==0.13.1 && #packages needed to load example for whisper testing
pip install --extra-index-url https://download.pytorch.org/whl/cpu timm==1.0.14 torchvision==0.22.0+cpu einops==0.8.1 && #packages to load VLMs
@@ -34,7 +35,7 @@ pipeline {
parallel {
stage('Run Non-CLI Non-QAIC Tests') {
steps {
- timeout(time: 25, unit: 'MINUTES') {
+ timeout(time: 100, unit: 'MINUTES') {
sh '''
sudo docker exec ${BUILD_TAG} bash -c "
cd /efficient-transformers &&