@@ -431,6 +431,18 @@ def __post_init__(self):
431431 dtype = torch .int32 ,
432432 capture_graph = capture_graph ,
433433 )
434+ self .kv_lens_expanded_cuda = self .get_empty (
435+ self .cuda_graph_buffers ,
436+ (self .max_num_sequences * (1 + self .max_draft_tokens ), ),
437+ cache_name = "kv_lens_expanded_cuda" ,
438+ dtype = torch .int32 ,
439+ capture_graph = capture_graph ,
440+ )
441+ self .kv_lens_expanded_host = torch .zeros_like (
442+ self .kv_lens_expanded_cuda ,
443+ device = 'cpu' ,
444+ pin_memory = True ,
445+ )
434446
435447 def prepare (self ):
436448 super ().prepare ()
@@ -534,6 +546,10 @@ def prepare(self):
534546 else :
535547 self .max_gen_seq_len = 0
536548
549+ # Expand kv_lens_cuda for draft tokens (only generation)
550+ gen_kv_lens = kv_lens [self .num_contexts :self .num_seqs ]
551+ self .kv_lens_expanded_host = torch .cat ([gen_kv_lens ] * (1 + self .max_draft_tokens ), dim = 0 )
552+
537553 # Prepare metadata for indexer
538554 Indexer .prepare (metadata = self )
539555
@@ -1049,9 +1065,15 @@ def sparse_attn_indexer(
10491065 # Reshape q for decode phase: [num_gen_tokens, ...] -> [batch_size, next_n, ...]
10501066 q_decode = q_fp8 [num_ctx_tokens :num_ctx_tokens + num_gen_tokens ,
10511067 ...]
1052- q_decode = q_decode .view (num_generations , - 1 , * q_fp8 .shape [1 :])
1053- batch_size = q_decode .shape [0 ]
1054- next_n = q_decode .shape [1 ]
1068+ batch_size = num_generations
1069+ next_n = num_gen_tokens // num_generations
1070+ if next_n <= 2 :
1071+ q_decode = q_decode .view (num_generations , - 1 , * q_fp8 .shape [1 :])
1072+ context_lens = metadata .kv_lens_cuda_runtime [num_contexts :num_contexts +
1073+ num_generations ]
1074+ else :
1075+ q_decode = q_decode .view (- 1 , 1 , * q_fp8 .shape [1 :])
1076+
10551077 assert num_gen_tokens == batch_size * next_n
10561078 weights_decode = weights [num_ctx_tokens :num_ctx_tokens +
10571079 num_gen_tokens , ...]
0 commit comments