Skip to content

Commit 2ba510b

Browse files
committed
fix for MTP>1.
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
1 parent f201d10 commit 2ba510b

File tree

1 file changed

+23
-1
lines changed
  • tensorrt_llm/_torch/attention_backend/sparse

1 file changed

+23
-1
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,10 @@ def __post_init__(self):
432432
dtype=torch.int32,
433433
capture_graph=capture_graph,
434434
)
435-
# TODO: remove these expanded buffers when fp8_paged_mqa_logits supports MTP > 1.
435+
self.create_expanded_buffers(capture_graph=capture_graph)
436+
437+
# TODO: remove these expanded buffers when fp8_paged_mqa_logits supports MTP > 1.
438+
def create_expanded_buffers(self, capture_graph=False):
436439
self.kv_lens_expanded_cuda = self.get_empty(
437440
self.cuda_graph_buffers,
438441
(self.max_num_sequences * (1 + self.max_draft_tokens), ),
@@ -468,6 +471,25 @@ def __post_init__(self):
468471
capture_graph=capture_graph,
469472
)
470473

474+
# This function is only used to create the expanded buffers when the max_draft_tokens is changed.
475+
# TODO: remove this function when fp8_paged_mqa_logits can support MTP > 1.
476+
def update_spec_dec_param(
477+
self,
478+
is_spec_decoding_enabled,
479+
is_spec_dec_tree,
480+
is_spec_dec_dynamic_tree,
481+
max_draft_tokens,
482+
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
483+
):
484+
super().update_spec_dec_param(is_spec_decoding_enabled,
485+
is_spec_dec_tree,
486+
is_spec_dec_dynamic_tree,
487+
max_draft_tokens, spec_decoding_tensor)
488+
init_shape = self.kv_lens_expanded_host.shape[0]
489+
if self.max_num_sequences * (1 + self.max_draft_tokens) != init_shape:
490+
capture_graph = torch.cuda.is_current_stream_capturing()
491+
self.create_expanded_buffers(capture_graph=capture_graph)
492+
471493
def prepare(self):
472494
super().prepare()
473495
if self.kv_cache_manager is not None:

0 commit comments

Comments
 (0)