Skip to content

Commit 311e743

Browse files
committed
address comments.
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
1 parent cfe1c4c commit 311e743

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
290290
indexer_max_chunk_size: int
291291
# Topk for sparse MLA
292292
sparse_mla_topk: int
293+
# max number of draft tokens
294+
max_draft_tokens: int = 0
293295

294296
def __init__(self, *args, **kwargs):
295297
self.num_sms = tensorrt_llm.deep_gemm.get_num_sms()
@@ -485,6 +487,7 @@ def update_spec_dec_param(
485487
is_spec_dec_tree,
486488
is_spec_dec_dynamic_tree,
487489
max_draft_tokens, spec_decoding_tensor)
490+
self.max_draft_tokens = max_draft_tokens
488491
init_shape = self.kv_lens_expanded_host.shape[0]
489492
if self.max_num_sequences * (1 + self.max_draft_tokens) != init_shape:
490493
capture_graph = torch.cuda.is_current_stream_capturing()

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -597,8 +597,6 @@ class TrtllmAttentionMetadata(AttentionMetadata):
597597
is_spec_decoding_enabled: bool = False
598598
# use_spec_decoding determines if the attention layer should be run in spec-dec mode at the specific step / layer.
599599
use_spec_decoding: bool = False
600-
# max number of draft tokens
601-
max_draft_tokens: int = 0
602600

603601
# if spec-dec tree is a tree or a chain (linear tree)
604602
is_spec_dec_tree: bool = False
@@ -1069,7 +1067,6 @@ def update_spec_dec_param(
10691067
max_draft_tokens,
10701068
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
10711069
):
1072-
self.max_draft_tokens = max_draft_tokens
10731070
if spec_decoding_tensor is not None:
10741071
spec_decoding_position_offsets = spec_decoding_tensor.position_offsets
10751072
spec_decoding_packed_mask = spec_decoding_tensor.packed_mask

0 commit comments

Comments
 (0)