Skip to content

Commit ee9c050

Browse files
committed
add mtp3 support.
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
1 parent 1c6e490 commit ee9c050

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,18 @@ def __post_init__(self):
431431
dtype=torch.int32,
432432
capture_graph=capture_graph,
433433
)
434+
self.kv_lens_expanded_cuda = self.get_empty(
435+
self.cuda_graph_buffers,
436+
(self.max_num_sequences * (1 + self.max_draft_tokens), ),
437+
cache_name="kv_lens_expanded_cuda",
438+
dtype=torch.int32,
439+
capture_graph=capture_graph,
440+
)
441+
self.kv_lens_expanded_host = torch.zeros_like(
442+
self.kv_lens_expanded_cuda,
443+
device='cpu',
444+
pin_memory=True,
445+
)
434446

435447
def prepare(self):
436448
super().prepare()
@@ -534,6 +546,10 @@ def prepare(self):
534546
else:
535547
self.max_gen_seq_len = 0
536548

549+
# Expand kv_lens_cuda for draft tokens (only generation)
550+
gen_kv_lens = kv_lens[self.num_contexts:self.num_seqs]
551+
self.kv_lens_expanded_host = torch.cat([gen_kv_lens] * (1+self.max_draft_tokens), dim=0)
552+
537553
# Prepare metadata for indexer
538554
Indexer.prepare(metadata=self)
539555

@@ -1047,9 +1063,15 @@ def sparse_attn_indexer(
10471063
# Reshape q for decode phase: [num_gen_tokens, ...] -> [batch_size, next_n, ...]
10481064
q_decode = q_fp8[num_ctx_tokens:num_ctx_tokens + num_gen_tokens,
10491065
...]
1050-
q_decode = q_decode.view(num_generations, -1, *q_fp8.shape[1:])
1051-
batch_size = q_decode.shape[0]
1052-
next_n = q_decode.shape[1]
1066+
batch_size = num_generations
1067+
next_n = num_gen_tokens // num_generations
1068+
if next_n <= 2:
1069+
q_decode = q_decode.view(num_generations, -1, *q_fp8.shape[1:])
1070+
context_lens = metadata.kv_lens_cuda_runtime[num_contexts:num_contexts +
1071+
num_generations]
1072+
else:
1073+
q_decode = q_decode.view(-1, 1, *q_fp8.shape[1:])
1074+
10531075
assert num_gen_tokens == batch_size * next_n
10541076
weights_decode = weights[num_ctx_tokens:num_ctx_tokens +
10551077
num_gen_tokens, ...]

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,8 @@ class TrtllmAttentionMetadata(AttentionMetadata):
597597
is_spec_decoding_enabled: bool = False
598598
# use_spec_decoding determines if the attention layer should be run in spec-dec mode at the specific step / layer.
599599
use_spec_decoding: bool = False
600+
# max number of draft tokens
601+
max_draft_tokens: int = 0
600602

601603
# if spec-dec tree is a tree or a chain (linear tree)
602604
is_spec_dec_tree: bool = False
@@ -1067,7 +1069,7 @@ def update_spec_dec_param(
10671069
max_draft_tokens,
10681070
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
10691071
):
1070-
1072+
self.max_draft_tokens = max_draft_tokens
10711073
if spec_decoding_tensor is not None:
10721074
spec_decoding_position_offsets = spec_decoding_tensor.position_offsets
10731075
spec_decoding_packed_mask = spec_decoding_tensor.packed_mask

0 commit comments

Comments
 (0)