@@ -431,6 +431,18 @@ def __post_init__(self):
431431 dtype=torch.int32,
432432 capture_graph=capture_graph,
433433 )
434+ self.kv_lens_expanded_cuda = self.get_empty(
435+ self.cuda_graph_buffers,
436+ (self.max_num_sequences * (1 + self.max_draft_tokens), ),
437+ cache_name="kv_lens_expanded_cuda",
438+ dtype=torch.int32,
439+ capture_graph=capture_graph,
440+ )
441+ self.kv_lens_expanded_host = torch.zeros_like(
442+ self.kv_lens_expanded_cuda,
443+ device='cpu',
444+ pin_memory=True,
445+ )
434446
435447 def prepare(self):
436448 super().prepare()
@@ -534,6 +546,10 @@ def prepare(self):
534546 else:
535547 self.max_gen_seq_len = 0
536548
549+ # Expand kv_lens_cuda for draft tokens (only generation)
550+ gen_kv_lens = kv_lens[self.num_contexts:self.num_seqs]
551+ self.kv_lens_expanded_host = torch.cat([gen_kv_lens] * (1+self.max_draft_tokens), dim=0)
552+
537553 # Prepare metadata for indexer
538554 Indexer.prepare(metadata=self)
539555
@@ -1047,9 +1063,15 @@ def sparse_attn_indexer(
10471063 # Reshape q for decode phase: [num_gen_tokens, ...] -> [batch_size, next_n, ...]
10481064 q_decode = q_fp8[num_ctx_tokens:num_ctx_tokens + num_gen_tokens,
10491065 ...]
1050- q_decode = q_decode.view(num_generations, -1, *q_fp8.shape[1:])
1051- batch_size = q_decode.shape[0]
1052- next_n = q_decode.shape[1]
1066+ batch_size = num_generations
1067+ next_n = num_gen_tokens // num_generations
1068+ if next_n <= 2:
1069+ q_decode = q_decode.view(num_generations, -1, *q_fp8.shape[1:])
1070+ context_lens = metadata.kv_lens_cuda_runtime[num_contexts:num_contexts +
1071+ num_generations]
1072+ else:
1073+ q_decode = q_decode.view(-1, 1, *q_fp8.shape[1:])
1074+
10531075 assert num_gen_tokens == batch_size * next_n
10541076 weights_decode = weights[num_ctx_tokens:num_ctx_tokens +
10551077 num_gen_tokens, ...]
0 commit comments