Skip to content

Commit a7033a9

Browse files
authored
[TRTLLM-9001][feat] add TP support for DeepSeek-V3.2 (#8943)
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
1 parent 78fac1f commit a7033a9

File tree

7 files changed

+27
-16
lines changed

7 files changed

+27
-16
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def add_llm_args(parser):
8484
parser.add_argument('--disable_kv_cache_reuse',
8585
default=False,
8686
action='store_true')
87+
parser.add_argument("--tokens_per_block", type=int, default=32)
8788

8889
# Runtime
8990
parser.add_argument('--disable_overlap_scheduler',
@@ -180,6 +181,7 @@ def setup_llm(args, **kwargs):
180181
enable_block_reuse=not args.disable_kv_cache_reuse,
181182
free_gpu_memory_fraction=args.kv_cache_fraction,
182183
dtype=args.kv_cache_dtype,
184+
tokens_per_block=args.tokens_per_block,
183185
)
184186

185187
spec_decode_algo = args.spec_decode_algo.upper(

tensorrt_llm/_torch/modules/attention.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1969,10 +1969,10 @@ def forward_sparse_mla_kvcache_bf16(
19691969
q, latent_cache, attn_metadata, is_generation=is_generation)
19701970

19711971
num_tokens = q.shape[0]
1972-
q_nope, q_rope = q.view(-1, self.num_heads, self.qk_head_dim).split(
1972+
q_nope, q_rope = q.view(-1, self.num_heads_tp, self.qk_head_dim).split(
19731973
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
19741974
q_nope_out = torch.empty(
1975-
[num_tokens, self.num_heads, (self.kv_lora_rank)],
1975+
[num_tokens, self.num_heads_tp, (self.kv_lora_rank)],
19761976
dtype=q.dtype,
19771977
device=q.device,
19781978
)
@@ -2011,23 +2011,23 @@ def forward_sparse_mla_kvcache_bf16(
20112011
# FlashMLA sparse kernel (bf16) requires num_heads=128 on sm100 or multiple of 64 on sm90
20122012
if sm_version >= 100:
20132013
padding = 128
2014-
assert self.num_heads <= padding, (
2014+
assert self.num_heads_tp <= padding, (
20152015
f"SM100 FlashMLA sparse kernel requires exactly {padding} heads, "
2016-
f"got {self.num_heads}. Padding from values > {padding} is not supported."
2016+
f"got {self.num_heads_tp}. Padding from values > {padding} is not supported."
20172017
)
20182018
else: # SM90
2019-
padding = ((self.num_heads + 63) // 64) * 64 # multiple of 64
2019+
padding = ((self.num_heads_tp + 63) // 64) * 64 # multiple of 64
20202020

2021-
if self.num_heads != padding:
2021+
if self.num_heads_tp != padding:
20222022
logger.warning_once(
2023-
f"Padding num_heads from {self.num_heads} to {padding} "
2023+
f"Padding num_heads from {self.num_heads_tp} to {padding} "
20242024
f"due to FlashMLA sparse attention kernel requirement",
20252025
key="sparse_mla_padding_warning")
20262026

20272027
# Create padded tensor with zeros for extra heads
20282028
q_padded = q_concat.new_empty(
20292029
(num_tokens, padding, q_concat.shape[2]))
2030-
q_padded[:, :self.num_heads, :] = q_concat
2030+
q_padded[:, :self.num_heads_tp, :] = q_concat
20312031
q_concat = q_padded
20322032

20332033
# Convert indices and return all-layer KV pool
@@ -2049,17 +2049,17 @@ def forward_sparse_mla_kvcache_bf16(
20492049
"flash_mla_sparse_fwd not available. Please ensure FlashMLA module is built."
20502050
)
20512051

2052-
# [seq, num_heads, kv_lora_rank]
2053-
attn_out_latent = attn_out_latent[:, :self.
2054-
num_heads, :] # account for padding
2052+
# [seq, num_heads, kv_lora_rank], account for padding
2053+
attn_out_latent = attn_out_latent[:, :self.num_heads_tp, :]
20552054
# TODO: seems we need .contiguous() here when padding enabled before pass to bmm?
20562055
attn_out_latent = attn_out_latent.view(
2057-
[-1, self.num_heads, self.kv_lora_rank])
2056+
[-1, self.num_heads_tp, self.kv_lora_rank])
20582057

20592058
assert (attn_out_latent.shape[0] == q.shape[0]
2060-
and attn_out_latent.shape[1] == self.num_heads)
2059+
and attn_out_latent.shape[1] == self.num_heads_tp)
20612060

2062-
attn_output = output.view([num_tokens, self.num_heads, self.v_head_dim])
2061+
attn_output = output.view(
2062+
[num_tokens, self.num_heads_tp, self.v_head_dim])
20632063

20642064
if self.v_b_proj.dtype == torch.bfloat16:
20652065
# [num_heads, seq, kv_lora_rank] x [num_heads, kv_lora_rank, v_head_dim]

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,8 +2380,9 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
23802380
(8, 1, 8, 0, False, True, True, True, 24, "_DEFAULT"),
23812381
(8, 1, 8, 1, False, True, True, True, 24, "_DEFAULT"),
23822382
(8, 1, 8, 0, True, True, True, True, 24, "_DEFAULT"),
2383+
(8, 1, 8, 1, False, False, True, True, 1, "TRTLLM"),
23832384
],
2384-
ids=["baseline", "baseline_mtp1", "baseline_fp8kv"])
2385+
ids=["baseline", "baseline_mtp1", "baseline_fp8kv", "latency"])
23852386
def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
23862387
attention_dp, cuda_graph, overlap_scheduler,
23872388
max_batch_size, moe_backend):
@@ -2447,8 +2448,9 @@ def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
24472448
(8, 1, 8, 0, False, True, True, True, 24, "CUTLASS"),
24482449
(8, 1, 8, 1, False, True, True, True, 24, "CUTLASS"),
24492450
(8, 1, 8, 0, True, True, True, True, 24, "CUTLASS"),
2451+
(8, 1, 8, 1, False, False, True, True, 1, "TRTLLM"),
24502452
],
2451-
ids=["baseline", "baseline_mtp1", "baseline_fp8kv"])
2453+
ids=["baseline", "baseline_mtp1", "baseline_fp8kv", "latency"])
24522454
def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
24532455
attention_dp, cuda_graph, overlap_scheduler,
24542456
max_batch_size, moe_backend):

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,9 +496,11 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale_chunked_pr
496496
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline]
497497
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_mtp1]
498498
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv]
499+
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency]
499500
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline]
500501
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_mtp1]
501502
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_fp8kv]
503+
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency]
502504
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
503505
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False]
504506
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=True]

tests/integration/test_lists/qa/llm_function_core_sanity.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,11 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_c
5151
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline]
5252
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_mtp1]
5353
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv]
54+
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency]
5455
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline]
5556
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_mtp1]
5657
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_fp8kv]
58+
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency]
5759
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0]
5860
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2]
5961
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False]

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,9 @@ l0_dgx_b200:
125125
tests:
126126
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (180)
127127
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv] TIMEOUT (180)
128+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency] TIMEOUT (180)
128129
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_fp8kv] TIMEOUT (180)
130+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency] TIMEOUT (180)
129131
- condition:
130132
ranges:
131133
system_gpu_count:

tests/integration/test_lists/test-db/l0_dgx_h200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ l0_dgx_h200:
1818
# - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput] # OOM
1919
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] # 1h
2020
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline]
21+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency]
2122
- accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True]
2223
- accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False]
2324
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=True]

0 commit comments

Comments
 (0)