Skip to content

Commit 3f33a63

Browse files
quic-amitrajAmit Raj
authored andcommitted
Flux support with Custom config
Signed-off-by: Amit Raj <amitraj@qti.qualcomm.com> Signed-off-by: Amit Raj <amitraj@qti.qualcommm.com>
1 parent 96e360d commit 3f33a63

File tree

10 files changed

+367
-288
lines changed

10 files changed

+367
-288
lines changed

QEfficient/diffusers/models/normalization.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# ----------------------------------------------------------------------------
7-
import numbers
8-
from typing import Dict, Optional, Tuple
7+
from typing import Optional, Tuple
8+
99
import torch
10-
from diffusers.models.normalization import AdaLayerNormZero, AdaLayerNormZeroSingle,AdaLayerNormContinuous
10+
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
1111

1212

1313
class QEffAdaLayerNormZero(AdaLayerNormZero):
@@ -21,13 +21,14 @@ def forward(
2121
scale_msa: Optional[torch.Tensor] = None,
2222
# emb: Optional[torch.Tensor] = None,
2323
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
24-
if self.emb is not None:
25-
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
24+
# if self.emb is not None:
25+
# emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
2626
# emb = self.linear(self.silu(emb))
2727
# shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
2828
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
2929
return x
3030

31+
3132
class QEffAdaLayerNormZeroSingle(AdaLayerNormZeroSingle):
3233
def forward(
3334
self,
@@ -39,11 +40,12 @@ def forward(
3940
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
4041
return x
4142

43+
4244
class QEffAdaLayerNormContinuous(AdaLayerNormContinuous):
4345
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
4446
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
4547
# emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
4648
emb = conditioning_embedding
4749
scale, shift = torch.chunk(emb, 2, dim=1)
4850
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
49-
return x
51+
return x

QEfficient/diffusers/models/pytorch_transforms.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,18 @@
55
#
66
# -----------------------------------------------------------------------------
77
from typing import Tuple
8-
import torch
9-
from torch import nn
8+
109
from diffusers.models.attention import JointTransformerBlock
1110
from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
12-
from diffusers.models.normalization import RMSNorm, AdaLayerNormZero, AdaLayerNormZeroSingle, AdaLayerNormContinuous
13-
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock, FluxTransformer2DModel, FluxAttnProcessor,FluxAttention
11+
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm
12+
from diffusers.models.transformers.transformer_flux import (
13+
FluxAttention,
14+
FluxAttnProcessor,
15+
FluxSingleTransformerBlock,
16+
FluxTransformer2DModel,
17+
FluxTransformerBlock,
18+
)
19+
from torch import nn
1420

1521
from QEfficient.base.pytorch_transforms import ModuleMappingTransform
1622
from QEfficient.customop.rms_norm import CustomRMSNormAIC
@@ -19,14 +25,25 @@
1925
QEffAttention,
2026
QEffJointAttnProcessor2_0,
2127
)
22-
from QEfficient.diffusers.models.transformers.transformer_flux import QEffFluxSingleTransformerBlock, QEffFluxTransformerBlock, QEffFluxTransformer2DModel, QEffFluxAttnProcessor, QEffFluxAttention
23-
from QEfficient.diffusers.models.normalization import QEffAdaLayerNormZero, QEffAdaLayerNormZeroSingle, QEffAdaLayerNormContinuous
28+
from QEfficient.diffusers.models.normalization import (
29+
QEffAdaLayerNormContinuous,
30+
QEffAdaLayerNormZero,
31+
QEffAdaLayerNormZeroSingle,
32+
)
33+
from QEfficient.diffusers.models.transformers.transformer_flux import (
34+
QEffFluxAttention,
35+
QEffFluxAttnProcessor,
36+
QEffFluxSingleTransformerBlock,
37+
QEffFluxTransformer2DModel,
38+
QEffFluxTransformerBlock,
39+
)
40+
2441

2542
class CustomOpsTransform(ModuleMappingTransform):
2643
_module_mapping = {
2744
RMSNorm: CustomRMSNormAIC,
28-
nn.RMSNorm: CustomRMSNormAIC # for torch.nn.RMSNorm
29-
}
45+
nn.RMSNorm: CustomRMSNormAIC, # for torch.nn.RMSNorm
46+
}
3047

3148
@classmethod
3249
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
@@ -42,21 +59,23 @@ class AttentionTransform(ModuleMappingTransform):
4259
FluxSingleTransformerBlock: QEffFluxSingleTransformerBlock,
4360
FluxTransformerBlock: QEffFluxTransformerBlock,
4461
FluxTransformer2DModel: QEffFluxTransformer2DModel,
45-
FluxAttention : QEffFluxAttention,
46-
FluxAttnProcessor: QEffFluxAttnProcessor
62+
FluxAttention: QEffFluxAttention,
63+
FluxAttnProcessor: QEffFluxAttnProcessor,
4764
}
4865

4966
@classmethod
5067
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
5168
model, transformed = super().apply(model)
5269
return model, transformed
5370

71+
5472
class NormalizationTransform(ModuleMappingTransform):
5573
_module_mapping = {
5674
AdaLayerNormZero: QEffAdaLayerNormZero,
5775
AdaLayerNormZeroSingle: QEffAdaLayerNormZeroSingle,
5876
AdaLayerNormContinuous: QEffAdaLayerNormContinuous,
5977
}
78+
6079
@classmethod
6180
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
6281
model, transformed = super().apply(model)

QEfficient/diffusers/models/transformers/transformer_flux.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,33 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# ----------------------------------------------------------------------------
7-
import os
8-
from typing import Any, Callable, Dict, List, Tuple, Optional, Union
7+
from typing import Any, Dict, Optional, Tuple, Union
98
from venv import logger
109

10+
import numpy as np
1111
import torch
1212
import torch.nn as nn
13-
import numpy as np
14-
15-
from diffusers.models.transformers.transformer_flux import FluxAttention,FluxSingleTransformerBlock, FluxTransformerBlock, FluxTransformer2DModel, FluxPosEmbed, FluxAttnProcessor, _get_qkv_projections
16-
from diffusers.models.modeling_outputs import Transformer2DModelOutput
1713
from diffusers.models.attention_dispatch import dispatch_attention_fn
14+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
15+
from diffusers.models.transformers.transformer_flux import (
16+
FluxAttention,
17+
FluxAttnProcessor,
18+
FluxSingleTransformerBlock,
19+
FluxTransformer2DModel,
20+
FluxTransformerBlock,
21+
_get_qkv_projections,
22+
)
23+
24+
from QEfficient.diffusers.models.normalization import (
25+
QEffAdaLayerNormContinuous,
26+
QEffAdaLayerNormZero,
27+
QEffAdaLayerNormZeroSingle,
28+
)
1829

19-
from QEfficient.diffusers.models.normalization import QEffAdaLayerNormZero, QEffAdaLayerNormZeroSingle, QEffAdaLayerNormContinuous
2030

2131
def qeff_apply_rotary_emb(
22-
x: torch.Tensor,
23-
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
32+
x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]
33+
) -> Tuple[torch.Tensor, torch.Tensor]:
2434
"""
2535
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
2636
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
@@ -39,12 +49,13 @@ def qeff_apply_rotary_emb(
3949
cos = cos[None, :, None, :]
4050
sin = sin[None, :, None, :]
4151
cos, sin = cos.to(x.device), sin.to(x.device)
42-
B, S, H , D = x.shape
43-
x_real, x_imag = x.reshape(B, -1, H, D//2, 2).unbind(-1)
52+
B, S, H, D = x.shape
53+
x_real, x_imag = x.reshape(B, -1, H, D // 2, 2).unbind(-1)
4454
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
4555
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
4656
return out
4757

58+
4859
class QEffFluxAttnProcessor(FluxAttnProcessor):
4960
_attention_backend = None
5061
_parallel_config = None
@@ -102,6 +113,7 @@ def __call__(
102113
else:
103114
return hidden_states
104115

116+
105117
class QEffFluxAttention(FluxAttention):
106118
def __qeff_init__(self):
107119
processor = QEffFluxAttnProcessor()
@@ -158,6 +170,7 @@ def forward(
158170
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
159171
return encoder_hidden_states, hidden_states
160172

173+
161174
class QEffFluxTransformerBlock(FluxTransformerBlock):
162175
def __init__(
163176
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
@@ -178,7 +191,6 @@ def __init__(
178191
eps=eps,
179192
)
180193

181-
182194
def forward(
183195
self,
184196
hidden_states: torch.Tensor,
@@ -187,15 +199,12 @@ def forward(
187199
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
188200
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
189201
) -> Tuple[torch.Tensor, torch.Tensor]:
190-
191202
temb1 = tuple(torch.split(temb[:6], 1))
192203
temb2 = tuple(torch.split(temb[6:], 1))
193204
norm_hidden_states = self.norm1(hidden_states, shift_msa=temb1[0], scale_msa=temb1[1])
194205
gate_msa, shift_mlp, scale_mlp, gate_mlp = temb1[-4:]
195206

196-
norm_encoder_hidden_states = self.norm1_context(
197-
encoder_hidden_states, shift_msa=temb2[0], scale_msa=temb2[1]
198-
)
207+
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, shift_msa=temb2[0], scale_msa=temb2[1])
199208

200209
c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = temb2[-4:]
201210

@@ -242,6 +251,7 @@ def forward(
242251

243252
return encoder_hidden_states, hidden_states
244253

254+
245255
class QEffFluxTransformer2DModel(FluxTransformer2DModel):
246256
def __init__(
247257
self,
@@ -257,7 +267,6 @@ def __init__(
257267
guidance_embeds: bool = False,
258268
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
259269
):
260-
261270
super().__init__(
262271
patch_size=patch_size,
263272
in_channels=in_channels,
@@ -296,7 +305,6 @@ def __init__(
296305

297306
self.norm_out = QEffAdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
298307

299-
300308
def forward(
301309
self,
302310
hidden_states: torch.Tensor,
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# ----------------------------------------------------------------------------
7+
8+
import os
9+
from typing import Optional
10+
11+
from QEfficient.utils._utils import load_json
12+
13+
14+
def config_manager(cls, config_source: Optional[str] = None):
15+
"""
16+
JSON-based compilation configuration manager for diffusion pipelines.
17+
18+
Supports loading configuration from JSON files only. Automatically detects
19+
model type and handles model-specific requirements.
20+
Initialize the configuration manager.
21+
22+
Args:
23+
config_source: Path to JSON configuration file. If None, uses default config.
24+
"""
25+
if config_source is None:
26+
config_source = cls.get_default_config_path()
27+
28+
if not isinstance(config_source, str):
29+
raise ValueError("config_source must be a path to JSON configuration file")
30+
31+
# Direct use of load_json utility - no wrapper needed
32+
if not os.path.exists(config_source):
33+
raise FileNotFoundError(f"Configuration file not found: {config_source}")
34+
35+
cls._compile_config = load_json(config_source)
36+
37+

QEfficient/diffusers/pipelines/flux/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
6-
# ----------------------------------------------------------------------------
6+
# ----------------------------------------------------------------------------
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
{
2+
"description": "Example compilation configuration for Flux pipeline",
3+
"model_type": "flux",
4+
5+
"modules":
6+
{
7+
"text_encoder":
8+
{
9+
"specializations":{
10+
"batch_size": 1,
11+
"seq_len": 77
12+
},
13+
"compilation":
14+
{
15+
"onnx_path": null,
16+
"compile_dir": null,
17+
"mdp_ts_num_devices": 1,
18+
"mxfp6_matmul": false,
19+
"convert_to_fp16": true,
20+
"aic_num_cores": 16
21+
},
22+
"execute":
23+
{
24+
"device_ids": null
25+
}
26+
27+
},
28+
"text_encoder_2":
29+
{
30+
"specializations":
31+
{
32+
"batch_size": 1,
33+
"seq_len": 256
34+
},
35+
"compilation":
36+
{
37+
"onnx_path": null,
38+
"compile_dir": null,
39+
"mdp_ts_num_devices": 1,
40+
"mxfp6_matmul": false,
41+
"convert_to_fp16": true,
42+
"aic_num_cores": 16
43+
},
44+
"execute":
45+
{
46+
"device_ids": null
47+
}
48+
},
49+
"transformer":
50+
{
51+
"specializations":
52+
{
53+
"batch_size": 1,
54+
"seq_len": 256,
55+
"steps": 1,
56+
"num_layers": 1,
57+
"num_single_layers": 1
58+
},
59+
"compilation":
60+
{
61+
"onnx_path": null,
62+
"compile_dir": null,
63+
"mdp_ts_num_devices": 4,
64+
"mxfp6_matmul": true,
65+
"convert_to_fp16": true,
66+
"aic_num_cores": 16,
67+
"mos": 1,
68+
"mdts-mos": 1
69+
},
70+
"execute":
71+
{
72+
"device_ids": null
73+
}
74+
},
75+
"vae_decoder":
76+
{
77+
"specializations":
78+
{
79+
"batch_size": 1,
80+
"channels": 16
81+
},
82+
"compilation":
83+
{
84+
"onnx_path": null,
85+
"compile_dir": null,
86+
"mdp_ts_num_devices": 1,
87+
"mxfp6_matmul": false,
88+
"convert_to_fp16": true,
89+
"aic_num_cores": 16
90+
},
91+
"execute":
92+
{
93+
"device_ids": null
94+
}
95+
}
96+
}
97+
}

QEfficient/diffusers/pipelines/flux/config/default_flux_execute_config.json

Whitespace-only changes.

0 commit comments

Comments
 (0)