@@ -443,6 +443,28 @@ def __post_init__(self):
443443 device = 'cpu' ,
444444 pin_memory = True ,
445445 )
446+ self .block_table_expanded = self .get_empty (
447+ self .cuda_graph_buffers ,
448+ [
449+ self .max_num_sequences * (1 + self .max_draft_tokens ),
450+ self .kv_cache_manager .max_blocks_per_seq
451+ ],
452+ cache_name = "block_table_expanded" ,
453+ dtype = torch .int32 ,
454+ capture_graph = capture_graph ,
455+ )
456+ self .host_block_table_expanded = torch .zeros_like (
457+ self .block_table_expanded ,
458+ device = 'cpu' ,
459+ pin_memory = True ,
460+ )
461+ self .scheduler_metadata_buffer_expanded = self .get_empty (
462+ self .cuda_graph_buffers ,
463+ (self .num_sms + 1 , 2 ),
464+ cache_name = "scheduler_metadata_buffer_expanded" ,
465+ dtype = torch .int32 ,
466+ capture_graph = capture_graph ,
467+ )
446468
447469 def prepare (self ):
448470 super ().prepare ()
@@ -546,9 +568,38 @@ def prepare(self):
546568 else :
547569 self .max_gen_seq_len = 0
548570
549- # Expand kv_lens_cuda for draft tokens (only generation)
550- gen_kv_lens = kv_lens [self .num_contexts :self .num_seqs ]
551- self .kv_lens_expanded_host = torch .cat ([gen_kv_lens ] * (1 + self .max_draft_tokens ), dim = 0 )
571+ # Because the fp8_paged_mqa_logits only supports seq_len == 1 or 2, so it cannot support
572+ # MTP > 1. To haddle this, when MTP > 1, we flatten the q tensor and expand the kv_lens and
573+ # block_table for to use the fp8_paged_mqa_logits.
574+ if self .max_draft_tokens > 1 :
575+ # Expand kv_lens_cuda (only generation)
576+ num_tokens = self .num_generations * (1 + self .max_draft_tokens )
577+ gen_kv_lens = kv_lens [self .num_contexts :self .num_seqs ]
578+ gen_kv_lens_expanded = torch .stack ([gen_kv_lens ] *
579+ (1 + self .max_draft_tokens ),
580+ dim = 0 )
581+ gen_kv_lens_expanded = gen_kv_lens_expanded .transpose (
582+ 0 , 1 ).contiguous ().flatten ()
583+ self .kv_lens_expanded_host [:num_tokens ].copy_ (gen_kv_lens_expanded )
584+ self .kv_lens_expanded_cuda [:num_tokens ].copy_ (
585+ self .kv_lens_expanded_host [:num_tokens ], non_blocking = True )
586+
587+ # Expand indexer_k_cache_block_offsets (only generation)
588+ if self .kv_cache_manager is not None :
589+ block_ids = self .kv_cache_manager .get_batch_cache_indices (
590+ self .request_ids )
591+ for i in range (self .num_contexts , len (block_ids )):
592+ for j in range (1 + self .max_draft_tokens ):
593+ self .host_block_table_expanded [
594+ (i - self .num_contexts ) *
595+ (1 + self .max_draft_tokens ) +
596+ j , :len (block_ids [i ])].copy_ (
597+ torch .tensor (block_ids [i ],
598+ dtype = torch .int32 ,
599+ device = 'cpu' ))
600+ self .block_table_expanded [:num_tokens ].copy_ (
601+ self .host_block_table_expanded [:num_tokens ],
602+ non_blocking = True )
552603
553604 # Prepare metadata for indexer
554605 Indexer .prepare (metadata = self )
@@ -814,6 +865,15 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
814865 gen_seq_lens , tokens_per_block , metadata .num_sms )
815866 metadata .scheduler_metadata_buffer .copy_ (scheduler_metadata_buffer ,
816867 non_blocking = True )
868+ if metadata .max_draft_tokens > 1 :
869+ # Expand schedule metadata buffer (only generation)
870+ num_tokens = metadata .num_generations * (
871+ 1 + metadata .max_draft_tokens )
872+ kv_lens_expanded = metadata .kv_lens_expanded_cuda [:num_tokens ]
873+ scheduler_metadata_buffer_expanded = get_paged_mqa_logits_metadata (
874+ kv_lens_expanded , tokens_per_block , metadata .num_sms )
875+ metadata .scheduler_metadata_buffer_expanded .copy_ (
876+ scheduler_metadata_buffer_expanded , non_blocking = True )
817877
818878 # Compute slot_mapping for all requests (both context and generation)
819879 # This maps each token to its flat cache position for vectorized KV cache updates
@@ -1065,12 +1125,21 @@ def sparse_attn_indexer(
10651125 ...]
10661126 batch_size = num_generations
10671127 next_n = num_gen_tokens // num_generations
1128+ # Because fp8_paged_mqa_logits cannot support next_n > 2, we need to flatten the q_decode tensor
1129+ # and expand the corresponding metadata.
10681130 if next_n <= 2 :
10691131 q_decode = q_decode .view (num_generations , - 1 , * q_fp8 .shape [1 :])
1070- context_lens = metadata .kv_lens_cuda_runtime [num_contexts :num_contexts +
1071- num_generations ]
1132+ context_lens = metadata .kv_lens_cuda_runtime [
1133+ num_contexts :num_contexts + num_generations ]
1134+ block_table = metadata .indexer_k_cache_block_offsets [
1135+ num_contexts :num_contexts + num_generations ]
1136+ scheduler_metadata_buffer = metadata .scheduler_metadata_buffer
10721137 else :
10731138 q_decode = q_decode .view (- 1 , 1 , * q_fp8 .shape [1 :])
1139+ num_tokens = num_generations * (1 + metadata .max_draft_tokens )
1140+ context_lens = metadata .kv_lens_expanded_cuda [:num_tokens ]
1141+ block_table = metadata .block_table_expanded [:num_tokens ]
1142+ scheduler_metadata_buffer = metadata .scheduler_metadata_buffer_expanded
10741143
10751144 assert num_gen_tokens == batch_size * next_n
10761145 weights_decode = weights [num_ctx_tokens :num_ctx_tokens +
@@ -1080,18 +1149,11 @@ def sparse_attn_indexer(
10801149 # [num_blocks, tokens_per_block, 1, head_dim + scale_size]
10811150 k_cache = metadata .kv_cache_manager .get_indexer_k_cache_buffers (
10821151 self .layer_idx )
1083- logits_decode = fp8_paged_mqa_logits (
1084- q_decode ,
1085- k_cache ,
1086- weights_decode ,
1087- metadata .kv_lens_cuda_runtime [
1088- num_contexts :num_contexts +
1089- num_generations ], # context_lens prepared in prepare()
1090- metadata .indexer_k_cache_block_offsets [
1091- num_contexts :num_contexts +
1092- num_generations ], # Only pass generation request block tables
1093- metadata .scheduler_metadata_buffer ,
1094- max_seq_len )
1152+ logits_decode = fp8_paged_mqa_logits (q_decode , k_cache ,
1153+ weights_decode , context_lens ,
1154+ block_table ,
1155+ scheduler_metadata_buffer ,
1156+ max_seq_len )
10951157 # padded
10961158 positions = torch .arange (
10971159 max_seq_len ,
0 commit comments