diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 84a9d63b7a0..4db8930079e 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -290,6 +290,8 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata): indexer_max_chunk_size: int # Topk for sparse MLA sparse_mla_topk: int + # max number of draft tokens + max_draft_tokens: int = 0 def __init__(self, *args, **kwargs): self.num_sms = tensorrt_llm.deep_gemm.get_num_sms() @@ -432,6 +434,64 @@ def __post_init__(self): dtype=torch.int32, capture_graph=capture_graph, ) + self.create_expanded_buffers(capture_graph=capture_graph) + + # TODO: remove these expanded buffers when fp8_paged_mqa_logits supports MTP > 1. + def create_expanded_buffers(self, capture_graph=False): + self.kv_lens_expanded_cuda = self.get_empty( + self.cuda_graph_buffers, + (self.max_num_sequences * (1 + self.max_draft_tokens), ), + cache_name="kv_lens_expanded_cuda", + dtype=torch.int32, + capture_graph=capture_graph, + ) + self.kv_lens_expanded_host = torch.zeros_like( + self.kv_lens_expanded_cuda, + device='cpu', + pin_memory=True, + ) + self.block_table_expanded = self.get_empty( + self.cuda_graph_buffers, + [ + self.max_num_sequences * (1 + self.max_draft_tokens), + self.kv_cache_manager.max_blocks_per_seq + ], + cache_name="block_table_expanded", + dtype=torch.int32, + capture_graph=capture_graph, + ) + self.host_block_table_expanded = torch.zeros_like( + self.block_table_expanded, + device='cpu', + pin_memory=True, + ) + self.scheduler_metadata_buffer_expanded = self.get_empty( + self.cuda_graph_buffers, + (self.num_sms + 1, 2), + cache_name="scheduler_metadata_buffer_expanded", + dtype=torch.int32, + capture_graph=capture_graph, + ) + + # This function is only used to create the expanded buffers when the max_draft_tokens is changed. + # TODO: remove this function when fp8_paged_mqa_logits can support MTP > 1. + def update_spec_dec_param( + self, + is_spec_decoding_enabled, + is_spec_dec_tree, + is_spec_dec_dynamic_tree, + max_draft_tokens, + spec_decoding_tensor: Optional['SpecDecodingTensor'] = None, + ): + super().update_spec_dec_param(is_spec_decoding_enabled, + is_spec_dec_tree, + is_spec_dec_dynamic_tree, + max_draft_tokens, spec_decoding_tensor) + self.max_draft_tokens = max_draft_tokens + init_shape = self.kv_lens_expanded_host.shape[0] + if self.max_num_sequences * (1 + self.max_draft_tokens) != init_shape: + capture_graph = torch.cuda.is_current_stream_capturing() + self.create_expanded_buffers(capture_graph=capture_graph) def prepare(self): super().prepare() @@ -535,6 +595,41 @@ def prepare(self): else: self.max_gen_seq_len = 0 + # Because the fp8_paged_mqa_logits only supports seq_len == 1 or 2, so it cannot support + # MTP > 1. To handle this, when MTP > 1, we flatten the q tensor and expand the kv_lens and + # block_table for to use the fp8_paged_mqa_logits. + # TODO: remove this when fp8_paged_mqa_logits supports MTP > 1. + if self.max_draft_tokens > 1: + # Expand kv_lens_cuda (only generation) + num_tokens = self.num_generations * (1 + self.max_draft_tokens) + gen_kv_lens = kv_lens[self.num_contexts:self.num_seqs] + gen_kv_lens_expanded = torch.stack([gen_kv_lens] * + (1 + self.max_draft_tokens), + dim=0) + gen_kv_lens_expanded = gen_kv_lens_expanded.transpose( + 0, 1).contiguous().flatten() + self.kv_lens_expanded_host[:num_tokens].copy_(gen_kv_lens_expanded) + self.kv_lens_expanded_cuda[:num_tokens].copy_( + self.kv_lens_expanded_host[:num_tokens], non_blocking=True) + + # Expand indexer_k_cache_block_offsets (only generation) + if self.kv_cache_manager is not None: + block_ids = self.kv_cache_manager.get_batch_cache_indices( + self.request_ids) + gen_block_ids = block_ids[self.num_contexts:] + if len(gen_block_ids) > 0: + # Find max length and create padded tensor + max_len = max(len(bid) for bid in gen_block_ids) + gen_block_tensor = self.host_indexer_k_cache_block_offsets[ + self.num_contexts:self.num_seqs, :max_len] + expanded_blocks = gen_block_tensor.repeat_interleave( + 1 + self.max_draft_tokens, dim=0) + self.host_block_table_expanded[:num_tokens, :max_len].copy_( + expanded_blocks, non_blocking=True) + self.block_table_expanded[:num_tokens].copy_( + self.host_block_table_expanded[:num_tokens], + non_blocking=True) + # Prepare metadata for indexer Indexer.prepare(metadata=self) @@ -799,12 +894,22 @@ def prepare(metadata: DSAtrtllmAttentionMetadata): if num_generations > 0: # Prepare schedule metadata for fp8_paged_mqa_logits # This is a preprocessing step that computes scheduling information for the kernel - gen_seq_lens = metadata.kv_lens_cuda_runtime[ - num_contexts:num_contexts + num_generations] - scheduler_metadata_buffer = get_paged_mqa_logits_metadata( - gen_seq_lens, tokens_per_block, metadata.num_sms) - metadata.scheduler_metadata_buffer.copy_(scheduler_metadata_buffer, - non_blocking=True) + if metadata.max_draft_tokens <= 1: + gen_seq_lens = metadata.kv_lens_cuda_runtime[ + num_contexts:num_contexts + num_generations] + scheduler_metadata_buffer = get_paged_mqa_logits_metadata( + gen_seq_lens, tokens_per_block, metadata.num_sms) + metadata.scheduler_metadata_buffer.copy_( + scheduler_metadata_buffer, non_blocking=True) + else: + # Expand schedule metadata buffer (only generation) + num_tokens = metadata.num_generations * ( + 1 + metadata.max_draft_tokens) + kv_lens_expanded = metadata.kv_lens_expanded_cuda[:num_tokens] + scheduler_metadata_buffer_expanded = get_paged_mqa_logits_metadata( + kv_lens_expanded, tokens_per_block, metadata.num_sms) + metadata.scheduler_metadata_buffer_expanded.copy_( + scheduler_metadata_buffer_expanded, non_blocking=True) # Compute slot_mapping for all requests (both context and generation) # This maps each token to its flat cache position for vectorized KV cache updates @@ -1053,9 +1158,24 @@ def sparse_attn_indexer( # Reshape q for decode phase: [num_gen_tokens, ...] -> [batch_size, next_n, ...] q_decode = q_fp8[num_ctx_tokens:num_ctx_tokens + num_gen_tokens, ...] - q_decode = q_decode.view(num_generations, -1, *q_fp8.shape[1:]) - batch_size = q_decode.shape[0] - next_n = q_decode.shape[1] + batch_size = num_generations + next_n = num_gen_tokens // num_generations + # Because fp8_paged_mqa_logits cannot support next_n > 2, we need to flatten the q_decode tensor + # and expand the corresponding metadata. + if next_n <= 2: + q_decode = q_decode.view(num_generations, -1, *q_fp8.shape[1:]) + context_lens = metadata.kv_lens_cuda_runtime[ + num_contexts:num_contexts + num_generations] + block_table = metadata.indexer_k_cache_block_offsets[ + num_contexts:num_contexts + num_generations] + scheduler_metadata_buffer = metadata.scheduler_metadata_buffer + else: + q_decode = q_decode.view(-1, 1, *q_fp8.shape[1:]) + num_tokens = num_generations * (1 + metadata.max_draft_tokens) + context_lens = metadata.kv_lens_expanded_cuda[:num_tokens] + block_table = metadata.block_table_expanded[:num_tokens] + scheduler_metadata_buffer = metadata.scheduler_metadata_buffer_expanded + assert num_gen_tokens == batch_size * next_n weights_decode = weights[num_ctx_tokens:num_ctx_tokens + num_gen_tokens, ...] @@ -1064,18 +1184,11 @@ def sparse_attn_indexer( # [num_blocks, tokens_per_block, 1, head_dim + scale_size] k_cache = metadata.kv_cache_manager.get_indexer_k_cache_buffers( self.layer_idx) - logits_decode = fp8_paged_mqa_logits( - q_decode, - k_cache, - weights_decode, - metadata.kv_lens_cuda_runtime[ - num_contexts:num_contexts + - num_generations], # context_lens prepared in prepare() - metadata.indexer_k_cache_block_offsets[ - num_contexts:num_contexts + - num_generations], # Only pass generation request block tables - metadata.scheduler_metadata_buffer, - max_seq_len) + logits_decode = fp8_paged_mqa_logits(q_decode, k_cache, + weights_decode, context_lens, + block_table, + scheduler_metadata_buffer, + max_seq_len) if use_custom_topk: # Kernel expects kv_lens (total cache length), not seq_lens (new tokens) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index fe95c4cc093..d41560e3cb1 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -2380,7 +2380,7 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness): (8, 1, 8, 0, False, True, True, True, 24, "_DEFAULT"), (8, 1, 8, 1, False, True, True, True, 24, "_DEFAULT"), (8, 1, 8, 0, True, True, True, True, 24, "_DEFAULT"), - (8, 1, 8, 1, False, False, True, True, 1, "TRTLLM"), + (8, 1, 8, 3, False, False, True, True, 1, "TRTLLM"), ], ids=["baseline", "baseline_mtp1", "baseline_fp8kv", "latency"]) def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv, @@ -2448,7 +2448,7 @@ def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv, (8, 1, 8, 0, False, True, True, True, 24, "CUTLASS"), (8, 1, 8, 1, False, True, True, True, 24, "CUTLASS"), (8, 1, 8, 0, True, True, True, True, 24, "CUTLASS"), - (8, 1, 8, 1, False, False, True, True, 1, "TRTLLM"), + (8, 1, 8, 3, False, False, True, True, 1, "TRTLLM"), ], ids=["baseline", "baseline_mtp1", "baseline_fp8kv", "latency"]) def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv, diff --git a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py index 6cfe276a105..75da9cebba2 100644 --- a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py +++ b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py @@ -381,7 +381,8 @@ def _create_mock_metadata(request_ids, cache_manager, num_ctx_tokens, num_tokens, - indexer_max_chunk_size=8194): + indexer_max_chunk_size=8194, + max_draft_tokens=0): """Helper to create mock metadata for testing.""" class MockKVCacheParams: @@ -396,6 +397,7 @@ def __init__(self): self.request_ids = request_ids self.num_contexts = num_contexts self.num_generations = num_generations + self.max_draft_tokens = max_draft_tokens # Keep seq_lens on CPU for split_prefill_chunks and other CPU operations # CUDA kernels will convert to CUDA as needed self.seq_lens = seq_lens.cpu() if seq_lens.is_cuda else seq_lens @@ -826,6 +828,7 @@ def test_indexer_decode_with_paged_kv_cache(batch_size, next_n): cache_manager=cache_manager, num_ctx_tokens=total_context_tokens, num_tokens=total_context_tokens, + max_draft_tokens=next_n - 1, ) Indexer.prepare(metadata_context) @@ -851,6 +854,7 @@ def test_indexer_decode_with_paged_kv_cache(batch_size, next_n): cache_manager=cache_manager, num_ctx_tokens=0, num_tokens=batch_size * num_gen_tokens, + max_draft_tokens=next_n - 1, ) Indexer.prepare(metadata_gen) @@ -1418,6 +1422,7 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk, cache_manager=cache_manager, num_ctx_tokens=total_context_tokens, num_tokens=total_context_tokens, + max_draft_tokens=next_n - 1, ) Indexer.prepare(metadata_context) indexer._update_k_cache(k_context_fp8, k_context_scale, metadata_context) @@ -1450,16 +1455,24 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk, cache_manager=cache_manager, num_ctx_tokens=0, num_tokens=num_gen_tokens, + max_draft_tokens=next_n - 1, ) Indexer.prepare(metadata_gen_write) indexer._update_k_cache(k_fp8, k_scale, metadata_gen_write) # Test with custom CUDA kernel - metadata_custom = _create_mock_metadata(request_ids, batch_size, 0, - batch_size, seq_lens.clone(), + metadata_custom = _create_mock_metadata(request_ids, + batch_size, + 0, + batch_size, + seq_lens.clone(), final_lens.clone(), - num_cached_tokens, cache_manager, 0, - num_gen_tokens, max_model_len) + num_cached_tokens, + cache_manager, + 0, + num_gen_tokens, + max_model_len, + max_draft_tokens=next_n - 1) Indexer.prepare(metadata_custom) indexer._update_k_cache(k_fp8, k_scale, metadata_custom) @@ -1476,11 +1489,18 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk, pytest.skip(f"Custom topk not available: {e}") # Test with PyTorch fallback - metadata_fallback = _create_mock_metadata(request_ids, batch_size, 0, - batch_size, seq_lens.clone(), + metadata_fallback = _create_mock_metadata(request_ids, + batch_size, + 0, + batch_size, + seq_lens.clone(), final_lens.clone(), - num_cached_tokens, cache_manager, - 0, num_gen_tokens, max_model_len) + num_cached_tokens, + cache_manager, + 0, + num_gen_tokens, + max_model_len, + max_draft_tokens=next_n - 1) Indexer.prepare(metadata_fallback) indexer._update_k_cache(k_fp8, k_scale, metadata_fallback)