Skip to content

Commit f7e7527

Browse files
quic-amitrajAmit Raj
authored andcommitted
Added OnnxfunctionTransform and code cleanup while modifying compile and export APIs
Signed-off-by: Amit Raj <amitraj@qti.qualcomm.com> Signed-off-by: Amit Raj <amitraj@qti.qualcommm.com>
1 parent 3f33a63 commit f7e7527

File tree

6 files changed

+209
-225
lines changed

6 files changed

+209
-225
lines changed

QEfficient/diffusers/models/pytorch_transforms.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ class AttentionTransform(ModuleMappingTransform):
5858
JointTransformerBlock: QEffJointTransformerBlock,
5959
FluxSingleTransformerBlock: QEffFluxSingleTransformerBlock,
6060
FluxTransformerBlock: QEffFluxTransformerBlock,
61-
FluxTransformer2DModel: QEffFluxTransformer2DModel,
6261
FluxAttention: QEffFluxAttention,
6362
FluxAttnProcessor: QEffFluxAttnProcessor,
6463
}
@@ -80,3 +79,12 @@ class NormalizationTransform(ModuleMappingTransform):
8079
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
8180
model, transformed = super().apply(model)
8281
return model, transformed
82+
83+
84+
class OnnxFunctionTransform(ModuleMappingTransform):
85+
_module_mapping = {FluxTransformer2DModel: QEffFluxTransformer2DModel}
86+
87+
@classmethod
88+
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
89+
model, transformed = super().apply(model)
90+
return model, transformed

QEfficient/diffusers/models/transformers/transformer_flux.py

Lines changed: 22 additions & 196 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,10 @@
55
#
66
# ----------------------------------------------------------------------------
77
from typing import Any, Dict, Optional, Tuple, Union
8-
from venv import logger
98

10-
import numpy as np
119
import torch
1210
import torch.nn as nn
1311
from diffusers.models.attention_dispatch import dispatch_attention_fn
14-
from diffusers.models.modeling_outputs import Transformer2DModelOutput
1512
from diffusers.models.transformers.transformer_flux import (
1613
FluxAttention,
1714
FluxAttnProcessor,
@@ -22,7 +19,6 @@
2219
)
2320

2421
from QEfficient.diffusers.models.normalization import (
25-
QEffAdaLayerNormContinuous,
2622
QEffAdaLayerNormZero,
2723
QEffAdaLayerNormZeroSingle,
2824
)
@@ -253,198 +249,28 @@ def forward(
253249

254250

255251
class QEffFluxTransformer2DModel(FluxTransformer2DModel):
256-
def __init__(
257-
self,
258-
patch_size: int = 1,
259-
in_channels: int = 64,
260-
out_channels: Optional[int] = None,
261-
num_layers: int = 19,
262-
num_single_layers: int = 38,
263-
attention_head_dim: int = 128,
264-
num_attention_heads: int = 24,
265-
joint_attention_dim: int = 4096,
266-
pooled_projection_dim: int = 768,
267-
guidance_embeds: bool = False,
268-
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
269-
):
270-
super().__init__(
271-
patch_size=patch_size,
272-
in_channels=in_channels,
273-
out_channels=out_channels,
274-
num_layers=num_layers,
275-
num_single_layers=num_single_layers,
276-
attention_head_dim=attention_head_dim,
277-
num_attention_heads=num_attention_heads,
278-
joint_attention_dim=joint_attention_dim,
279-
pooled_projection_dim=pooled_projection_dim,
280-
guidance_embeds=guidance_embeds,
281-
axes_dims_rope=axes_dims_rope,
282-
)
283-
284-
self.transformer_blocks = nn.ModuleList(
285-
[
286-
QEffFluxTransformerBlock(
287-
dim=self.inner_dim,
288-
num_attention_heads=num_attention_heads,
289-
attention_head_dim=attention_head_dim,
290-
)
291-
for _ in range(num_layers)
292-
]
293-
)
294-
295-
self.single_transformer_blocks = nn.ModuleList(
296-
[
297-
QEffFluxSingleTransformerBlock(
298-
dim=self.inner_dim,
299-
num_attention_heads=num_attention_heads,
300-
attention_head_dim=attention_head_dim,
301-
)
302-
for _ in range(num_single_layers)
303-
]
304-
)
305-
306-
self.norm_out = QEffAdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
252+
def __qeff_init__(self):
253+
self.transformer_blocks = nn.ModuleList()
254+
self._block_classes = set()
255+
256+
for _ in range(self.config.num_layers):
257+
BlockClass = QEffFluxTransformerBlock
258+
block = BlockClass(
259+
dim=self.inner_dim,
260+
num_attention_heads=self.config.num_attention_heads,
261+
attention_head_dim=self.config.attention_head_dim,
262+
)
263+
self.transformer_blocks.append(block)
264+
self._block_classes.add(BlockClass)
307265

308-
def forward(
309-
self,
310-
hidden_states: torch.Tensor,
311-
encoder_hidden_states: torch.Tensor = None,
312-
pooled_projections: torch.Tensor = None,
313-
timestep: torch.LongTensor = None,
314-
img_ids: torch.Tensor = None,
315-
txt_ids: torch.Tensor = None,
316-
adaln_emb: torch.Tensor = None,
317-
adaln_single_emb: torch.Tensor = None,
318-
adaln_out: torch.Tensor = None,
319-
guidance: torch.Tensor = None,
320-
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
321-
controlnet_block_samples=None,
322-
controlnet_single_block_samples=None,
323-
return_dict: bool = True,
324-
controlnet_blocks_repeat: bool = False,
325-
) -> Union[torch.Tensor, Transformer2DModelOutput]:
326-
"""
327-
The [`FluxTransformer2DModel`] forward method.
328-
329-
Args:
330-
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
331-
Input `hidden_states`.
332-
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
333-
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
334-
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
335-
from the embeddings of input conditions.
336-
timestep ( `torch.LongTensor`):
337-
Used to indicate denoising step.
338-
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
339-
A list of tensors that if specified are added to the residuals of transformer blocks.
340-
joint_attention_kwargs (`dict`, *optional*):
341-
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
342-
`self.processor` in
343-
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
344-
return_dict (`bool`, *optional*, defaults to `True`):
345-
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
346-
tuple.
347-
Returns:
348-
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
349-
`tuple` where the first element is the sample tensor.
350-
"""
351-
352-
hidden_states = self.x_embedder(hidden_states)
353-
354-
timestep = timestep.to(hidden_states.dtype) * 1000
355-
if guidance is not None:
356-
guidance = guidance.to(hidden_states.dtype) * 1000
357-
358-
temb = (
359-
self.time_text_embed(timestep, pooled_projections)
360-
if guidance is None
361-
else self.time_text_embed(timestep, guidance, pooled_projections)
362-
)
363-
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
266+
self.single_transformer_blocks = nn.ModuleList()
364267

365-
if txt_ids.ndim == 3:
366-
logger.warning(
367-
"Passing `txt_ids` 3d torch.Tensor is deprecated."
368-
"Please remove the batch dimension and pass it as a 2d torch Tensor"
369-
)
370-
txt_ids = txt_ids[0]
371-
if img_ids.ndim == 3:
372-
logger.warning(
373-
"Passing `img_ids` 3d torch.Tensor is deprecated."
374-
"Please remove the batch dimension and pass it as a 2d torch Tensor"
268+
for _ in range(self.config.num_single_layers):
269+
SingleBlockClass = QEffFluxSingleTransformerBlock
270+
single_block = SingleBlockClass(
271+
dim=self.inner_dim,
272+
num_attention_heads=self.config.num_attention_heads,
273+
attention_head_dim=self.config.attention_head_dim,
375274
)
376-
img_ids = img_ids[0]
377-
378-
ids = torch.cat((txt_ids, img_ids), dim=0)
379-
image_rotary_emb = self.pos_embed(ids)
380-
381-
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
382-
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
383-
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
384-
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
385-
386-
for index_block, block in enumerate(self.transformer_blocks):
387-
if torch.is_grad_enabled() and self.gradient_checkpointing:
388-
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
389-
block,
390-
hidden_states,
391-
encoder_hidden_states,
392-
temb,
393-
image_rotary_emb,
394-
joint_attention_kwargs,
395-
)
396-
397-
else:
398-
encoder_hidden_states, hidden_states = block(
399-
hidden_states=hidden_states,
400-
encoder_hidden_states=encoder_hidden_states,
401-
temb=adaln_emb[index_block],
402-
image_rotary_emb=image_rotary_emb,
403-
joint_attention_kwargs=joint_attention_kwargs,
404-
)
405-
406-
# controlnet residual
407-
if controlnet_block_samples is not None:
408-
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
409-
interval_control = int(np.ceil(interval_control))
410-
# For Xlabs ControlNet.
411-
if controlnet_blocks_repeat:
412-
hidden_states = (
413-
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
414-
)
415-
else:
416-
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
417-
418-
for index_block, block in enumerate(self.single_transformer_blocks):
419-
if torch.is_grad_enabled() and self.gradient_checkpointing:
420-
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
421-
block,
422-
hidden_states,
423-
encoder_hidden_states,
424-
temb,
425-
image_rotary_emb,
426-
joint_attention_kwargs,
427-
)
428-
429-
else:
430-
encoder_hidden_states, hidden_states = block(
431-
hidden_states=hidden_states,
432-
encoder_hidden_states=encoder_hidden_states,
433-
temb=adaln_single_emb[index_block],
434-
image_rotary_emb=image_rotary_emb,
435-
joint_attention_kwargs=joint_attention_kwargs,
436-
)
437-
438-
# controlnet residual
439-
if controlnet_single_block_samples is not None:
440-
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
441-
interval_control = int(np.ceil(interval_control))
442-
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
443-
444-
hidden_states = self.norm_out(hidden_states, adaln_out)
445-
output = self.proj_out(hidden_states)
446-
447-
if not return_dict:
448-
return (output,)
449-
450-
return Transformer2DModelOutput(sample=output)
275+
self.single_transformer_blocks.append(single_block)
276+
self._block_classes.add(SingleBlockClass)

QEfficient/diffusers/pipelines/config_manager.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,3 @@ def config_manager(cls, config_source: Optional[str] = None):
3333
raise FileNotFoundError(f"Configuration file not found: {config_source}")
3434

3535
cls._compile_config = load_json(config_source)
36-
37-

0 commit comments

Comments
 (0)