@@ -121,6 +121,7 @@ def __init__(
121121 qk_norm : bool = False ,
122122 scale_norm : bool = False ,
123123 proj_bias : bool = True ,
124+ rotate_half : bool = False ,
124125 ):
125126 """Initialize the Attention module.
126127
@@ -136,6 +137,7 @@ def __init__(
136137 norm_layer: Normalization layer constructor to use for QK and scale normalization
137138 qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer
138139 scale_norm: Enable normalization (scaling) of attention output with norm_layer
140+ rotate_half: Use 'half' ROPE layout instead of default 'interleaved'
139141 """
140142 super ().__init__ ()
141143 if scale_norm or qk_norm :
@@ -148,6 +150,7 @@ def __init__(
148150 self .scale = head_dim ** - 0.5
149151 self .num_prefix_tokens = num_prefix_tokens
150152 self .fused_attn = use_fused_attn ()
153+ self .rotate_half = rotate_half
151154
152155 if qkv_fused :
153156 self .qkv = nn .Linear (dim , attn_dim * 3 , bias = qkv_bias )
@@ -196,8 +199,9 @@ def forward(
196199
197200 if rope is not None :
198201 npt = self .num_prefix_tokens
199- q = torch .cat ([q [:, :, :npt , :], apply_rot_embed_cat (q [:, :, npt :, :], rope )], dim = 2 ).type_as (v )
200- k = torch .cat ([k [:, :, :npt , :], apply_rot_embed_cat (k [:, :, npt :, :], rope )], dim = 2 ).type_as (v )
202+ half = getattr (self , 'rotate_half' , False )
203+ q = torch .cat ([q [:, :, :npt , :], apply_rot_embed_cat (q [:, :, npt :, :], rope , half = half )], dim = 2 ).type_as (v )
204+ k = torch .cat ([k [:, :, :npt , :], apply_rot_embed_cat (k [:, :, npt :, :], rope , half = half )], dim = 2 ).type_as (v )
201205
202206 if self .fused_attn :
203207 x = F .scaled_dot_product_attention (
0 commit comments