Skip to content

Commit 96e360d

Browse files
tv-karthikeyaAmit Raj
authored andcommitted
[WIP] Adding support for custom Height,width
Signed-off-by: Amit Raj <amitraj@qti.qualcommm.com>
1 parent bed22a1 commit 96e360d

File tree

6 files changed

+166
-76
lines changed

6 files changed

+166
-76
lines changed

QEfficient/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def check_qaic_sdk():
5252
from QEfficient.compile.compile_helper import compile
5353

5454
# Imports for the diffusers
55-
from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEFFFluxPipeline
55+
from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEFFFluxPipeline
5656
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
5757
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
5858
from QEfficient.peft import QEffAutoPeftModelForCausalLM

QEfficient/diffusers/models/pytorch_transforms.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from diffusers.models.attention import JointTransformerBlock
1111
from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
1212
from diffusers.models.normalization import RMSNorm, AdaLayerNormZero, AdaLayerNormZeroSingle, AdaLayerNormContinuous
13-
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock, FluxTransformer2DModel
13+
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock, FluxTransformer2DModel, FluxAttnProcessor,FluxAttention
1414

1515
from QEfficient.base.pytorch_transforms import ModuleMappingTransform
1616
from QEfficient.customop.rms_norm import CustomRMSNormAIC
@@ -19,7 +19,7 @@
1919
QEffAttention,
2020
QEffJointAttnProcessor2_0,
2121
)
22-
from QEfficient.diffusers.models.transformers.transformer_flux import QEffFluxSingleTransformerBlock, QEffFluxTransformerBlock, QEffFluxTransformer2DModel
22+
from QEfficient.diffusers.models.transformers.transformer_flux import QEffFluxSingleTransformerBlock, QEffFluxTransformerBlock, QEffFluxTransformer2DModel, QEffFluxAttnProcessor, QEffFluxAttention
2323
from QEfficient.diffusers.models.normalization import QEffAdaLayerNormZero, QEffAdaLayerNormZeroSingle, QEffAdaLayerNormContinuous
2424

2525
class CustomOpsTransform(ModuleMappingTransform):
@@ -42,6 +42,8 @@ class AttentionTransform(ModuleMappingTransform):
4242
FluxSingleTransformerBlock: QEffFluxSingleTransformerBlock,
4343
FluxTransformerBlock: QEffFluxTransformerBlock,
4444
FluxTransformer2DModel: QEffFluxTransformer2DModel,
45+
FluxAttention : QEffFluxAttention,
46+
FluxAttnProcessor: QEffFluxAttnProcessor
4547
}
4648

4749
@classmethod

QEfficient/diffusers/models/transformers/transformer_flux.py

Lines changed: 96 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,102 @@
1212
import torch.nn as nn
1313
import numpy as np
1414

15-
from diffusers.models.transformers.transformer_flux import FluxAttention,FluxSingleTransformerBlock, FluxTransformerBlock, FluxTransformer2DModel, FluxPosEmbed, FluxAttnProcessor
15+
from diffusers.models.transformers.transformer_flux import FluxAttention,FluxSingleTransformerBlock, FluxTransformerBlock, FluxTransformer2DModel, FluxPosEmbed, FluxAttnProcessor, _get_qkv_projections
1616
from diffusers.models.modeling_outputs import Transformer2DModelOutput
17-
from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
18-
from diffusers.models.attention import FeedForward
17+
from diffusers.models.attention_dispatch import dispatch_attention_fn
1918

2019
from QEfficient.diffusers.models.normalization import QEffAdaLayerNormZero, QEffAdaLayerNormZeroSingle, QEffAdaLayerNormContinuous
2120

21+
def qeff_apply_rotary_emb(
22+
x: torch.Tensor,
23+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
24+
"""
25+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
26+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
27+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
28+
tensors contain rotary embeddings and are returned as real tensors.
29+
30+
Args:
31+
x (`torch.Tensor`):
32+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
33+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
34+
35+
Returns:
36+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
37+
"""
38+
cos, sin = freqs_cis # [S, D]
39+
cos = cos[None, :, None, :]
40+
sin = sin[None, :, None, :]
41+
cos, sin = cos.to(x.device), sin.to(x.device)
42+
B, S, H , D = x.shape
43+
x_real, x_imag = x.reshape(B, -1, H, D//2, 2).unbind(-1)
44+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
45+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
46+
return out
47+
48+
class QEffFluxAttnProcessor(FluxAttnProcessor):
49+
_attention_backend = None
50+
_parallel_config = None
51+
52+
def __call__(
53+
self,
54+
attn: "QEffFluxAttention",
55+
hidden_states: torch.Tensor,
56+
encoder_hidden_states: torch.Tensor = None,
57+
attention_mask: Optional[torch.Tensor] = None,
58+
image_rotary_emb: Optional[torch.Tensor] = None,
59+
) -> torch.Tensor:
60+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
61+
attn, hidden_states, encoder_hidden_states
62+
)
63+
64+
query = query.unflatten(-1, (attn.heads, -1))
65+
key = key.unflatten(-1, (attn.heads, -1))
66+
value = value.unflatten(-1, (attn.heads, -1))
67+
68+
query = attn.norm_q(query)
69+
key = attn.norm_k(key)
70+
71+
if attn.added_kv_proj_dim is not None:
72+
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
73+
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
74+
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
75+
76+
encoder_query = attn.norm_added_q(encoder_query)
77+
encoder_key = attn.norm_added_k(encoder_key)
78+
79+
query = torch.cat([encoder_query, query], dim=1)
80+
key = torch.cat([encoder_key, key], dim=1)
81+
value = torch.cat([encoder_value, value], dim=1)
82+
83+
if image_rotary_emb is not None:
84+
query = qeff_apply_rotary_emb(query, image_rotary_emb)
85+
key = qeff_apply_rotary_emb(key, image_rotary_emb)
86+
87+
hidden_states = dispatch_attention_fn(
88+
query, key, value, attn_mask=attention_mask, backend=self._attention_backend
89+
)
90+
hidden_states = hidden_states.flatten(2, 3)
91+
hidden_states = hidden_states.to(query.dtype)
92+
93+
if encoder_hidden_states is not None:
94+
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
95+
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
96+
)
97+
hidden_states = attn.to_out[0](hidden_states)
98+
hidden_states = attn.to_out[1](hidden_states)
99+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
100+
101+
return hidden_states, encoder_hidden_states
102+
else:
103+
return hidden_states
104+
105+
class QEffFluxAttention(FluxAttention):
106+
def __qeff_init__(self):
107+
processor = QEffFluxAttnProcessor()
108+
self.processor = processor
109+
110+
22111
class QEffFluxSingleTransformerBlock(FluxSingleTransformerBlock):
23112
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
24113
super().__init__(dim, num_attention_heads, attention_head_dim, mlp_ratio)
@@ -27,13 +116,13 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int,
27116
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
28117
self.act_mlp = nn.GELU(approximate="tanh")
29118
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
30-
self.attn = FluxAttention(
119+
self.attn = QEffFluxAttention(
31120
query_dim=dim,
32121
dim_head=attention_head_dim,
33122
heads=num_attention_heads,
34123
out_dim=dim,
35124
bias=True,
36-
processor=FluxAttnProcessor(),
125+
processor=QEffFluxAttnProcessor(),
37126
eps=1e-6,
38127
pre_only=True,
39128
)
@@ -77,23 +166,18 @@ def __init__(
77166

78167
self.norm1 = QEffAdaLayerNormZero(dim)
79168
self.norm1_context = QEffAdaLayerNormZero(dim)
80-
self.attn = FluxAttention(
169+
self.attn = QEffFluxAttention(
81170
query_dim=dim,
82171
added_kv_proj_dim=dim,
83172
dim_head=attention_head_dim,
84173
heads=num_attention_heads,
85174
out_dim=dim,
86175
context_pre_only=False,
87176
bias=True,
88-
processor=FluxAttnProcessor(),
177+
processor=QEffFluxAttnProcessor(),
89178
eps=eps,
90179
)
91180

92-
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
93-
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
94-
95-
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
96-
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
97181

98182
def forward(
99183
self,
@@ -174,9 +258,6 @@ def __init__(
174258
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
175259
):
176260

177-
resolved_out_channels = out_channels or in_channels
178-
inner_dim = num_attention_heads * attention_head_dim
179-
180261
super().__init__(
181262
patch_size=patch_size,
182263
in_channels=in_channels,
@@ -191,21 +272,6 @@ def __init__(
191272
axes_dims_rope=axes_dims_rope,
192273
)
193274

194-
self.out_channels = resolved_out_channels
195-
self.inner_dim = inner_dim
196-
197-
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
198-
199-
text_time_guidance_cls = (
200-
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
201-
)
202-
self.time_text_embed = text_time_guidance_cls(
203-
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
204-
)
205-
206-
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
207-
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
208-
209275
self.transformer_blocks = nn.ModuleList(
210276
[
211277
QEffFluxTransformerBlock(
@@ -229,9 +295,7 @@ def __init__(
229295
)
230296

231297
self.norm_out = QEffAdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
232-
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
233298

234-
self.gradient_checkpointing = False
235299

236300
def forward(
237301
self,

0 commit comments

Comments
 (0)