From c8fc751466d4f171e46f266481541c29ac25d956 Mon Sep 17 00:00:00 2001 From: samuelt0 Date: Sat, 8 Nov 2025 15:36:28 -0500 Subject: [PATCH] cross attn module for ltx --- src/diffusers/models/transformers/transformer_ltx.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 685c73c07c75..cdc35e9b467c 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -142,6 +142,7 @@ def __init__( self.dropout = dropout self.out_dim = query_dim self.heads = heads + self.is_cross_attention: Optional[bool] = None norm_eps = 1e-5 norm_elementwise_affine = True @@ -166,6 +167,7 @@ def forward( image_rotary_emb: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: + self.is_cross_attention = encoder_hidden_states is not None attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] if len(unused_kwargs) > 0: @@ -324,6 +326,7 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, ) + self.attn1.is_cross_attention = False self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) self.attn2 = LTXAttention( @@ -336,6 +339,7 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, ) + self.attn2.is_cross_attention = True self.ff = FeedForward(dim, activation_fn=activation_fn)