From c3e146996ec16e603cf514e76c56fff40604a2f3 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Sun, 9 Nov 2025 08:09:52 -0800 Subject: [PATCH 1/9] add mtp3 support. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../_torch/attention_backend/sparse/dsa.py | 28 +++++++++++++++++-- .../_torch/attention_backend/trtllm.py | 4 ++- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 84a9d63b7a0..ca460bb2c57 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -432,6 +432,18 @@ def __post_init__(self): dtype=torch.int32, capture_graph=capture_graph, ) + self.kv_lens_expanded_cuda = self.get_empty( + self.cuda_graph_buffers, + (self.max_num_sequences * (1 + self.max_draft_tokens), ), + cache_name="kv_lens_expanded_cuda", + dtype=torch.int32, + capture_graph=capture_graph, + ) + self.kv_lens_expanded_host = torch.zeros_like( + self.kv_lens_expanded_cuda, + device='cpu', + pin_memory=True, + ) def prepare(self): super().prepare() @@ -535,6 +547,10 @@ def prepare(self): else: self.max_gen_seq_len = 0 + # Expand kv_lens_cuda for draft tokens (only generation) + gen_kv_lens = kv_lens[self.num_contexts:self.num_seqs] + self.kv_lens_expanded_host = torch.cat([gen_kv_lens] * (1+self.max_draft_tokens), dim=0) + # Prepare metadata for indexer Indexer.prepare(metadata=self) @@ -1053,9 +1069,15 @@ def sparse_attn_indexer( # Reshape q for decode phase: [num_gen_tokens, ...] -> [batch_size, next_n, ...] q_decode = q_fp8[num_ctx_tokens:num_ctx_tokens + num_gen_tokens, ...] - q_decode = q_decode.view(num_generations, -1, *q_fp8.shape[1:]) - batch_size = q_decode.shape[0] - next_n = q_decode.shape[1] + batch_size = num_generations + next_n = num_gen_tokens // num_generations + if next_n <= 2: + q_decode = q_decode.view(num_generations, -1, *q_fp8.shape[1:]) + context_lens = metadata.kv_lens_cuda_runtime[num_contexts:num_contexts + + num_generations] + else: + q_decode = q_decode.view(-1, 1, *q_fp8.shape[1:]) + assert num_gen_tokens == batch_size * next_n weights_decode = weights[num_ctx_tokens:num_ctx_tokens + num_gen_tokens, ...] diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index a7bc7c4490f..3cc120ed634 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -597,6 +597,8 @@ class TrtllmAttentionMetadata(AttentionMetadata): is_spec_decoding_enabled: bool = False # use_spec_decoding determines if the attention layer should be run in spec-dec mode at the specific step / layer. use_spec_decoding: bool = False + # max number of draft tokens + max_draft_tokens: int = 0 # if spec-dec tree is a tree or a chain (linear tree) is_spec_dec_tree: bool = False @@ -1067,7 +1069,7 @@ def update_spec_dec_param( max_draft_tokens, spec_decoding_tensor: Optional['SpecDecodingTensor'] = None, ): - + self.max_draft_tokens = max_draft_tokens if spec_decoding_tensor is not None: spec_decoding_position_offsets = spec_decoding_tensor.position_offsets spec_decoding_packed_mask = spec_decoding_tensor.packed_mask From d5b7d43df850c8b683b52449a13c2e04925af742 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Mon, 10 Nov 2025 06:15:49 -0800 Subject: [PATCH 2/9] add mtp3 support. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../_torch/attention_backend/sparse/dsa.py | 96 +++++++++++++++---- 1 file changed, 79 insertions(+), 17 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index ca460bb2c57..244724d9807 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -444,6 +444,28 @@ def __post_init__(self): device='cpu', pin_memory=True, ) + self.block_table_expanded = self.get_empty( + self.cuda_graph_buffers, + [ + self.max_num_sequences * (1 + self.max_draft_tokens), + self.kv_cache_manager.max_blocks_per_seq + ], + cache_name="block_table_expanded", + dtype=torch.int32, + capture_graph=capture_graph, + ) + self.host_block_table_expanded = torch.zeros_like( + self.block_table_expanded, + device='cpu', + pin_memory=True, + ) + self.scheduler_metadata_buffer_expanded = self.get_empty( + self.cuda_graph_buffers, + (self.num_sms + 1, 2), + cache_name="scheduler_metadata_buffer_expanded", + dtype=torch.int32, + capture_graph=capture_graph, + ) def prepare(self): super().prepare() @@ -547,9 +569,38 @@ def prepare(self): else: self.max_gen_seq_len = 0 - # Expand kv_lens_cuda for draft tokens (only generation) - gen_kv_lens = kv_lens[self.num_contexts:self.num_seqs] - self.kv_lens_expanded_host = torch.cat([gen_kv_lens] * (1+self.max_draft_tokens), dim=0) + # Because the fp8_paged_mqa_logits only supports seq_len == 1 or 2, so it cannot support + # MTP > 1. To haddle this, when MTP > 1, we flatten the q tensor and expand the kv_lens and + # block_table for to use the fp8_paged_mqa_logits. + if self.max_draft_tokens > 1: + # Expand kv_lens_cuda (only generation) + num_tokens = self.num_generations * (1 + self.max_draft_tokens) + gen_kv_lens = kv_lens[self.num_contexts:self.num_seqs] + gen_kv_lens_expanded = torch.stack([gen_kv_lens] * + (1 + self.max_draft_tokens), + dim=0) + gen_kv_lens_expanded = gen_kv_lens_expanded.transpose( + 0, 1).contiguous().flatten() + self.kv_lens_expanded_host[:num_tokens].copy_(gen_kv_lens_expanded) + self.kv_lens_expanded_cuda[:num_tokens].copy_( + self.kv_lens_expanded_host[:num_tokens], non_blocking=True) + + # Expand indexer_k_cache_block_offsets (only generation) + if self.kv_cache_manager is not None: + block_ids = self.kv_cache_manager.get_batch_cache_indices( + self.request_ids) + for i in range(self.num_contexts, len(block_ids)): + for j in range(1 + self.max_draft_tokens): + self.host_block_table_expanded[ + (i - self.num_contexts) * + (1 + self.max_draft_tokens) + + j, :len(block_ids[i])].copy_( + torch.tensor(block_ids[i], + dtype=torch.int32, + device='cpu')) + self.block_table_expanded[:num_tokens].copy_( + self.host_block_table_expanded[:num_tokens], + non_blocking=True) # Prepare metadata for indexer Indexer.prepare(metadata=self) @@ -821,6 +872,15 @@ def prepare(metadata: DSAtrtllmAttentionMetadata): gen_seq_lens, tokens_per_block, metadata.num_sms) metadata.scheduler_metadata_buffer.copy_(scheduler_metadata_buffer, non_blocking=True) + if metadata.max_draft_tokens > 1: + # Expand schedule metadata buffer (only generation) + num_tokens = metadata.num_generations * ( + 1 + metadata.max_draft_tokens) + kv_lens_expanded = metadata.kv_lens_expanded_cuda[:num_tokens] + scheduler_metadata_buffer_expanded = get_paged_mqa_logits_metadata( + kv_lens_expanded, tokens_per_block, metadata.num_sms) + metadata.scheduler_metadata_buffer_expanded.copy_( + scheduler_metadata_buffer_expanded, non_blocking=True) # Compute slot_mapping for all requests (both context and generation) # This maps each token to its flat cache position for vectorized KV cache updates @@ -1071,12 +1131,21 @@ def sparse_attn_indexer( ...] batch_size = num_generations next_n = num_gen_tokens // num_generations + # Because fp8_paged_mqa_logits cannot support next_n > 2, we need to flatten the q_decode tensor + # and expand the corresponding metadata. if next_n <= 2: q_decode = q_decode.view(num_generations, -1, *q_fp8.shape[1:]) - context_lens = metadata.kv_lens_cuda_runtime[num_contexts:num_contexts + - num_generations] + context_lens = metadata.kv_lens_cuda_runtime[ + num_contexts:num_contexts + num_generations] + block_table = metadata.indexer_k_cache_block_offsets[ + num_contexts:num_contexts + num_generations] + scheduler_metadata_buffer = metadata.scheduler_metadata_buffer else: q_decode = q_decode.view(-1, 1, *q_fp8.shape[1:]) + num_tokens = num_generations * (1 + metadata.max_draft_tokens) + context_lens = metadata.kv_lens_expanded_cuda[:num_tokens] + block_table = metadata.block_table_expanded[:num_tokens] + scheduler_metadata_buffer = metadata.scheduler_metadata_buffer_expanded assert num_gen_tokens == batch_size * next_n weights_decode = weights[num_ctx_tokens:num_ctx_tokens + @@ -1086,18 +1155,11 @@ def sparse_attn_indexer( # [num_blocks, tokens_per_block, 1, head_dim + scale_size] k_cache = metadata.kv_cache_manager.get_indexer_k_cache_buffers( self.layer_idx) - logits_decode = fp8_paged_mqa_logits( - q_decode, - k_cache, - weights_decode, - metadata.kv_lens_cuda_runtime[ - num_contexts:num_contexts + - num_generations], # context_lens prepared in prepare() - metadata.indexer_k_cache_block_offsets[ - num_contexts:num_contexts + - num_generations], # Only pass generation request block tables - metadata.scheduler_metadata_buffer, - max_seq_len) + logits_decode = fp8_paged_mqa_logits(q_decode, k_cache, + weights_decode, context_lens, + block_table, + scheduler_metadata_buffer, + max_seq_len) if use_custom_topk: # Kernel expects kv_lens (total cache length), not seq_lens (new tokens) From 56e97a1cc6900d36205f354661b1a7dfcf6d5f2d Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Mon, 10 Nov 2025 07:15:05 -0800 Subject: [PATCH 3/9] add mtp3 tests. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tests/integration/defs/accuracy/test_llm_api_pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index fe95c4cc093..d41560e3cb1 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -2380,7 +2380,7 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness): (8, 1, 8, 0, False, True, True, True, 24, "_DEFAULT"), (8, 1, 8, 1, False, True, True, True, 24, "_DEFAULT"), (8, 1, 8, 0, True, True, True, True, 24, "_DEFAULT"), - (8, 1, 8, 1, False, False, True, True, 1, "TRTLLM"), + (8, 1, 8, 3, False, False, True, True, 1, "TRTLLM"), ], ids=["baseline", "baseline_mtp1", "baseline_fp8kv", "latency"]) def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv, @@ -2448,7 +2448,7 @@ def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv, (8, 1, 8, 0, False, True, True, True, 24, "CUTLASS"), (8, 1, 8, 1, False, True, True, True, 24, "CUTLASS"), (8, 1, 8, 0, True, True, True, True, 24, "CUTLASS"), - (8, 1, 8, 1, False, False, True, True, 1, "TRTLLM"), + (8, 1, 8, 3, False, False, True, True, 1, "TRTLLM"), ], ids=["baseline", "baseline_mtp1", "baseline_fp8kv", "latency"]) def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv, From bbc2b08872703feade5ea47bcb76c1834025c01d Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Mon, 10 Nov 2025 20:45:19 -0800 Subject: [PATCH 4/9] address comments. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../_torch/attention_backend/sparse/dsa.py | 43 ++++++++++--------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 244724d9807..ceaa38d5c14 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -570,8 +570,9 @@ def prepare(self): self.max_gen_seq_len = 0 # Because the fp8_paged_mqa_logits only supports seq_len == 1 or 2, so it cannot support - # MTP > 1. To haddle this, when MTP > 1, we flatten the q tensor and expand the kv_lens and + # MTP > 1. To handle this, when MTP > 1, we flatten the q tensor and expand the kv_lens and # block_table for to use the fp8_paged_mqa_logits. + # TODO: remove this when fp8_paged_mqa_logits supports MTP > 1. if self.max_draft_tokens > 1: # Expand kv_lens_cuda (only generation) num_tokens = self.num_generations * (1 + self.max_draft_tokens) @@ -589,18 +590,19 @@ def prepare(self): if self.kv_cache_manager is not None: block_ids = self.kv_cache_manager.get_batch_cache_indices( self.request_ids) - for i in range(self.num_contexts, len(block_ids)): - for j in range(1 + self.max_draft_tokens): - self.host_block_table_expanded[ - (i - self.num_contexts) * - (1 + self.max_draft_tokens) + - j, :len(block_ids[i])].copy_( - torch.tensor(block_ids[i], - dtype=torch.int32, - device='cpu')) - self.block_table_expanded[:num_tokens].copy_( - self.host_block_table_expanded[:num_tokens], - non_blocking=True) + gen_block_ids = block_ids[self.num_contexts:] + if len(gen_block_ids) > 0: + # Find max length and create padded tensor + max_len = max(len(bid) for bid in gen_block_ids) + gen_block_tensor = self.host_indexer_k_cache_block_offsets[ + self.num_contexts:self.num_seqs, :max_len] + expanded_blocks = gen_block_tensor.repeat_interleave( + 1 + self.max_draft_tokens, dim=0) + self.host_block_table_expanded[:num_tokens, :max_len].copy_( + expanded_blocks, non_blocking=True) + self.block_table_expanded[:num_tokens].copy_( + self.host_block_table_expanded[:num_tokens], + non_blocking=True) # Prepare metadata for indexer Indexer.prepare(metadata=self) @@ -866,13 +868,14 @@ def prepare(metadata: DSAtrtllmAttentionMetadata): if num_generations > 0: # Prepare schedule metadata for fp8_paged_mqa_logits # This is a preprocessing step that computes scheduling information for the kernel - gen_seq_lens = metadata.kv_lens_cuda_runtime[ - num_contexts:num_contexts + num_generations] - scheduler_metadata_buffer = get_paged_mqa_logits_metadata( - gen_seq_lens, tokens_per_block, metadata.num_sms) - metadata.scheduler_metadata_buffer.copy_(scheduler_metadata_buffer, - non_blocking=True) - if metadata.max_draft_tokens > 1: + if metadata.max_draft_tokens <= 1: + gen_seq_lens = metadata.kv_lens_cuda_runtime[ + num_contexts:num_contexts + num_generations] + scheduler_metadata_buffer = get_paged_mqa_logits_metadata( + gen_seq_lens, tokens_per_block, metadata.num_sms) + metadata.scheduler_metadata_buffer.copy_( + scheduler_metadata_buffer, non_blocking=True) + else: # Expand schedule metadata buffer (only generation) num_tokens = metadata.num_generations * ( 1 + metadata.max_draft_tokens) From 3c4e6a497089cad183c30234cfa2e3afa850ba3b Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Mon, 10 Nov 2025 20:47:38 -0800 Subject: [PATCH 5/9] add TODO. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tensorrt_llm/_torch/attention_backend/sparse/dsa.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index ceaa38d5c14..0a93e687577 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -432,6 +432,7 @@ def __post_init__(self): dtype=torch.int32, capture_graph=capture_graph, ) + # TODO: remove these expanded buffers when fp8_paged_mqa_logits supports MTP > 1. self.kv_lens_expanded_cuda = self.get_empty( self.cuda_graph_buffers, (self.max_num_sequences * (1 + self.max_draft_tokens), ), From 3c6348415be6ef0387091f6b5dd92bd393a1d604 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Tue, 11 Nov 2025 00:57:41 -0800 Subject: [PATCH 6/9] fix for MTP>1. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../_torch/attention_backend/sparse/dsa.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 0a93e687577..6aa9f5c3849 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -432,7 +432,10 @@ def __post_init__(self): dtype=torch.int32, capture_graph=capture_graph, ) - # TODO: remove these expanded buffers when fp8_paged_mqa_logits supports MTP > 1. + self.create_expanded_buffers(capture_graph=capture_graph) + + # TODO: remove these expanded buffers when fp8_paged_mqa_logits supports MTP > 1. + def create_expanded_buffers(self, capture_graph=False): self.kv_lens_expanded_cuda = self.get_empty( self.cuda_graph_buffers, (self.max_num_sequences * (1 + self.max_draft_tokens), ), @@ -468,6 +471,25 @@ def __post_init__(self): capture_graph=capture_graph, ) + # This function is only used to create the expanded buffers when the max_draft_tokens is changed. + # TODO: remove this function when fp8_paged_mqa_logits can support MTP > 1. + def update_spec_dec_param( + self, + is_spec_decoding_enabled, + is_spec_dec_tree, + is_spec_dec_dynamic_tree, + max_draft_tokens, + spec_decoding_tensor: Optional['SpecDecodingTensor'] = None, + ): + super().update_spec_dec_param(is_spec_decoding_enabled, + is_spec_dec_tree, + is_spec_dec_dynamic_tree, + max_draft_tokens, spec_decoding_tensor) + init_shape = self.kv_lens_expanded_host.shape[0] + if self.max_num_sequences * (1 + self.max_draft_tokens) != init_shape: + capture_graph = torch.cuda.is_current_stream_capturing() + self.create_expanded_buffers(capture_graph=capture_graph) + def prepare(self): super().prepare() if self.kv_cache_manager is not None: From cfe1c4cc1fbb6de7465b03172a18bda72b05b0e8 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Tue, 11 Nov 2025 04:20:39 -0800 Subject: [PATCH 7/9] fix test_dsa_indexer. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../attention/sparse/test_dsa_indexer.py | 38 ++++++++++++++----- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py index 6cfe276a105..75da9cebba2 100644 --- a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py +++ b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py @@ -381,7 +381,8 @@ def _create_mock_metadata(request_ids, cache_manager, num_ctx_tokens, num_tokens, - indexer_max_chunk_size=8194): + indexer_max_chunk_size=8194, + max_draft_tokens=0): """Helper to create mock metadata for testing.""" class MockKVCacheParams: @@ -396,6 +397,7 @@ def __init__(self): self.request_ids = request_ids self.num_contexts = num_contexts self.num_generations = num_generations + self.max_draft_tokens = max_draft_tokens # Keep seq_lens on CPU for split_prefill_chunks and other CPU operations # CUDA kernels will convert to CUDA as needed self.seq_lens = seq_lens.cpu() if seq_lens.is_cuda else seq_lens @@ -826,6 +828,7 @@ def test_indexer_decode_with_paged_kv_cache(batch_size, next_n): cache_manager=cache_manager, num_ctx_tokens=total_context_tokens, num_tokens=total_context_tokens, + max_draft_tokens=next_n - 1, ) Indexer.prepare(metadata_context) @@ -851,6 +854,7 @@ def test_indexer_decode_with_paged_kv_cache(batch_size, next_n): cache_manager=cache_manager, num_ctx_tokens=0, num_tokens=batch_size * num_gen_tokens, + max_draft_tokens=next_n - 1, ) Indexer.prepare(metadata_gen) @@ -1418,6 +1422,7 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk, cache_manager=cache_manager, num_ctx_tokens=total_context_tokens, num_tokens=total_context_tokens, + max_draft_tokens=next_n - 1, ) Indexer.prepare(metadata_context) indexer._update_k_cache(k_context_fp8, k_context_scale, metadata_context) @@ -1450,16 +1455,24 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk, cache_manager=cache_manager, num_ctx_tokens=0, num_tokens=num_gen_tokens, + max_draft_tokens=next_n - 1, ) Indexer.prepare(metadata_gen_write) indexer._update_k_cache(k_fp8, k_scale, metadata_gen_write) # Test with custom CUDA kernel - metadata_custom = _create_mock_metadata(request_ids, batch_size, 0, - batch_size, seq_lens.clone(), + metadata_custom = _create_mock_metadata(request_ids, + batch_size, + 0, + batch_size, + seq_lens.clone(), final_lens.clone(), - num_cached_tokens, cache_manager, 0, - num_gen_tokens, max_model_len) + num_cached_tokens, + cache_manager, + 0, + num_gen_tokens, + max_model_len, + max_draft_tokens=next_n - 1) Indexer.prepare(metadata_custom) indexer._update_k_cache(k_fp8, k_scale, metadata_custom) @@ -1476,11 +1489,18 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk, pytest.skip(f"Custom topk not available: {e}") # Test with PyTorch fallback - metadata_fallback = _create_mock_metadata(request_ids, batch_size, 0, - batch_size, seq_lens.clone(), + metadata_fallback = _create_mock_metadata(request_ids, + batch_size, + 0, + batch_size, + seq_lens.clone(), final_lens.clone(), - num_cached_tokens, cache_manager, - 0, num_gen_tokens, max_model_len) + num_cached_tokens, + cache_manager, + 0, + num_gen_tokens, + max_model_len, + max_draft_tokens=next_n - 1) Indexer.prepare(metadata_fallback) indexer._update_k_cache(k_fp8, k_scale, metadata_fallback) From 311e743a22c8b53e6e211ecfce2e7ed53a4689da Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Tue, 11 Nov 2025 21:44:14 -0800 Subject: [PATCH 8/9] address comments. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tensorrt_llm/_torch/attention_backend/sparse/dsa.py | 3 +++ tensorrt_llm/_torch/attention_backend/trtllm.py | 3 --- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 6aa9f5c3849..4db8930079e 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -290,6 +290,8 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata): indexer_max_chunk_size: int # Topk for sparse MLA sparse_mla_topk: int + # max number of draft tokens + max_draft_tokens: int = 0 def __init__(self, *args, **kwargs): self.num_sms = tensorrt_llm.deep_gemm.get_num_sms() @@ -485,6 +487,7 @@ def update_spec_dec_param( is_spec_dec_tree, is_spec_dec_dynamic_tree, max_draft_tokens, spec_decoding_tensor) + self.max_draft_tokens = max_draft_tokens init_shape = self.kv_lens_expanded_host.shape[0] if self.max_num_sequences * (1 + self.max_draft_tokens) != init_shape: capture_graph = torch.cuda.is_current_stream_capturing() diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 3cc120ed634..cb2d16918fe 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -597,8 +597,6 @@ class TrtllmAttentionMetadata(AttentionMetadata): is_spec_decoding_enabled: bool = False # use_spec_decoding determines if the attention layer should be run in spec-dec mode at the specific step / layer. use_spec_decoding: bool = False - # max number of draft tokens - max_draft_tokens: int = 0 # if spec-dec tree is a tree or a chain (linear tree) is_spec_dec_tree: bool = False @@ -1069,7 +1067,6 @@ def update_spec_dec_param( max_draft_tokens, spec_decoding_tensor: Optional['SpecDecodingTensor'] = None, ): - self.max_draft_tokens = max_draft_tokens if spec_decoding_tensor is not None: spec_decoding_position_offsets = spec_decoding_tensor.position_offsets spec_decoding_packed_mask = spec_decoding_tensor.packed_mask From 9b56f8ac6daf68193cd52c2f9cda318bce7b11ed Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Tue, 11 Nov 2025 21:46:17 -0800 Subject: [PATCH 9/9] address comments. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tensorrt_llm/_torch/attention_backend/trtllm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index cb2d16918fe..a7bc7c4490f 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -1067,6 +1067,7 @@ def update_spec_dec_param( max_draft_tokens, spec_decoding_tensor: Optional['SpecDecodingTensor'] = None, ): + if spec_decoding_tensor is not None: spec_decoding_position_offsets = spec_decoding_tensor.position_offsets spec_decoding_packed_mask = spec_decoding_tensor.packed_mask