File tree Expand file tree Collapse file tree 2 files changed +3
-3
lines changed
tensorrt_llm/_torch/attention_backend Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -290,6 +290,8 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
290290 indexer_max_chunk_size : int
291291 # Topk for sparse MLA
292292 sparse_mla_topk : int
293+ # max number of draft tokens
294+ max_draft_tokens : int = 0
293295
294296 def __init__ (self , * args , ** kwargs ):
295297 self .num_sms = tensorrt_llm .deep_gemm .get_num_sms ()
@@ -485,6 +487,7 @@ def update_spec_dec_param(
485487 is_spec_dec_tree ,
486488 is_spec_dec_dynamic_tree ,
487489 max_draft_tokens , spec_decoding_tensor )
490+ self .max_draft_tokens = max_draft_tokens
488491 init_shape = self .kv_lens_expanded_host .shape [0 ]
489492 if self .max_num_sequences * (1 + self .max_draft_tokens ) != init_shape :
490493 capture_graph = torch .cuda .is_current_stream_capturing ()
Original file line number Diff line number Diff line change @@ -597,8 +597,6 @@ class TrtllmAttentionMetadata(AttentionMetadata):
597597 is_spec_decoding_enabled : bool = False
598598 # use_spec_decoding determines if the attention layer should be run in spec-dec mode at the specific step / layer.
599599 use_spec_decoding : bool = False
600- # max number of draft tokens
601- max_draft_tokens : int = 0
602600
603601 # if spec-dec tree is a tree or a chain (linear tree)
604602 is_spec_dec_tree : bool = False
@@ -1069,7 +1067,6 @@ def update_spec_dec_param(
10691067 max_draft_tokens ,
10701068 spec_decoding_tensor : Optional ['SpecDecodingTensor' ] = None ,
10711069 ):
1072- self .max_draft_tokens = max_draft_tokens
10731070 if spec_decoding_tensor is not None :
10741071 spec_decoding_position_offsets = spec_decoding_tensor .position_offsets
10751072 spec_decoding_packed_mask = spec_decoding_tensor .packed_mask
You can’t perform that action at this time.
0 commit comments