Skip to content

Commit 9d7e52b

Browse files
committed
add mtp3 support.
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
1 parent ee9c050 commit 9d7e52b

File tree

1 file changed

+79
-17
lines changed
  • tensorrt_llm/_torch/attention_backend/sparse

1 file changed

+79
-17
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 79 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,28 @@ def __post_init__(self):
443443
device='cpu',
444444
pin_memory=True,
445445
)
446+
self.block_table_expanded = self.get_empty(
447+
self.cuda_graph_buffers,
448+
[
449+
self.max_num_sequences * (1 + self.max_draft_tokens),
450+
self.kv_cache_manager.max_blocks_per_seq
451+
],
452+
cache_name="block_table_expanded",
453+
dtype=torch.int32,
454+
capture_graph=capture_graph,
455+
)
456+
self.host_block_table_expanded = torch.zeros_like(
457+
self.block_table_expanded,
458+
device='cpu',
459+
pin_memory=True,
460+
)
461+
self.scheduler_metadata_buffer_expanded = self.get_empty(
462+
self.cuda_graph_buffers,
463+
(self.num_sms + 1, 2),
464+
cache_name="scheduler_metadata_buffer_expanded",
465+
dtype=torch.int32,
466+
capture_graph=capture_graph,
467+
)
446468

447469
def prepare(self):
448470
super().prepare()
@@ -546,9 +568,38 @@ def prepare(self):
546568
else:
547569
self.max_gen_seq_len = 0
548570

549-
# Expand kv_lens_cuda for draft tokens (only generation)
550-
gen_kv_lens = kv_lens[self.num_contexts:self.num_seqs]
551-
self.kv_lens_expanded_host = torch.cat([gen_kv_lens] * (1+self.max_draft_tokens), dim=0)
571+
# Because the fp8_paged_mqa_logits only supports seq_len == 1 or 2, so it cannot support
572+
# MTP > 1. To haddle this, when MTP > 1, we flatten the q tensor and expand the kv_lens and
573+
# block_table for to use the fp8_paged_mqa_logits.
574+
if self.max_draft_tokens > 1:
575+
# Expand kv_lens_cuda (only generation)
576+
num_tokens = self.num_generations * (1 + self.max_draft_tokens)
577+
gen_kv_lens = kv_lens[self.num_contexts:self.num_seqs]
578+
gen_kv_lens_expanded = torch.stack([gen_kv_lens] *
579+
(1 + self.max_draft_tokens),
580+
dim=0)
581+
gen_kv_lens_expanded = gen_kv_lens_expanded.transpose(
582+
0, 1).contiguous().flatten()
583+
self.kv_lens_expanded_host[:num_tokens].copy_(gen_kv_lens_expanded)
584+
self.kv_lens_expanded_cuda[:num_tokens].copy_(
585+
self.kv_lens_expanded_host[:num_tokens], non_blocking=True)
586+
587+
# Expand indexer_k_cache_block_offsets (only generation)
588+
if self.kv_cache_manager is not None:
589+
block_ids = self.kv_cache_manager.get_batch_cache_indices(
590+
self.request_ids)
591+
for i in range(self.num_contexts, len(block_ids)):
592+
for j in range(1 + self.max_draft_tokens):
593+
self.host_block_table_expanded[
594+
(i - self.num_contexts) *
595+
(1 + self.max_draft_tokens) +
596+
j, :len(block_ids[i])].copy_(
597+
torch.tensor(block_ids[i],
598+
dtype=torch.int32,
599+
device='cpu'))
600+
self.block_table_expanded[:num_tokens].copy_(
601+
self.host_block_table_expanded[:num_tokens],
602+
non_blocking=True)
552603

553604
# Prepare metadata for indexer
554605
Indexer.prepare(metadata=self)
@@ -814,6 +865,15 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
814865
gen_seq_lens, tokens_per_block, metadata.num_sms)
815866
metadata.scheduler_metadata_buffer.copy_(scheduler_metadata_buffer,
816867
non_blocking=True)
868+
if metadata.max_draft_tokens > 1:
869+
# Expand schedule metadata buffer (only generation)
870+
num_tokens = metadata.num_generations * (
871+
1 + metadata.max_draft_tokens)
872+
kv_lens_expanded = metadata.kv_lens_expanded_cuda[:num_tokens]
873+
scheduler_metadata_buffer_expanded = get_paged_mqa_logits_metadata(
874+
kv_lens_expanded, tokens_per_block, metadata.num_sms)
875+
metadata.scheduler_metadata_buffer_expanded.copy_(
876+
scheduler_metadata_buffer_expanded, non_blocking=True)
817877

818878
# Compute slot_mapping for all requests (both context and generation)
819879
# This maps each token to its flat cache position for vectorized KV cache updates
@@ -1065,12 +1125,21 @@ def sparse_attn_indexer(
10651125
...]
10661126
batch_size = num_generations
10671127
next_n = num_gen_tokens // num_generations
1128+
# Because fp8_paged_mqa_logits cannot support next_n > 2, we need to flatten the q_decode tensor
1129+
# and expand the corresponding metadata.
10681130
if next_n <= 2:
10691131
q_decode = q_decode.view(num_generations, -1, *q_fp8.shape[1:])
1070-
context_lens = metadata.kv_lens_cuda_runtime[num_contexts:num_contexts +
1071-
num_generations]
1132+
context_lens = metadata.kv_lens_cuda_runtime[
1133+
num_contexts:num_contexts + num_generations]
1134+
block_table = metadata.indexer_k_cache_block_offsets[
1135+
num_contexts:num_contexts + num_generations]
1136+
scheduler_metadata_buffer = metadata.scheduler_metadata_buffer
10721137
else:
10731138
q_decode = q_decode.view(-1, 1, *q_fp8.shape[1:])
1139+
num_tokens = num_generations * (1 + metadata.max_draft_tokens)
1140+
context_lens = metadata.kv_lens_expanded_cuda[:num_tokens]
1141+
block_table = metadata.block_table_expanded[:num_tokens]
1142+
scheduler_metadata_buffer = metadata.scheduler_metadata_buffer_expanded
10741143

10751144
assert num_gen_tokens == batch_size * next_n
10761145
weights_decode = weights[num_ctx_tokens:num_ctx_tokens +
@@ -1080,18 +1149,11 @@ def sparse_attn_indexer(
10801149
# [num_blocks, tokens_per_block, 1, head_dim + scale_size]
10811150
k_cache = metadata.kv_cache_manager.get_indexer_k_cache_buffers(
10821151
self.layer_idx)
1083-
logits_decode = fp8_paged_mqa_logits(
1084-
q_decode,
1085-
k_cache,
1086-
weights_decode,
1087-
metadata.kv_lens_cuda_runtime[
1088-
num_contexts:num_contexts +
1089-
num_generations], # context_lens prepared in prepare()
1090-
metadata.indexer_k_cache_block_offsets[
1091-
num_contexts:num_contexts +
1092-
num_generations], # Only pass generation request block tables
1093-
metadata.scheduler_metadata_buffer,
1094-
max_seq_len)
1152+
logits_decode = fp8_paged_mqa_logits(q_decode, k_cache,
1153+
weights_decode, context_lens,
1154+
block_table,
1155+
scheduler_metadata_buffer,
1156+
max_seq_len)
10951157
# padded
10961158
positions = torch.arange(
10971159
max_seq_len,

0 commit comments

Comments
 (0)