@@ -570,8 +570,9 @@ def prepare(self):
570570 self .max_gen_seq_len = 0
571571
572572 # Because the fp8_paged_mqa_logits only supports seq_len == 1 or 2, so it cannot support
573- # MTP > 1. To haddle this, when MTP > 1, we flatten the q tensor and expand the kv_lens and
573+ # MTP > 1. To handle this, when MTP > 1, we flatten the q tensor and expand the kv_lens and
574574 # block_table for to use the fp8_paged_mqa_logits.
575+ # TODO: remove this when fp8_paged_mqa_logits supports MTP > 1.
575576 if self .max_draft_tokens > 1 :
576577 # Expand kv_lens_cuda (only generation)
577578 num_tokens = self .num_generations * (1 + self .max_draft_tokens )
@@ -589,18 +590,19 @@ def prepare(self):
589590 if self .kv_cache_manager is not None :
590591 block_ids = self .kv_cache_manager .get_batch_cache_indices (
591592 self .request_ids )
592- for i in range (self .num_contexts , len (block_ids )):
593- for j in range (1 + self .max_draft_tokens ):
594- self .host_block_table_expanded [
595- (i - self .num_contexts ) *
596- (1 + self .max_draft_tokens ) +
597- j , :len (block_ids [i ])].copy_ (
598- torch .tensor (block_ids [i ],
599- dtype = torch .int32 ,
600- device = 'cpu' ))
601- self .block_table_expanded [:num_tokens ].copy_ (
602- self .host_block_table_expanded [:num_tokens ],
603- non_blocking = True )
593+ gen_block_ids = block_ids [self .num_contexts :]
594+ if len (gen_block_ids ) > 0 :
595+ # Find max length and create padded tensor
596+ max_len = max (len (bid ) for bid in gen_block_ids )
597+ gen_block_tensor = self .host_indexer_k_cache_block_offsets [
598+ self .num_contexts :self .num_seqs , :max_len ]
599+ expanded_blocks = gen_block_tensor .repeat_interleave (
600+ 1 + self .max_draft_tokens , dim = 0 )
601+ self .host_block_table_expanded [:num_tokens , :max_len ].copy_ (
602+ expanded_blocks , non_blocking = True )
603+ self .block_table_expanded [:num_tokens ].copy_ (
604+ self .host_block_table_expanded [:num_tokens ],
605+ non_blocking = True )
604606
605607 # Prepare metadata for indexer
606608 Indexer .prepare (metadata = self )
@@ -866,13 +868,14 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
866868 if num_generations > 0 :
867869 # Prepare schedule metadata for fp8_paged_mqa_logits
868870 # This is a preprocessing step that computes scheduling information for the kernel
869- gen_seq_lens = metadata .kv_lens_cuda_runtime [
870- num_contexts :num_contexts + num_generations ]
871- scheduler_metadata_buffer = get_paged_mqa_logits_metadata (
872- gen_seq_lens , tokens_per_block , metadata .num_sms )
873- metadata .scheduler_metadata_buffer .copy_ (scheduler_metadata_buffer ,
874- non_blocking = True )
875- if metadata .max_draft_tokens > 1 :
871+ if metadata .max_draft_tokens <= 1 :
872+ gen_seq_lens = metadata .kv_lens_cuda_runtime [
873+ num_contexts :num_contexts + num_generations ]
874+ scheduler_metadata_buffer = get_paged_mqa_logits_metadata (
875+ gen_seq_lens , tokens_per_block , metadata .num_sms )
876+ metadata .scheduler_metadata_buffer .copy_ (
877+ scheduler_metadata_buffer , non_blocking = True )
878+ else :
876879 # Expand schedule metadata buffer (only generation)
877880 num_tokens = metadata .num_generations * (
878881 1 + metadata .max_draft_tokens )
0 commit comments