Skip to content

Commit b7551b0

Browse files
committed
add mtp3 support.
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
1 parent d8df21a commit b7551b0

File tree

1 file changed

+79
-17
lines changed
  • tensorrt_llm/_torch/attention_backend/sparse

1 file changed

+79
-17
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 79 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,28 @@ def __post_init__(self):
443443
device='cpu',
444444
pin_memory=True,
445445
)
446+
self.block_table_expanded = self.get_empty(
447+
self.cuda_graph_buffers,
448+
[
449+
self.max_num_sequences * (1 + self.max_draft_tokens),
450+
self.kv_cache_manager.max_blocks_per_seq
451+
],
452+
cache_name="block_table_expanded",
453+
dtype=torch.int32,
454+
capture_graph=capture_graph,
455+
)
456+
self.host_block_table_expanded = torch.zeros_like(
457+
self.block_table_expanded,
458+
device='cpu',
459+
pin_memory=True,
460+
)
461+
self.scheduler_metadata_buffer_expanded = self.get_empty(
462+
self.cuda_graph_buffers,
463+
(self.num_sms + 1, 2),
464+
cache_name="scheduler_metadata_buffer_expanded",
465+
dtype=torch.int32,
466+
capture_graph=capture_graph,
467+
)
446468

447469
def prepare(self):
448470
super().prepare()
@@ -546,9 +568,38 @@ def prepare(self):
546568
else:
547569
self.max_gen_seq_len = 0
548570

549-
# Expand kv_lens_cuda for draft tokens (only generation)
550-
gen_kv_lens = kv_lens[self.num_contexts:self.num_seqs]
551-
self.kv_lens_expanded_host = torch.cat([gen_kv_lens] * (1+self.max_draft_tokens), dim=0)
571+
# Because the fp8_paged_mqa_logits only supports seq_len == 1 or 2, so it cannot support
572+
# MTP > 1. To haddle this, when MTP > 1, we flatten the q tensor and expand the kv_lens and
573+
# block_table for to use the fp8_paged_mqa_logits.
574+
if self.max_draft_tokens > 1:
575+
# Expand kv_lens_cuda (only generation)
576+
num_tokens = self.num_generations * (1 + self.max_draft_tokens)
577+
gen_kv_lens = kv_lens[self.num_contexts:self.num_seqs]
578+
gen_kv_lens_expanded = torch.stack([gen_kv_lens] *
579+
(1 + self.max_draft_tokens),
580+
dim=0)
581+
gen_kv_lens_expanded = gen_kv_lens_expanded.transpose(
582+
0, 1).contiguous().flatten()
583+
self.kv_lens_expanded_host[:num_tokens].copy_(gen_kv_lens_expanded)
584+
self.kv_lens_expanded_cuda[:num_tokens].copy_(
585+
self.kv_lens_expanded_host[:num_tokens], non_blocking=True)
586+
587+
# Expand indexer_k_cache_block_offsets (only generation)
588+
if self.kv_cache_manager is not None:
589+
block_ids = self.kv_cache_manager.get_batch_cache_indices(
590+
self.request_ids)
591+
for i in range(self.num_contexts, len(block_ids)):
592+
for j in range(1 + self.max_draft_tokens):
593+
self.host_block_table_expanded[
594+
(i - self.num_contexts) *
595+
(1 + self.max_draft_tokens) +
596+
j, :len(block_ids[i])].copy_(
597+
torch.tensor(block_ids[i],
598+
dtype=torch.int32,
599+
device='cpu'))
600+
self.block_table_expanded[:num_tokens].copy_(
601+
self.host_block_table_expanded[:num_tokens],
602+
non_blocking=True)
552603

553604
# Prepare metadata for indexer
554605
Indexer.prepare(metadata=self)
@@ -814,6 +865,15 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
814865
gen_seq_lens, tokens_per_block, metadata.num_sms)
815866
metadata.scheduler_metadata_buffer.copy_(scheduler_metadata_buffer,
816867
non_blocking=True)
868+
if metadata.max_draft_tokens > 1:
869+
# Expand schedule metadata buffer (only generation)
870+
num_tokens = metadata.num_generations * (
871+
1 + metadata.max_draft_tokens)
872+
kv_lens_expanded = metadata.kv_lens_expanded_cuda[:num_tokens]
873+
scheduler_metadata_buffer_expanded = get_paged_mqa_logits_metadata(
874+
kv_lens_expanded, tokens_per_block, metadata.num_sms)
875+
metadata.scheduler_metadata_buffer_expanded.copy_(
876+
scheduler_metadata_buffer_expanded, non_blocking=True)
817877

818878
# Compute slot_mapping for all requests (both context and generation)
819879
# This maps each token to its flat cache position for vectorized KV cache updates
@@ -1067,12 +1127,21 @@ def sparse_attn_indexer(
10671127
...]
10681128
batch_size = num_generations
10691129
next_n = num_gen_tokens // num_generations
1130+
# Because fp8_paged_mqa_logits cannot support next_n > 2, we need to flatten the q_decode tensor
1131+
# and expand the corresponding metadata.
10701132
if next_n <= 2:
10711133
q_decode = q_decode.view(num_generations, -1, *q_fp8.shape[1:])
1072-
context_lens = metadata.kv_lens_cuda_runtime[num_contexts:num_contexts +
1073-
num_generations]
1134+
context_lens = metadata.kv_lens_cuda_runtime[
1135+
num_contexts:num_contexts + num_generations]
1136+
block_table = metadata.indexer_k_cache_block_offsets[
1137+
num_contexts:num_contexts + num_generations]
1138+
scheduler_metadata_buffer = metadata.scheduler_metadata_buffer
10741139
else:
10751140
q_decode = q_decode.view(-1, 1, *q_fp8.shape[1:])
1141+
num_tokens = num_generations * (1 + metadata.max_draft_tokens)
1142+
context_lens = metadata.kv_lens_expanded_cuda[:num_tokens]
1143+
block_table = metadata.block_table_expanded[:num_tokens]
1144+
scheduler_metadata_buffer = metadata.scheduler_metadata_buffer_expanded
10761145

10771146
assert num_gen_tokens == batch_size * next_n
10781147
weights_decode = weights[num_ctx_tokens:num_ctx_tokens +
@@ -1082,18 +1151,11 @@ def sparse_attn_indexer(
10821151
# [num_blocks, tokens_per_block, 1, head_dim + scale_size]
10831152
k_cache = metadata.kv_cache_manager.get_indexer_k_cache_buffers(
10841153
self.layer_idx)
1085-
logits_decode = fp8_paged_mqa_logits(
1086-
q_decode,
1087-
k_cache,
1088-
weights_decode,
1089-
metadata.kv_lens_cuda_runtime[
1090-
num_contexts:num_contexts +
1091-
num_generations], # context_lens prepared in prepare()
1092-
metadata.indexer_k_cache_block_offsets[
1093-
num_contexts:num_contexts +
1094-
num_generations], # Only pass generation request block tables
1095-
metadata.scheduler_metadata_buffer,
1096-
max_seq_len)
1154+
logits_decode = fp8_paged_mqa_logits(q_decode, k_cache,
1155+
weights_decode, context_lens,
1156+
block_table,
1157+
scheduler_metadata_buffer,
1158+
max_seq_len)
10971159

10981160
if use_custom_topk:
10991161
# Kernel expects kv_lens (total cache length), not seq_lens (new tokens)

0 commit comments

Comments
 (0)