Skip to content

Commit d87042d

Browse files
committed
address comments.
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
1 parent fb7578a commit d87042d

File tree

1 file changed

+23
-20
lines changed
  • tensorrt_llm/_torch/attention_backend/sparse

1 file changed

+23
-20
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -570,8 +570,9 @@ def prepare(self):
570570
self.max_gen_seq_len = 0
571571

572572
# Because the fp8_paged_mqa_logits only supports seq_len == 1 or 2, so it cannot support
573-
# MTP > 1. To haddle this, when MTP > 1, we flatten the q tensor and expand the kv_lens and
573+
# MTP > 1. To handle this, when MTP > 1, we flatten the q tensor and expand the kv_lens and
574574
# block_table for to use the fp8_paged_mqa_logits.
575+
# TODO: remove this when fp8_paged_mqa_logits supports MTP > 1.
575576
if self.max_draft_tokens > 1:
576577
# Expand kv_lens_cuda (only generation)
577578
num_tokens = self.num_generations * (1 + self.max_draft_tokens)
@@ -589,18 +590,19 @@ def prepare(self):
589590
if self.kv_cache_manager is not None:
590591
block_ids = self.kv_cache_manager.get_batch_cache_indices(
591592
self.request_ids)
592-
for i in range(self.num_contexts, len(block_ids)):
593-
for j in range(1 + self.max_draft_tokens):
594-
self.host_block_table_expanded[
595-
(i - self.num_contexts) *
596-
(1 + self.max_draft_tokens) +
597-
j, :len(block_ids[i])].copy_(
598-
torch.tensor(block_ids[i],
599-
dtype=torch.int32,
600-
device='cpu'))
601-
self.block_table_expanded[:num_tokens].copy_(
602-
self.host_block_table_expanded[:num_tokens],
603-
non_blocking=True)
593+
gen_block_ids = block_ids[self.num_contexts:]
594+
if len(gen_block_ids) > 0:
595+
# Find max length and create padded tensor
596+
max_len = max(len(bid) for bid in gen_block_ids)
597+
gen_block_tensor = self.host_indexer_k_cache_block_offsets[
598+
self.num_contexts:self.num_seqs, :max_len]
599+
expanded_blocks = gen_block_tensor.repeat_interleave(
600+
1 + self.max_draft_tokens, dim=0)
601+
self.host_block_table_expanded[:num_tokens, :max_len].copy_(
602+
expanded_blocks, non_blocking=True)
603+
self.block_table_expanded[:num_tokens].copy_(
604+
self.host_block_table_expanded[:num_tokens],
605+
non_blocking=True)
604606

605607
# Prepare metadata for indexer
606608
Indexer.prepare(metadata=self)
@@ -866,13 +868,14 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
866868
if num_generations > 0:
867869
# Prepare schedule metadata for fp8_paged_mqa_logits
868870
# This is a preprocessing step that computes scheduling information for the kernel
869-
gen_seq_lens = metadata.kv_lens_cuda_runtime[
870-
num_contexts:num_contexts + num_generations]
871-
scheduler_metadata_buffer = get_paged_mqa_logits_metadata(
872-
gen_seq_lens, tokens_per_block, metadata.num_sms)
873-
metadata.scheduler_metadata_buffer.copy_(scheduler_metadata_buffer,
874-
non_blocking=True)
875-
if metadata.max_draft_tokens > 1:
871+
if metadata.max_draft_tokens <= 1:
872+
gen_seq_lens = metadata.kv_lens_cuda_runtime[
873+
num_contexts:num_contexts + num_generations]
874+
scheduler_metadata_buffer = get_paged_mqa_logits_metadata(
875+
gen_seq_lens, tokens_per_block, metadata.num_sms)
876+
metadata.scheduler_metadata_buffer.copy_(
877+
scheduler_metadata_buffer, non_blocking=True)
878+
else:
876879
# Expand schedule metadata buffer (only generation)
877880
num_tokens = metadata.num_generations * (
878881
1 + metadata.max_draft_tokens)

0 commit comments

Comments
 (0)