1212import torch .nn as nn
1313import numpy as np
1414
15- from diffusers .models .transformers .transformer_flux import FluxAttention ,FluxSingleTransformerBlock , FluxTransformerBlock , FluxTransformer2DModel , FluxPosEmbed , FluxAttnProcessor
15+ from diffusers .models .transformers .transformer_flux import FluxAttention ,FluxSingleTransformerBlock , FluxTransformerBlock , FluxTransformer2DModel , FluxPosEmbed , FluxAttnProcessor , _get_qkv_projections
1616from diffusers .models .modeling_outputs import Transformer2DModelOutput
17- from diffusers .models .embeddings import CombinedTimestepGuidanceTextProjEmbeddings , CombinedTimestepTextProjEmbeddings
18- from diffusers .models .attention import FeedForward
17+ from diffusers .models .attention_dispatch import dispatch_attention_fn
1918
2019from QEfficient .diffusers .models .normalization import QEffAdaLayerNormZero , QEffAdaLayerNormZeroSingle , QEffAdaLayerNormContinuous
2120
21+ def qeff_apply_rotary_emb (
22+ x : torch .Tensor ,
23+ freqs_cis : Union [torch .Tensor , Tuple [torch .Tensor ]]) -> Tuple [torch .Tensor , torch .Tensor ]:
24+ """
25+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
26+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
27+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
28+ tensors contain rotary embeddings and are returned as real tensors.
29+
30+ Args:
31+ x (`torch.Tensor`):
32+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
33+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
34+
35+ Returns:
36+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
37+ """
38+ cos , sin = freqs_cis # [S, D]
39+ cos = cos [None , :, None , :]
40+ sin = sin [None , :, None , :]
41+ cos , sin = cos .to (x .device ), sin .to (x .device )
42+ B , S , H , D = x .shape
43+ x_real , x_imag = x .reshape (B , - 1 , H , D // 2 , 2 ).unbind (- 1 )
44+ x_rotated = torch .stack ([- x_imag , x_real ], dim = - 1 ).flatten (3 )
45+ out = (x .float () * cos + x_rotated .float () * sin ).to (x .dtype )
46+ return out
47+
48+ class QEffFluxAttnProcessor (FluxAttnProcessor ):
49+ _attention_backend = None
50+ _parallel_config = None
51+
52+ def __call__ (
53+ self ,
54+ attn : "QEffFluxAttention" ,
55+ hidden_states : torch .Tensor ,
56+ encoder_hidden_states : torch .Tensor = None ,
57+ attention_mask : Optional [torch .Tensor ] = None ,
58+ image_rotary_emb : Optional [torch .Tensor ] = None ,
59+ ) -> torch .Tensor :
60+ query , key , value , encoder_query , encoder_key , encoder_value = _get_qkv_projections (
61+ attn , hidden_states , encoder_hidden_states
62+ )
63+
64+ query = query .unflatten (- 1 , (attn .heads , - 1 ))
65+ key = key .unflatten (- 1 , (attn .heads , - 1 ))
66+ value = value .unflatten (- 1 , (attn .heads , - 1 ))
67+
68+ query = attn .norm_q (query )
69+ key = attn .norm_k (key )
70+
71+ if attn .added_kv_proj_dim is not None :
72+ encoder_query = encoder_query .unflatten (- 1 , (attn .heads , - 1 ))
73+ encoder_key = encoder_key .unflatten (- 1 , (attn .heads , - 1 ))
74+ encoder_value = encoder_value .unflatten (- 1 , (attn .heads , - 1 ))
75+
76+ encoder_query = attn .norm_added_q (encoder_query )
77+ encoder_key = attn .norm_added_k (encoder_key )
78+
79+ query = torch .cat ([encoder_query , query ], dim = 1 )
80+ key = torch .cat ([encoder_key , key ], dim = 1 )
81+ value = torch .cat ([encoder_value , value ], dim = 1 )
82+
83+ if image_rotary_emb is not None :
84+ query = qeff_apply_rotary_emb (query , image_rotary_emb )
85+ key = qeff_apply_rotary_emb (key , image_rotary_emb )
86+
87+ hidden_states = dispatch_attention_fn (
88+ query , key , value , attn_mask = attention_mask , backend = self ._attention_backend
89+ )
90+ hidden_states = hidden_states .flatten (2 , 3 )
91+ hidden_states = hidden_states .to (query .dtype )
92+
93+ if encoder_hidden_states is not None :
94+ encoder_hidden_states , hidden_states = hidden_states .split_with_sizes (
95+ [encoder_hidden_states .shape [1 ], hidden_states .shape [1 ] - encoder_hidden_states .shape [1 ]], dim = 1
96+ )
97+ hidden_states = attn .to_out [0 ](hidden_states )
98+ hidden_states = attn .to_out [1 ](hidden_states )
99+ encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
100+
101+ return hidden_states , encoder_hidden_states
102+ else :
103+ return hidden_states
104+
105+ class QEffFluxAttention (FluxAttention ):
106+ def __qeff_init__ (self ):
107+ processor = QEffFluxAttnProcessor ()
108+ self .processor = processor
109+
110+
22111class QEffFluxSingleTransformerBlock (FluxSingleTransformerBlock ):
23112 def __init__ (self , dim : int , num_attention_heads : int , attention_head_dim : int , mlp_ratio : float = 4.0 ):
24113 super ().__init__ (dim , num_attention_heads , attention_head_dim , mlp_ratio )
@@ -27,13 +116,13 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int,
27116 self .proj_mlp = nn .Linear (dim , self .mlp_hidden_dim )
28117 self .act_mlp = nn .GELU (approximate = "tanh" )
29118 self .proj_out = nn .Linear (dim + self .mlp_hidden_dim , dim )
30- self .attn = FluxAttention (
119+ self .attn = QEffFluxAttention (
31120 query_dim = dim ,
32121 dim_head = attention_head_dim ,
33122 heads = num_attention_heads ,
34123 out_dim = dim ,
35124 bias = True ,
36- processor = FluxAttnProcessor (),
125+ processor = QEffFluxAttnProcessor (),
37126 eps = 1e-6 ,
38127 pre_only = True ,
39128 )
@@ -77,23 +166,18 @@ def __init__(
77166
78167 self .norm1 = QEffAdaLayerNormZero (dim )
79168 self .norm1_context = QEffAdaLayerNormZero (dim )
80- self .attn = FluxAttention (
169+ self .attn = QEffFluxAttention (
81170 query_dim = dim ,
82171 added_kv_proj_dim = dim ,
83172 dim_head = attention_head_dim ,
84173 heads = num_attention_heads ,
85174 out_dim = dim ,
86175 context_pre_only = False ,
87176 bias = True ,
88- processor = FluxAttnProcessor (),
177+ processor = QEffFluxAttnProcessor (),
89178 eps = eps ,
90179 )
91180
92- self .norm2 = nn .LayerNorm (dim , elementwise_affine = False , eps = 1e-6 )
93- self .ff = FeedForward (dim = dim , dim_out = dim , activation_fn = "gelu-approximate" )
94-
95- self .norm2_context = nn .LayerNorm (dim , elementwise_affine = False , eps = 1e-6 )
96- self .ff_context = FeedForward (dim = dim , dim_out = dim , activation_fn = "gelu-approximate" )
97181
98182 def forward (
99183 self ,
@@ -174,9 +258,6 @@ def __init__(
174258 axes_dims_rope : Tuple [int , int , int ] = (16 , 56 , 56 ),
175259 ):
176260
177- resolved_out_channels = out_channels or in_channels
178- inner_dim = num_attention_heads * attention_head_dim
179-
180261 super ().__init__ (
181262 patch_size = patch_size ,
182263 in_channels = in_channels ,
@@ -191,21 +272,6 @@ def __init__(
191272 axes_dims_rope = axes_dims_rope ,
192273 )
193274
194- self .out_channels = resolved_out_channels
195- self .inner_dim = inner_dim
196-
197- self .pos_embed = FluxPosEmbed (theta = 10000 , axes_dim = axes_dims_rope )
198-
199- text_time_guidance_cls = (
200- CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
201- )
202- self .time_text_embed = text_time_guidance_cls (
203- embedding_dim = self .inner_dim , pooled_projection_dim = pooled_projection_dim
204- )
205-
206- self .context_embedder = nn .Linear (joint_attention_dim , self .inner_dim )
207- self .x_embedder = nn .Linear (in_channels , self .inner_dim )
208-
209275 self .transformer_blocks = nn .ModuleList (
210276 [
211277 QEffFluxTransformerBlock (
@@ -229,9 +295,7 @@ def __init__(
229295 )
230296
231297 self .norm_out = QEffAdaLayerNormContinuous (self .inner_dim , self .inner_dim , elementwise_affine = False , eps = 1e-6 )
232- self .proj_out = nn .Linear (self .inner_dim , patch_size * patch_size * self .out_channels , bias = True )
233298
234- self .gradient_checkpointing = False
235299
236300 def forward (
237301 self ,
0 commit comments