Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ common-files: &common_files |
examples/infinitebench/compute_scores.py |
examples/infinitebench/construct_synthetic_dataset.py |
examples/infinitebench/eval_utils.py |
examples/layer_wise_benchmarks/run_single.py |
examples/llm-api/_tensorrt_engine/llm_eagle_decoding.py |
examples/llm-api/_tensorrt_engine/llm_eagle2_decoding.py |
examples/llm-api/_tensorrt_engine/llm_inference_customize.py |
Expand Down Expand Up @@ -811,7 +810,6 @@ common-files: &common_files |
tensorrt_llm/serve/tool_parser/utils.py |
tensorrt_llm/tools/__init__.py |
tensorrt_llm/tools/importlib_utils.py |
tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py |
tensorrt_llm/tools/multimodal_builder.py |
tensorrt_llm/tools/onnx_utils.py |
tensorrt_llm/tools/plugin_gen/__init__.py |
Expand Down Expand Up @@ -1188,7 +1186,6 @@ common-files: &common_files |
tests/unittest/tools/plugin_gen/test_core.py |
tests/unittest/tools/plugin_gen/test_plugin_gen.py |
tests/unittest/tools/plugin_gen/test_shape_infer.py |
tests/unittest/tools/test_layer_wise_benchmarks.py |
tests/unittest/tools/test_prepare_dataset.py |
tests/unittest/tools/test_test_to_stage_mapping.py |
tests/unittest/trt/__init__.py |
Expand Down
16 changes: 13 additions & 3 deletions examples/layer_wise_benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pip install -e ../..
**Step 3:** In the container, run benchmarks and generate profiles:

```bash
# Run DeepSeek-R1
# Run DeepSeek-R1 NVFP4
NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml

Expand All @@ -24,7 +24,7 @@ NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model deepseek-ai/DeepSee
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM

# Run DeepSeek-V3.2-Exp with 32k context length
NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --max-num-tokens $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769
NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --seq-len-kv-cache 32769

# Run with attention TP
Expand All @@ -48,6 +48,10 @@ NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --scaled-from 16 --moe-back
# Scale TEP=16 to 4 GPUs: reduce the number of attention heads and experts
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --scaled-from 16 --no-enable-attention-dp

# Run Qwen3-Next (balanced routing is not implemented)
NP=2 TRTLLM_ENABLE_PDL=1 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified
NP=2 TRTLLM_ENABLE_PDL=1 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified

# Run with DeepEP A2A
NP=4 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEP ./mpi_launch.sh ./run_single.sh config_ctx.yaml --moe-backend WIDEEP
NP=4 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEP ./mpi_launch.sh ./run_single.sh config_gen.yaml --moe-backend WIDEEP
Expand Down Expand Up @@ -76,7 +80,7 @@ It uses the image recorded in `../../jenkins/current_image_tags.properties`. The
**Step 3:** Run benchmarks to generate profiles. Run the following command on the controller node, where `NODES` ≤ the number of allocated nodes:

```bash
# Run DeepSeek-R1 with wide ep: uses MNNVL A2A if applicable
# Run DeepSeek-R1 NVFP4 with wide ep: uses MNNVL A2A if applicable
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 ./slurm_launch.sh ./run_single.sh config_gen.yaml --moe-backend WIDEEP

# Run with attention TP and TRTLLMGen
Expand All @@ -93,3 +97,9 @@ SLURM_JOB_ID=$SLURM_JOB_ID NODES=2 NP=8 ./slurm_launch.sh ./run_single.sh config
## Parse profiles

Coming soon.

## Trouble shooting

1. Error `fp8 blockscale gemm only support Hopper` on Blackwell.

The default MoE backend "CUTLASS" does not support FP8 weights. Please choose the same MoE backend as your end-to-end config. A typical choice is adding `--moe-backend DEEPGEMM`, `--moe-backend TRTLLM`, or `--moe-backend WIDEEP` option.
1 change: 0 additions & 1 deletion examples/layer_wise_benchmarks/config_ctx.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ max_seq_len: 9220 # 8192 + 1024 + 4
enable_attention_dp: true

# Model init args
max_num_tokens: 20480
moe_backend: CUTLASS
use_cuda_graph: false

Expand Down
1 change: 0 additions & 1 deletion examples/layer_wise_benchmarks/config_gen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ max_seq_len: 9220 # 8192 + 1024 + 4
enable_attention_dp: true

# Model init args
max_num_tokens: 4096 # MTP3 as max
moe_backend: CUTLASS
use_cuda_graph: true

Expand Down
116 changes: 57 additions & 59 deletions examples/layer_wise_benchmarks/run_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
from tensorrt_llm.tools.layer_wise_benchmarks.deepseekv3_runner import (
BalanceMethod, DeepSeekV3Runner)
from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, get_runner_cls


def comma_separated_ints(s):
Expand All @@ -23,30 +22,25 @@ def comma_separated_ints(s):
parser.add_argument(
"--layer-indices",
type=comma_separated_ints,
help="Comma separated indices of layers, should be a contiguous range")
help="Comma separated indices of layers, should be a contiguous range",
)
parser.add_argument("--run-type", type=str, choices=["CTX", "GEN"])
parser.add_argument("--scaled-from", type=int)
# KV cache related args
parser.add_argument("--max-batch-size", type=int)
parser.add_argument("--tokens-per-block", type=int)
parser.add_argument("--max-seq-len", type=int)
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--enable-attention-dp",
action="store_true",
dest="enable_attention_dp")
group.add_argument("--no-enable-attention-dp",
action="store_false",
dest="enable_attention_dp")
group.add_argument("--enable-attention-dp", action="store_true", dest="enable_attention_dp")
group.add_argument("--no-enable-attention-dp", action="store_false", dest="enable_attention_dp")
parser.set_defaults(enable_attention_dp=None)
# Model init args
parser.add_argument("--max-num-tokens", type=int)
parser.add_argument("--moe-backend", type=str)
parser.add_argument("--moe-max-num-tokens", type=int)
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--use-cuda-graph",
action="store_true",
dest="use_cuda_graph")
group.add_argument("--no-use-cuda-graph",
action="store_false",
dest="use_cuda_graph")
group.add_argument("--use-cuda-graph", action="store_true", dest="use_cuda_graph")
group.add_argument("--no-use-cuda-graph", action="store_false", dest="use_cuda_graph")
parser.set_defaults(use_cuda_graph=None)
# Per iteration args
parser.add_argument("--batch-size", type=int)
Expand All @@ -59,8 +53,12 @@ def comma_separated_ints(s):
config = yaml.safe_load(f)
del args.config_path
for k, v in vars(args).items():
if v is None:
if v is None and k in config:
setattr(args, k, config[k])
if args.max_batch_size is None:
args.max_batch_size = args.batch_size
if args.max_num_tokens is None:
args.max_num_tokens = args.max_batch_size * args.seq_len_q
print(args)

# MPI args
Expand All @@ -70,43 +68,49 @@ def comma_separated_ints(s):
torch.cuda.set_device(local_rank)

# Create KV cache manager
mapping = DeepSeekV3Runner.create_mapping(
enable_attention_dp=args.enable_attention_dp)
max_batch_size = 2048
kv_cache_manager = DeepSeekV3Runner.create_kv_cache_manager(
Runner = get_runner_cls(args.model)
mapping = Runner.create_mapping(enable_attention_dp=args.enable_attention_dp)
kv_cache_manager = Runner.create_kv_cache_manager(
args.model,
mapping,
tokens_per_block=args.tokens_per_block,
max_batch_size=max_batch_size,
max_batch_size=args.max_batch_size,
max_seq_len=args.max_seq_len,
layer_indices=args.layer_indices)
attn_workspace = torch.empty((0, ), device="cuda", dtype=torch.int8)
layer_indices=args.layer_indices,
)
attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8)

# Create other global objects
AutoTuner.get().clear_cache()
capture_stream = torch.cuda.Stream()

# Create Runner
runner = DeepSeekV3Runner(args.model,
mapping,
moe_backend=args.moe_backend,
layer_indices=args.layer_indices,
scaled_from=args.scaled_from,
max_seq_len=args.max_seq_len,
max_num_tokens=args.max_num_tokens,
use_cuda_graph=args.use_cuda_graph)
runner = Runner(
args.model,
mapping,
moe_backend=args.moe_backend,
layer_indices=args.layer_indices,
scaled_from=args.scaled_from,
max_seq_len=args.max_seq_len,
max_num_tokens=args.max_num_tokens,
moe_max_num_tokens=args.moe_max_num_tokens,
use_cuda_graph=args.use_cuda_graph,
)

# Warm up
assert args.batch_size <= max_batch_size
assert args.batch_size <= args.max_batch_size
assert args.seq_len_q + args.seq_len_kv_cache <= args.max_seq_len
run_pack = runner.create_run_pack(args.run_type,
batch_size=args.batch_size,
seq_len_q=args.seq_len_q,
seq_len_kv_cache=args.seq_len_kv_cache,
kv_cache_manager=kv_cache_manager,
attn_workspace=attn_workspace)
runner.replace_routing_method(balance_method=BalanceMethod[args.balance_method],
balance_ratio=args.balance_ratio)
run_pack = runner.create_run_pack(
args.run_type,
batch_size=args.batch_size,
seq_len_q=args.seq_len_q,
seq_len_kv_cache=args.seq_len_kv_cache,
kv_cache_manager=kv_cache_manager,
attn_workspace=attn_workspace,
)
runner.replace_routing_method(
balance_method=BalanceMethod[args.balance_method], balance_ratio=args.balance_ratio
)
capture_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(capture_stream):
run_pack()
Expand All @@ -120,21 +124,15 @@ def comma_separated_ints(s):
if args.use_cuda_graph:
with with_multi_stream(True):
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g,
stream=capture_stream,
capture_error_mode="global"):
with torch.cuda.graph(g, stream=capture_stream, capture_error_mode="global"):
run_pack()

warmup_times = 20
run_times = 100
events = [
torch.cuda.Event(enable_timing=True)
for _ in range(warmup_times + run_times + 1)
]
events = [torch.cuda.Event(enable_timing=True) for _ in range(warmup_times + run_times + 1)]
for i in range(warmup_times + run_times):
events[i].record()
with nvtx.annotate(
f"b={args.batch_size} s={args.seq_len_q} EP{world_size}"):
with nvtx.annotate(f"b={args.batch_size} s={args.seq_len_q} EP{world_size}"):
if args.use_cuda_graph:
g.replay()
else:
Expand All @@ -144,16 +142,16 @@ def comma_separated_ints(s):

# Print statistics
# Print before `cudaProfilerStop` to ensure messages are included in the profile
time_list = [
start.elapsed_time(stop) for start, stop in zip(events, events[1:])
]
time_list = [start.elapsed_time(stop) for start, stop in zip(events, events[1:])]
time_list = time_list[warmup_times:]
print(f"[RANK {rank}]"
f" min {np.min(time_list) * 1000:.1f}"
f" max {np.max(time_list) * 1000:.1f}"
f" mean {np.mean(time_list) * 1000:.1f}"
f" median {np.median(time_list) * 1000:.1f}"
f" P90 {np.percentile(time_list, 90) * 1000:.1f}"
f" (us)")
print(
f"[RANK {rank}]"
f" min {np.min(time_list) * 1000:.1f}"
f" max {np.max(time_list) * 1000:.1f}"
f" mean {np.mean(time_list) * 1000:.1f}"
f" median {np.median(time_list) * 1000:.1f}"
f" P90 {np.percentile(time_list, 90) * 1000:.1f}"
f" (us)"
)

torch.cuda.cudart().cudaProfilerStop()
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ exclude = [
"examples/infinitebench/compute_scores.py",
"examples/infinitebench/construct_synthetic_dataset.py",
"examples/infinitebench/eval_utils.py",
"examples/layer_wise_benchmarks/run_single.py",
"examples/llm-api/_tensorrt_engine/llm_eagle_decoding.py",
"examples/llm-api/_tensorrt_engine/llm_eagle2_decoding.py",
"examples/llm-api/_tensorrt_engine/llm_inference_customize.py",
Expand Down Expand Up @@ -851,7 +850,6 @@ exclude = [
"tensorrt_llm/serve/tool_parser/utils.py",
"tensorrt_llm/tools/__init__.py",
"tensorrt_llm/tools/importlib_utils.py",
"tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py",
"tensorrt_llm/tools/multimodal_builder.py",
"tensorrt_llm/tools/onnx_utils.py",
"tensorrt_llm/tools/plugin_gen/__init__.py",
Expand Down Expand Up @@ -1228,7 +1226,6 @@ exclude = [
"tests/unittest/tools/plugin_gen/test_core.py",
"tests/unittest/tools/plugin_gen/test_plugin_gen.py",
"tests/unittest/tools/plugin_gen/test_shape_infer.py",
"tests/unittest/tools/test_layer_wise_benchmarks.py",
"tests/unittest/tools/test_prepare_dataset.py",
"tests/unittest/tools/test_test_to_stage_mapping.py",
"tests/unittest/trt/__init__.py",
Expand Down
7 changes: 7 additions & 0 deletions tensorrt_llm/tools/layer_wise_benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .runner_factory import get_runner_cls
from .runner_interface import BalanceMethod

__all__ = [
"BalanceMethod",
"get_runner_cls",
]
Loading