@@ -1969,10 +1969,10 @@ def forward_sparse_mla_kvcache_bf16(
19691969 q , latent_cache , attn_metadata , is_generation = is_generation )
19701970
19711971 num_tokens = q .shape [0 ]
1972- q_nope , q_rope = q .view (- 1 , self .num_heads , self .qk_head_dim ).split (
1972+ q_nope , q_rope = q .view (- 1 , self .num_heads_tp , self .qk_head_dim ).split (
19731973 [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
19741974 q_nope_out = torch .empty (
1975- [num_tokens , self .num_heads , (self .kv_lora_rank )],
1975+ [num_tokens , self .num_heads_tp , (self .kv_lora_rank )],
19761976 dtype = q .dtype ,
19771977 device = q .device ,
19781978 )
@@ -2011,23 +2011,23 @@ def forward_sparse_mla_kvcache_bf16(
20112011 # FlashMLA sparse kernel (bf16) requires num_heads=128 on sm100 or multiple of 64 on sm90
20122012 if sm_version >= 100 :
20132013 padding = 128
2014- assert self .num_heads <= padding , (
2014+ assert self .num_heads_tp <= padding , (
20152015 f"SM100 FlashMLA sparse kernel requires exactly { padding } heads, "
2016- f"got { self .num_heads } . Padding from values > { padding } is not supported."
2016+ f"got { self .num_heads_tp } . Padding from values > { padding } is not supported."
20172017 )
20182018 else : # SM90
2019- padding = ((self .num_heads + 63 ) // 64 ) * 64 # multiple of 64
2019+ padding = ((self .num_heads_tp + 63 ) // 64 ) * 64 # multiple of 64
20202020
2021- if self .num_heads != padding :
2021+ if self .num_heads_tp != padding :
20222022 logger .warning_once (
2023- f"Padding num_heads from { self .num_heads } to { padding } "
2023+ f"Padding num_heads from { self .num_heads_tp } to { padding } "
20242024 f"due to FlashMLA sparse attention kernel requirement" ,
20252025 key = "sparse_mla_padding_warning" )
20262026
20272027 # Create padded tensor with zeros for extra heads
20282028 q_padded = q_concat .new_empty (
20292029 (num_tokens , padding , q_concat .shape [2 ]))
2030- q_padded [:, :self .num_heads , :] = q_concat
2030+ q_padded [:, :self .num_heads_tp , :] = q_concat
20312031 q_concat = q_padded
20322032
20332033 # Convert indices and return all-layer KV pool
@@ -2049,17 +2049,17 @@ def forward_sparse_mla_kvcache_bf16(
20492049 "flash_mla_sparse_fwd not available. Please ensure FlashMLA module is built."
20502050 )
20512051
2052- # [seq, num_heads, kv_lora_rank]
2053- attn_out_latent = attn_out_latent [:, :self .
2054- num_heads , :] # account for padding
2052+ # [seq, num_heads, kv_lora_rank], account for padding
2053+ attn_out_latent = attn_out_latent [:, :self .num_heads_tp , :]
20552054 # TODO: seems we need .contiguous() here when padding enabled before pass to bmm?
20562055 attn_out_latent = attn_out_latent .view (
2057- [- 1 , self .num_heads , self .kv_lora_rank ])
2056+ [- 1 , self .num_heads_tp , self .kv_lora_rank ])
20582057
20592058 assert (attn_out_latent .shape [0 ] == q .shape [0 ]
2060- and attn_out_latent .shape [1 ] == self .num_heads )
2059+ and attn_out_latent .shape [1 ] == self .num_heads_tp )
20612060
2062- attn_output = output .view ([num_tokens , self .num_heads , self .v_head_dim ])
2061+ attn_output = output .view (
2062+ [num_tokens , self .num_heads_tp , self .v_head_dim ])
20632063
20642064 if self .v_b_proj .dtype == torch .bfloat16 :
20652065 # [num_heads, seq, kv_lora_rank] x [num_heads, kv_lora_rank, v_head_dim]
0 commit comments