Skip to content

Commit 4d45694

Browse files
quic-amitrajAmit Raj
authored andcommitted
Compile fix
Signed-off-by: Amit Raj <amitraj@qti.qualcomm.com> Signed-off-by: Amit Raj <amitraj@qti.qualcommm.com>
1 parent f7e7527 commit 4d45694

File tree

4 files changed

+210
-210
lines changed

4 files changed

+210
-210
lines changed

QEfficient/diffusers/models/pytorch_transforms.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class AttentionTransform(ModuleMappingTransform):
5858
JointTransformerBlock: QEffJointTransformerBlock,
5959
FluxSingleTransformerBlock: QEffFluxSingleTransformerBlock,
6060
FluxTransformerBlock: QEffFluxTransformerBlock,
61+
FluxTransformer2DModel: QEffFluxTransformer2DModel,
6162
FluxAttention: QEffFluxAttention,
6263
FluxAttnProcessor: QEffFluxAttnProcessor,
6364
}

QEfficient/diffusers/models/transformers/transformer_flux.py

Lines changed: 196 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
#
66
# ----------------------------------------------------------------------------
77
from typing import Any, Dict, Optional, Tuple, Union
8+
from venv import logger
89

10+
import numpy as np
911
import torch
1012
import torch.nn as nn
1113
from diffusers.models.attention_dispatch import dispatch_attention_fn
14+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
1215
from diffusers.models.transformers.transformer_flux import (
1316
FluxAttention,
1417
FluxAttnProcessor,
@@ -19,6 +22,7 @@
1922
)
2023

2124
from QEfficient.diffusers.models.normalization import (
25+
QEffAdaLayerNormContinuous,
2226
QEffAdaLayerNormZero,
2327
QEffAdaLayerNormZeroSingle,
2428
)
@@ -249,28 +253,198 @@ def forward(
249253

250254

251255
class QEffFluxTransformer2DModel(FluxTransformer2DModel):
252-
def __qeff_init__(self):
253-
self.transformer_blocks = nn.ModuleList()
254-
self._block_classes = set()
255-
256-
for _ in range(self.config.num_layers):
257-
BlockClass = QEffFluxTransformerBlock
258-
block = BlockClass(
259-
dim=self.inner_dim,
260-
num_attention_heads=self.config.num_attention_heads,
261-
attention_head_dim=self.config.attention_head_dim,
262-
)
263-
self.transformer_blocks.append(block)
264-
self._block_classes.add(BlockClass)
256+
def __init__(
257+
self,
258+
patch_size: int = 1,
259+
in_channels: int = 64,
260+
out_channels: Optional[int] = None,
261+
num_layers: int = 19,
262+
num_single_layers: int = 38,
263+
attention_head_dim: int = 128,
264+
num_attention_heads: int = 24,
265+
joint_attention_dim: int = 4096,
266+
pooled_projection_dim: int = 768,
267+
guidance_embeds: bool = False,
268+
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
269+
):
270+
super().__init__(
271+
patch_size=patch_size,
272+
in_channels=in_channels,
273+
out_channels=out_channels,
274+
num_layers=num_layers,
275+
num_single_layers=num_single_layers,
276+
attention_head_dim=attention_head_dim,
277+
num_attention_heads=num_attention_heads,
278+
joint_attention_dim=joint_attention_dim,
279+
pooled_projection_dim=pooled_projection_dim,
280+
guidance_embeds=guidance_embeds,
281+
axes_dims_rope=axes_dims_rope,
282+
)
283+
284+
self.transformer_blocks = nn.ModuleList(
285+
[
286+
QEffFluxTransformerBlock(
287+
dim=self.inner_dim,
288+
num_attention_heads=num_attention_heads,
289+
attention_head_dim=attention_head_dim,
290+
)
291+
for _ in range(num_layers)
292+
]
293+
)
294+
295+
self.single_transformer_blocks = nn.ModuleList(
296+
[
297+
QEffFluxSingleTransformerBlock(
298+
dim=self.inner_dim,
299+
num_attention_heads=num_attention_heads,
300+
attention_head_dim=attention_head_dim,
301+
)
302+
for _ in range(num_single_layers)
303+
]
304+
)
265305

266-
self.single_transformer_blocks = nn.ModuleList()
306+
self.norm_out = QEffAdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
267307

268-
for _ in range(self.config.num_single_layers):
269-
SingleBlockClass = QEffFluxSingleTransformerBlock
270-
single_block = SingleBlockClass(
271-
dim=self.inner_dim,
272-
num_attention_heads=self.config.num_attention_heads,
273-
attention_head_dim=self.config.attention_head_dim,
308+
def forward(
309+
self,
310+
hidden_states: torch.Tensor,
311+
encoder_hidden_states: torch.Tensor = None,
312+
pooled_projections: torch.Tensor = None,
313+
timestep: torch.LongTensor = None,
314+
img_ids: torch.Tensor = None,
315+
txt_ids: torch.Tensor = None,
316+
adaln_emb: torch.Tensor = None,
317+
adaln_single_emb: torch.Tensor = None,
318+
adaln_out: torch.Tensor = None,
319+
guidance: torch.Tensor = None,
320+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
321+
controlnet_block_samples=None,
322+
controlnet_single_block_samples=None,
323+
return_dict: bool = True,
324+
controlnet_blocks_repeat: bool = False,
325+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
326+
"""
327+
The [`FluxTransformer2DModel`] forward method.
328+
329+
Args:
330+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
331+
Input `hidden_states`.
332+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
333+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
334+
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
335+
from the embeddings of input conditions.
336+
timestep ( `torch.LongTensor`):
337+
Used to indicate denoising step.
338+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
339+
A list of tensors that if specified are added to the residuals of transformer blocks.
340+
joint_attention_kwargs (`dict`, *optional*):
341+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
342+
`self.processor` in
343+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
344+
return_dict (`bool`, *optional*, defaults to `True`):
345+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
346+
tuple.
347+
Returns:
348+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
349+
`tuple` where the first element is the sample tensor.
350+
"""
351+
352+
hidden_states = self.x_embedder(hidden_states)
353+
354+
timestep = timestep.to(hidden_states.dtype) * 1000
355+
if guidance is not None:
356+
guidance = guidance.to(hidden_states.dtype) * 1000
357+
358+
temb = (
359+
self.time_text_embed(timestep, pooled_projections)
360+
if guidance is None
361+
else self.time_text_embed(timestep, guidance, pooled_projections)
362+
)
363+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
364+
365+
if txt_ids.ndim == 3:
366+
logger.warning(
367+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
368+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
369+
)
370+
txt_ids = txt_ids[0]
371+
if img_ids.ndim == 3:
372+
logger.warning(
373+
"Passing `img_ids` 3d torch.Tensor is deprecated."
374+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
274375
)
275-
self.single_transformer_blocks.append(single_block)
276-
self._block_classes.add(SingleBlockClass)
376+
img_ids = img_ids[0]
377+
378+
ids = torch.cat((txt_ids, img_ids), dim=0)
379+
image_rotary_emb = self.pos_embed(ids)
380+
381+
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
382+
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
383+
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
384+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
385+
386+
for index_block, block in enumerate(self.transformer_blocks):
387+
if torch.is_grad_enabled() and self.gradient_checkpointing:
388+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
389+
block,
390+
hidden_states,
391+
encoder_hidden_states,
392+
temb,
393+
image_rotary_emb,
394+
joint_attention_kwargs,
395+
)
396+
397+
else:
398+
encoder_hidden_states, hidden_states = block(
399+
hidden_states=hidden_states,
400+
encoder_hidden_states=encoder_hidden_states,
401+
temb=adaln_emb[index_block],
402+
image_rotary_emb=image_rotary_emb,
403+
joint_attention_kwargs=joint_attention_kwargs,
404+
)
405+
406+
# controlnet residual
407+
if controlnet_block_samples is not None:
408+
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
409+
interval_control = int(np.ceil(interval_control))
410+
# For Xlabs ControlNet.
411+
if controlnet_blocks_repeat:
412+
hidden_states = (
413+
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
414+
)
415+
else:
416+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
417+
418+
for index_block, block in enumerate(self.single_transformer_blocks):
419+
if torch.is_grad_enabled() and self.gradient_checkpointing:
420+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
421+
block,
422+
hidden_states,
423+
encoder_hidden_states,
424+
temb,
425+
image_rotary_emb,
426+
joint_attention_kwargs,
427+
)
428+
429+
else:
430+
encoder_hidden_states, hidden_states = block(
431+
hidden_states=hidden_states,
432+
encoder_hidden_states=encoder_hidden_states,
433+
temb=adaln_single_emb[index_block],
434+
image_rotary_emb=image_rotary_emb,
435+
joint_attention_kwargs=joint_attention_kwargs,
436+
)
437+
438+
# controlnet residual
439+
if controlnet_single_block_samples is not None:
440+
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
441+
interval_control = int(np.ceil(interval_control))
442+
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
443+
444+
hidden_states = self.norm_out(hidden_states, adaln_out)
445+
output = self.proj_out(hidden_states)
446+
447+
if not return_dict:
448+
return (output,)
449+
450+
return Transformer2DModelOutput(sample=output)

QEfficient/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def compile(
243243
specializations[0]["latent_height"] = self.latent_height
244244
specializations[0]["latent_width"] = self.latent_width
245245
# Compile the module
246-
module_obj._compile(specializations=specializations, **compile_kwargs)
246+
module_obj.compile(specializations=specializations, **compile_kwargs)
247247

248248
def _get_t5_prompt_embeds(
249249
self,

0 commit comments

Comments
 (0)