From 7816b85cc7007ecf29354c47a29ab9173d406127 Mon Sep 17 00:00:00 2001 From: Tailing Yuan Date: Wed, 5 Nov 2025 11:12:10 +0000 Subject: [PATCH 1/9] Refine arguments Signed-off-by: Tailing Yuan --- examples/layer_wise_benchmarks/README.md | 12 +++++++++--- examples/layer_wise_benchmarks/config_ctx.yaml | 1 - examples/layer_wise_benchmarks/config_gen.yaml | 1 - examples/layer_wise_benchmarks/run_single.py | 14 ++++++++++---- .../layer_wise_benchmarks/deepseekv3_runner.py | 5 +++-- 5 files changed, 22 insertions(+), 11 deletions(-) diff --git a/examples/layer_wise_benchmarks/README.md b/examples/layer_wise_benchmarks/README.md index b0ca53e9aae..bf927ed278a 100644 --- a/examples/layer_wise_benchmarks/README.md +++ b/examples/layer_wise_benchmarks/README.md @@ -15,7 +15,7 @@ pip install -e ../.. **Step 3:** In the container, run benchmarks and generate profiles: ```bash -# Run DeepSeek-R1 +# Run DeepSeek-R1 NVFP4 NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml @@ -24,7 +24,7 @@ NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model deepseek-ai/DeepSee NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM # Run DeepSeek-V3.2-Exp with 32k context length -NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --max-num-tokens $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769 +NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769 NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --seq-len-kv-cache 32769 # Run with attention TP @@ -76,7 +76,7 @@ It uses the image recorded in `../../jenkins/current_image_tags.properties`. The **Step 3:** Run benchmarks to generate profiles. Run the following command on the controller node, where `NODES` ≤ the number of allocated nodes: ```bash -# Run DeepSeek-R1 with wide ep: uses MNNVL A2A if applicable +# Run DeepSeek-R1 NVFP4 with wide ep: uses MNNVL A2A if applicable SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 ./slurm_launch.sh ./run_single.sh config_gen.yaml --moe-backend WIDEEP # Run with attention TP and TRTLLMGen @@ -93,3 +93,9 @@ SLURM_JOB_ID=$SLURM_JOB_ID NODES=2 NP=8 ./slurm_launch.sh ./run_single.sh config ## Parse profiles Coming soon. + +## Trouble shooting + +1. Error `fp8 blockscale gemm only support Hopper` on Blackwell. + + The default MoE backend "CUTLASS" does not support FP8 weights. Please choose the same MoE backend as your end-to-end config. A typical choice is adding `--moe-backend DEEPGEMM`, `--moe-backend TRTLLM`, or `--moe-backend WIDEEP` option. diff --git a/examples/layer_wise_benchmarks/config_ctx.yaml b/examples/layer_wise_benchmarks/config_ctx.yaml index 13a637e1624..07f5a8b7bd2 100644 --- a/examples/layer_wise_benchmarks/config_ctx.yaml +++ b/examples/layer_wise_benchmarks/config_ctx.yaml @@ -9,7 +9,6 @@ max_seq_len: 9220 # 8192 + 1024 + 4 enable_attention_dp: true # Model init args -max_num_tokens: 20480 moe_backend: CUTLASS use_cuda_graph: false diff --git a/examples/layer_wise_benchmarks/config_gen.yaml b/examples/layer_wise_benchmarks/config_gen.yaml index 9ad86f8e594..494d4b0ff40 100644 --- a/examples/layer_wise_benchmarks/config_gen.yaml +++ b/examples/layer_wise_benchmarks/config_gen.yaml @@ -9,7 +9,6 @@ max_seq_len: 9220 # 8192 + 1024 + 4 enable_attention_dp: true # Model init args -max_num_tokens: 4096 # MTP3 as max moe_backend: CUTLASS use_cuda_graph: true diff --git a/examples/layer_wise_benchmarks/run_single.py b/examples/layer_wise_benchmarks/run_single.py index 79d4bbe5019..b840b4366f9 100644 --- a/examples/layer_wise_benchmarks/run_single.py +++ b/examples/layer_wise_benchmarks/run_single.py @@ -27,6 +27,7 @@ def comma_separated_ints(s): parser.add_argument("--run-type", type=str, choices=["CTX", "GEN"]) parser.add_argument("--scaled-from", type=int) # KV cache related args +parser.add_argument("--max-batch-size", type=int) parser.add_argument("--tokens-per-block", type=int) parser.add_argument("--max-seq-len", type=int) group = parser.add_mutually_exclusive_group(required=False) @@ -40,6 +41,7 @@ def comma_separated_ints(s): # Model init args parser.add_argument("--max-num-tokens", type=int) parser.add_argument("--moe-backend", type=str) +parser.add_argument("--moe-max-num-tokens", type=int) group = parser.add_mutually_exclusive_group(required=False) group.add_argument("--use-cuda-graph", action="store_true", @@ -59,8 +61,12 @@ def comma_separated_ints(s): config = yaml.safe_load(f) del args.config_path for k, v in vars(args).items(): - if v is None: + if v is None and k in config: setattr(args, k, config[k]) +if args.max_batch_size is None: + args.max_batch_size = args.batch_size +if args.max_num_tokens is None: + args.max_num_tokens = args.max_batch_size * args.seq_len_q print(args) # MPI args @@ -72,12 +78,11 @@ def comma_separated_ints(s): # Create KV cache manager mapping = DeepSeekV3Runner.create_mapping( enable_attention_dp=args.enable_attention_dp) -max_batch_size = 2048 kv_cache_manager = DeepSeekV3Runner.create_kv_cache_manager( args.model, mapping, tokens_per_block=args.tokens_per_block, - max_batch_size=max_batch_size, + max_batch_size=args.max_batch_size, max_seq_len=args.max_seq_len, layer_indices=args.layer_indices) attn_workspace = torch.empty((0, ), device="cuda", dtype=torch.int8) @@ -94,10 +99,11 @@ def comma_separated_ints(s): scaled_from=args.scaled_from, max_seq_len=args.max_seq_len, max_num_tokens=args.max_num_tokens, + moe_max_num_tokens=args.moe_max_num_tokens, use_cuda_graph=args.use_cuda_graph) # Warm up -assert args.batch_size <= max_batch_size +assert args.batch_size <= args.max_batch_size assert args.seq_len_q + args.seq_len_kv_cache <= args.max_seq_len run_pack = runner.create_run_pack(args.run_type, batch_size=args.batch_size, diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py b/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py index 2bc9ef8a9c3..1627c224739 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py @@ -142,7 +142,8 @@ class DeepSeekV3Runner: def __init__(self, pretrained_model_name_or_path: str, mapping: Mapping, *, moe_backend: str, layer_indices: List[int], scaled_from: Optional[int], max_seq_len: int, - max_num_tokens: int, use_cuda_graph: bool): + max_num_tokens: int, moe_max_num_tokens: int, + use_cuda_graph: bool): # Temporally replace the gate class gate_cls_orig = tensorrt_llm._torch.models.modeling_deepseekv3.DeepseekV3Gate @@ -158,7 +159,7 @@ def __init__(self, pretrained_model_name_or_path: str, mapping: Mapping, *, sparse_attention_config=None, # To be loaded from config max_num_tokens=max_num_tokens, max_seq_len=max_seq_len, - moe_max_num_tokens=None, + moe_max_num_tokens=moe_max_num_tokens, moe_load_balancer=None, lora_config=None, allreduce_strategy=AllReduceStrategy.AUTO, From d6899e4aed0b7414d12039474f1dd3edfbd08c48 Mon Sep 17 00:00:00 2001 From: Tailing Yuan Date: Tue, 11 Nov 2025 03:31:31 +0000 Subject: [PATCH 2/9] Add Qwen3-Next layer-wise benchmarks Signed-off-by: Tailing Yuan --- examples/layer_wise_benchmarks/README.md | 4 + examples/layer_wise_benchmarks/run_single.py | 29 +- .../deepseekv3_runner.py | 304 +++--------------- .../qwen3_next_runner.py | 90 ++++++ .../layer_wise_benchmarks/runner_base.py | 49 +++ .../layer_wise_benchmarks/runner_factory.py | 13 + .../layer_wise_benchmarks/runner_utils.py | 303 +++++++++++++++++ 7 files changed, 521 insertions(+), 271 deletions(-) create mode 100644 tensorrt_llm/tools/layer_wise_benchmarks/qwen3_next_runner.py create mode 100644 tensorrt_llm/tools/layer_wise_benchmarks/runner_base.py create mode 100644 tensorrt_llm/tools/layer_wise_benchmarks/runner_factory.py create mode 100644 tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py diff --git a/examples/layer_wise_benchmarks/README.md b/examples/layer_wise_benchmarks/README.md index bf927ed278a..79cbb745022 100644 --- a/examples/layer_wise_benchmarks/README.md +++ b/examples/layer_wise_benchmarks/README.md @@ -48,6 +48,10 @@ NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --scaled-from 16 --moe-back # Scale TEP=16 to 4 GPUs: reduce the number of attention heads and experts NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --scaled-from 16 --no-enable-attention-dp +# Run Qwen3-Next (balanced routing is not implemented) +NP=2 TRTLLM_ENABLE_PDL=1 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified +NP=2 TRTLLM_ENABLE_PDL=1 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified + # Run with DeepEP A2A NP=4 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEP ./mpi_launch.sh ./run_single.sh config_ctx.yaml --moe-backend WIDEEP NP=4 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEP ./mpi_launch.sh ./run_single.sh config_gen.yaml --moe-backend WIDEEP diff --git a/examples/layer_wise_benchmarks/run_single.py b/examples/layer_wise_benchmarks/run_single.py index b840b4366f9..da700907bb3 100644 --- a/examples/layer_wise_benchmarks/run_single.py +++ b/examples/layer_wise_benchmarks/run_single.py @@ -8,8 +8,9 @@ from tensorrt_llm._torch.autotuner import AutoTuner, autotune from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size -from tensorrt_llm.tools.layer_wise_benchmarks.deepseekv3_runner import ( - BalanceMethod, DeepSeekV3Runner) +from tensorrt_llm.tools.layer_wise_benchmarks.runner_base import BalanceMethod +from tensorrt_llm.tools.layer_wise_benchmarks.runner_factory import \ + get_runner_cls def comma_separated_ints(s): @@ -76,9 +77,9 @@ def comma_separated_ints(s): torch.cuda.set_device(local_rank) # Create KV cache manager -mapping = DeepSeekV3Runner.create_mapping( - enable_attention_dp=args.enable_attention_dp) -kv_cache_manager = DeepSeekV3Runner.create_kv_cache_manager( +Runner = get_runner_cls(args.model) +mapping = Runner.create_mapping(enable_attention_dp=args.enable_attention_dp) +kv_cache_manager = Runner.create_kv_cache_manager( args.model, mapping, tokens_per_block=args.tokens_per_block, @@ -92,15 +93,15 @@ def comma_separated_ints(s): capture_stream = torch.cuda.Stream() # Create Runner -runner = DeepSeekV3Runner(args.model, - mapping, - moe_backend=args.moe_backend, - layer_indices=args.layer_indices, - scaled_from=args.scaled_from, - max_seq_len=args.max_seq_len, - max_num_tokens=args.max_num_tokens, - moe_max_num_tokens=args.moe_max_num_tokens, - use_cuda_graph=args.use_cuda_graph) +runner = Runner(args.model, + mapping, + moe_backend=args.moe_backend, + layer_indices=args.layer_indices, + scaled_from=args.scaled_from, + max_seq_len=args.max_seq_len, + max_num_tokens=args.max_num_tokens, + moe_max_num_tokens=args.moe_max_num_tokens, + use_cuda_graph=args.use_cuda_graph) # Warm up assert args.batch_size <= args.max_batch_size diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py b/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py index 1627c224739..4469096eb2d 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py @@ -1,45 +1,20 @@ import functools -import os -import weakref -from enum import IntEnum from typing import List, Optional import torch import tensorrt_llm._torch.models.modeling_deepseekv3 -from tensorrt_llm._torch.attention_backend.utils import get_attention_backend -from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_deepseekv3 import ( DeepseekV3DecoderLayer, DeepseekV3Gate) -from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import WideEPMoE -from tensorrt_llm._torch.modules.linear import Linear, WeightMode from tensorrt_llm._torch.modules.rms_norm import RMSNorm -from tensorrt_llm._torch.pyexecutor._util import get_kv_cache_manager_cls -from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager -from tensorrt_llm._torch.utils import (AuxStreamType, get_model_extra_attrs, - model_extra_attrs) -from tensorrt_llm._utils import (local_mpi_size, mpi_rank, mpi_world_size, - torch_dtype_to_binding) -from tensorrt_llm.bindings.executor import KvCacheConfig +from tensorrt_llm._torch.utils import AuxStreamType +from tensorrt_llm._utils import mpi_rank, mpi_world_size from tensorrt_llm.functional import AllReduceStrategy from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models.modeling_utils import QuantConfig - -class BalanceMethod(IntEnum): - NotModified = 1 - Balanced = 2 - ImbalancedRanks = 3 - ImbalancedExperts = 4 - - -def ceil_div(a, b): - return (a + b - 1) // b - - -def round_up(a, b): - return ceil_div(a, b) * b +from .runner_base import BalanceMethod, RunnerBase +from .runner_utils import RunnerMixin, ceil_div class RoutingMethod(DeepseekV3Gate): @@ -137,7 +112,11 @@ def get_balanced_rank_imbalanced_expert_selection(num_tokens, top_k, world_size, rank) -class DeepSeekV3Runner: +class DeepSeekV3Runner(RunnerMixin, RunnerBase): + + @staticmethod + def has_mamba_metadata() -> bool: + return False def __init__(self, pretrained_model_name_or_path: str, mapping: Mapping, *, moe_backend: str, layer_indices: List[int], @@ -170,176 +149,48 @@ def __init__(self, pretrained_model_name_or_path: str, mapping: Mapping, *, use_low_precision_moe_combine=False, skip_create_weights_in_init=True, ) - pretrained_config = self.model_config.pretrained_config - if scaled_from is not None: - # To run the problem size of $B$ GPUs on $A$ GPUs, we need: - # (1) Attention: If TP, reduce the number of attention heads; If DP, nothing to change. - # (2) MoE: If EP, reduce the number of experts; If TP, reduce head size. - # Maintain the result of AllToAll method selection because it is affected by EP size. - if not mapping.enable_attention_dp: - if hasattr(pretrained_config, "index_n_heads"): - raise NotImplementedError( - "Not support Indexer TP for weak scaling") - pretrained_config.num_attention_heads = pretrained_config.num_attention_heads // scaled_from * mapping.tp_size - pretrained_config.num_key_value_heads = pretrained_config.num_key_value_heads // scaled_from * mapping.tp_size - if mapping.moe_ep_size != mapping.world_size: - raise NotImplementedError("Not support MoE TP for weak scaling") - pretrained_config.n_routed_experts = pretrained_config.n_routed_experts // scaled_from * mapping.moe_ep_size - select_alltoall_method_type_orig = WideEPMoE.select_alltoall_method_type - - def select_alltoall_method_type(cls: type, mapping: Mapping, - top_k: int, *args, **kwargs): - # Replace the condition `mapping.moe_ep_size <= top_k` with `scaled_from <= top_k` - # by replacing `top_k` with `fake_top_k` - if scaled_from <= top_k: - fake_top_k = mapping.moe_ep_size + 1 - else: - fake_top_k = mapping.moe_ep_size - 1 - assert (mapping.moe_ep_size <= fake_top_k) == (scaled_from - <= top_k) - return select_alltoall_method_type_orig(mapping, fake_top_k, - *args, **kwargs) - - WideEPMoE.select_alltoall_method_type = select_alltoall_method_type - - aux_stream_list = [torch.cuda.Stream() for _ in range(2)] - aux_stream_dict = { - AuxStreamType.Attention: aux_stream_list[0], - AuxStreamType.MoeShared: aux_stream_list[0], - AuxStreamType.MoeChunkingOverlap: aux_stream_list[1], - } - - layers = [ - DeepseekV3DecoderLayer( - model_config=self.model_config, - layer_idx=layer_idx, - aux_stream_dict=aux_stream_dict, - ) for layer_idx in layer_indices - ] - next_layer_layernorm = RMSNorm( - hidden_size=pretrained_config.hidden_size, - eps=pretrained_config.rms_norm_eps, - dtype=pretrained_config.torch_dtype) - - # apply_quant_config_exclude_modules - # Please refer to tensorrt_llm/_torch/models/modeling_utils.py - quant_config = self.model_config.quant_config - new_quant_config = QuantConfig( - kv_cache_quant_algo=quant_config.kv_cache_quant_algo) - for layer in layers: - for name, module in layer.named_modules(): - name = f"model.layers.{layer.layer_idx}.{name}" - candidates = [name] - if isinstance(module, Linear): - weight_mode = module.weights_loading_config.weight_mode - if weight_mode == WeightMode.FUSED_GATE_UP_LINEAR: - # sometimes gate and up proj are not packed in the checkpoint, - # but they still share the same exclusion rule - candidates += [ - name.replace('gate_up_proj', 'gate_proj'), - name.replace('gate_up_proj', 'up_proj') - ] - elif weight_mode == WeightMode.FUSED_QKV_LINEAR: - # sometimes q_proj, k_proj and v_proj are not packed in the checkpoint, - # but they still share the same exclusion rule - candidates += [ - name.replace('qkv_proj', 'q_proj'), - name.replace('qkv_proj', 'k_proj'), - name.replace('qkv_proj', 'v_proj') - ] - is_excluded = any( - quant_config.is_module_excluded_from_quantization(n) - for n in candidates) - if is_excluded and getattr(module, "quant_config", - None) is not None: - module.quant_config = new_quant_config - for name, module in layer.named_modules(): - if callable(getattr(module, "create_weights", None)): - module.create_weights() - layer.cuda() - for name, module in layer.named_modules(): - if hasattr(module, 'post_load_weights') and not getattr( - module, '_weights_removed', False): - module.post_load_weights() - next_layer_layernorm.cuda() - for layer, next_layer in zip(layers[:-1], layers[1:]): - layer.next_layer_layernorm = next_layer.input_layernorm - layers[-1].next_layer_layernorm = next_layer_layernorm - self.layers = layers - if scaled_from is not None: - WideEPMoE.select_alltoall_method_type = select_alltoall_method_type_orig + with self.scaled_from_ctx(scaled_from, mapping, pretrained_config): + aux_stream_list = [torch.cuda.Stream() for _ in range(2)] + aux_stream_dict = { + AuxStreamType.Attention: aux_stream_list[0], + AuxStreamType.MoeShared: aux_stream_list[0], + AuxStreamType.MoeChunkingOverlap: aux_stream_list[1], + } + + layers = [ + DeepseekV3DecoderLayer( + model_config=self.model_config, + layer_idx=layer_idx, + aux_stream_dict=aux_stream_dict, + ) for layer_idx in layer_indices + ] + next_layer_layernorm = RMSNorm( + hidden_size=pretrained_config.hidden_size, + eps=pretrained_config.rms_norm_eps, + dtype=pretrained_config.torch_dtype) + + # TODO: apply_layerwise_quant_config + self.apply_quant_config_exclude_modules( + layers, self.model_config.quant_config) + for layer in layers: + for module in layer.modules(): + if callable(getattr(module, "create_weights", None)): + module.create_weights() + layer.cuda() + for module in layer.modules(): + if hasattr(module, 'post_load_weights') and not getattr( + module, '_weights_removed', False): + module.post_load_weights() + next_layer_layernorm.cuda() + for layer, next_layer in zip(layers[:-1], layers[1:]): + layer.next_layer_layernorm = next_layer.input_layernorm + layers[-1].next_layer_layernorm = next_layer_layernorm + + self.layers = layers tensorrt_llm._torch.models.modeling_deepseekv3.DeepseekV3Gate = gate_cls_orig - def create_run_pack(self, - run_type: str, - batch_size: int, - seq_len_q: int, - seq_len_kv_cache: int, - kv_cache_manager: KVCacheManager, - attn_workspace: Optional[torch.Tensor] = None): - if self.model_config.moe_backend == "TRTLLM" and os.getenv( - "TRTLLM_ENABLE_PDL") != "1": - raise ValueError( - "Suggest to set TRTLLM_ENABLE_PDL=1 when moe_backend is TRTLLM") - world_size = mpi_world_size() - AttentionCls = get_attention_backend( - self.model_config.attn_backend, - self.model_config.sparse_attention_config) - attn_metadata = AttentionCls.Metadata( - seq_lens=torch.tensor([seq_len_q] * batch_size, dtype=torch.int), - request_ids=list(range(batch_size)), - max_num_requests=kv_cache_manager.max_batch_size, - num_contexts={ - "CTX": batch_size, - "GEN": 0, - }[run_type], - prompt_lens=[{ - "CTX": seq_len_q, - "GEN": seq_len_kv_cache, - }[run_type]] * batch_size, - max_num_tokens=batch_size * seq_len_q, - kv_cache_manager=kv_cache_manager, - kv_cache_params=KVCacheParams( - use_cache=True, - num_cached_tokens_per_seq=[seq_len_kv_cache] * batch_size, - ), - workspace=attn_workspace, - mapping=self.model_config.mapping, - sparse_attention_config=self.model_config.sparse_attention_config, - ) - attn_metadata.all_rank_num_tokens = [batch_size * seq_len_q - ] * world_size - attn_metadata.prepare() - with model_extra_attrs(self.model_config.extra_attrs): - get_model_extra_attrs()["attention_metadata"] = weakref.ref( - attn_metadata) - hidden_size = self.model_config.pretrained_config.hidden_size - position_ids = torch.tensor([ - list(range(seq_len_kv_cache, seq_len_kv_cache + seq_len_q)) * - batch_size - ], - dtype=torch.int32, - device="cuda") - hidden_states = torch.rand((batch_size * seq_len_q, hidden_size), - dtype=torch.bfloat16, - device="cuda") - residual = torch.rand((batch_size * seq_len_q, hidden_size), - dtype=torch.bfloat16, - device="cuda") - - def run_pack(): - output = hidden_states, residual - with model_extra_attrs(self.model_config.extra_attrs): - with torch.inference_mode(): - for layer in self.layers: - output = layer(position_ids, output[0], attn_metadata, - output[1]) - return output - - return run_pack - def replace_routing_method(self, balance_method: BalanceMethod, balance_ratio: float): if self.model_config.moe_backend not in [ @@ -351,64 +202,3 @@ def replace_routing_method(self, balance_method: BalanceMethod, for layer in self.layers: layer.mlp.gate.balance_method = balance_method layer.mlp.gate.balance_ratio = balance_ratio - - @staticmethod - def create_kv_cache_manager(pretrained_model_name_or_path, mapping, - tokens_per_block, max_batch_size, max_seq_len, - layer_indices): - # Please refer to `tensorrt_llm/_torch/pyexecutor/py_executor_creator.py` for `tokens_per_block` - model_config = ModelConfig.from_pretrained( - pretrained_model_name_or_path) - if model_config.enable_flash_mla: - assert tokens_per_block == 64 - - # Please refer to `tensorrt_llm/_torch/pyexecutor/_util.py` for `kv_cache_manager` - kv_cache_manager_cls = get_kv_cache_manager_cls(model_config) - kv_cache_manager = kv_cache_manager_cls( - KvCacheConfig( - max_tokens=max_batch_size * - round_up(max_seq_len, tokens_per_block), - enable_block_reuse=False, - ), - tensorrt_llm.bindings.internal.batch_manager.CacheType.SELFKONLY, - num_layers=len(layer_indices), - num_kv_heads=1, - head_dim=model_config.pretrained_config.kv_lora_rank + - model_config.pretrained_config.qk_rope_head_dim, - tokens_per_block=tokens_per_block, - max_seq_len=max_seq_len, - max_batch_size=max_batch_size, - mapping=mapping, - dtype=torch_dtype_to_binding({ - None: torch.bfloat16, - "FP8": torch.float8_e4m3fn, - }[model_config.quant_config.kv_cache_quant_algo]), - sparse_attn_config=model_config.sparse_attention_config, - ) - kv_cache_manager.layer_offsets = { - layer_idx: i - for i, layer_idx in enumerate(layer_indices) - } - kv_cache_manager.add_dummy_requests(list(range(max_batch_size)), - [max_seq_len] * max_batch_size) - return kv_cache_manager - - @staticmethod - def create_mapping(enable_attention_dp: bool): - world_size = mpi_world_size() - rank = mpi_rank() - mapping = Mapping( - world_size=world_size, - rank=rank, - gpus_per_node=local_mpi_size(), - cp_size=1, - tp_size=world_size, - pp_size=1, - moe_cluster_size=1, - moe_tp_size=1, - moe_ep_size=world_size, - attn_tp_size=world_size, - attn_cp_size=1, - enable_attention_dp=enable_attention_dp, - ) - return mapping diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/qwen3_next_runner.py b/tensorrt_llm/tools/layer_wise_benchmarks/qwen3_next_runner.py new file mode 100644 index 00000000000..19ee41054a5 --- /dev/null +++ b/tensorrt_llm/tools/layer_wise_benchmarks/qwen3_next_runner.py @@ -0,0 +1,90 @@ +from typing import List, Optional + +import torch + +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models.modeling_qwen3_next import ALL_DECODER_LAYER_TYPES +from tensorrt_llm._torch.modules.rms_norm import RMSNorm +from tensorrt_llm.functional import AllReduceStrategy +from tensorrt_llm.mapping import Mapping + +from .runner_base import RunnerBase +from .runner_utils import RunnerMixin + + +class Qwen3NextRunner(RunnerMixin, RunnerBase): + @staticmethod + def has_mamba_metadata() -> bool: + return True + + def __init__( + self, + pretrained_model_name_or_path: str, + mapping: Mapping, + *, + moe_backend: str, + layer_indices: List[int], + scaled_from: Optional[int], + max_seq_len: int, + max_num_tokens: int, + moe_max_num_tokens: int, + use_cuda_graph: bool, + ): + self.model_config = ModelConfig.from_pretrained( + pretrained_model_name_or_path, + mapping=mapping, + enable_min_latency=False, + use_cuda_graph=use_cuda_graph, + force_dynamic_quantization=False, + spec_config=None, + sparse_attention_config=None, # To be loaded from config + max_num_tokens=max_num_tokens, + max_seq_len=max_seq_len, + moe_max_num_tokens=moe_max_num_tokens, + moe_load_balancer=None, + lora_config=None, + allreduce_strategy=AllReduceStrategy.AUTO, + mm_encoder_only=False, + attn_backend="TRTLLM", + moe_backend=moe_backend, + moe_disable_finalize_fusion=False, + use_low_precision_moe_combine=False, + skip_create_weights_in_init=True, + ) + pretrained_config = self.model_config.pretrained_config + + with self.scaled_from_ctx(scaled_from, mapping, pretrained_config): + aux_stream = torch.cuda.Stream() + layers = [ + ALL_DECODER_LAYER_TYPES[pretrained_config.layer_types[layer_idx]]( + self.model_config, + layer_idx, + aux_stream, + ) + for layer_idx in layer_indices + ] + next_layer_layernorm = RMSNorm( + hidden_size=pretrained_config.hidden_size, + eps=pretrained_config.rms_norm_eps, + dtype=pretrained_config.torch_dtype, + use_gemma=True, + ) + + # TODO: apply_layerwise_quant_config + self.apply_quant_config_exclude_modules(layers, self.model_config.quant_config) + for layer in layers: + for module in layer.modules(): + if callable(getattr(module, "create_weights", None)): + module.create_weights() + layer.cuda() + for module in layer.modules(): + if hasattr(module, "post_load_weights") and not getattr( + module, "_weights_removed", False + ): + module.post_load_weights() + next_layer_layernorm.cuda() + for layer, next_layer in zip(layers[:-1], layers[1:]): + layer.next_layer_layernorm = next_layer.input_layernorm + layers[-1].next_layer_layernorm = next_layer_layernorm + + self.layers = layers diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/runner_base.py b/tensorrt_llm/tools/layer_wise_benchmarks/runner_base.py new file mode 100644 index 00000000000..20d672da167 --- /dev/null +++ b/tensorrt_llm/tools/layer_wise_benchmarks/runner_base.py @@ -0,0 +1,49 @@ +from abc import ABC, abstractmethod +from enum import IntEnum +from typing import Optional + +import torch + +from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager + + +class BalanceMethod(IntEnum): + NotModified = 1 + Balanced = 2 + ImbalancedRanks = 3 + ImbalancedExperts = 4 + + +class RunnerBase(ABC): + @abstractmethod + def create_run_pack( + self, + run_type: str, + batch_size: int, + seq_len_q: int, + seq_len_kv_cache: int, + kv_cache_manager: KVCacheManager, + attn_workspace: Optional[torch.Tensor] = None, + ): + pass + + @abstractmethod + def replace_routing_method(self, balance_method: BalanceMethod, balance_ratio: float): + pass + + @staticmethod + @abstractmethod + def create_kv_cache_manager( + pretrained_model_name_or_path, + mapping, + tokens_per_block, + max_batch_size, + max_seq_len, + layer_indices, + ): + pass + + @staticmethod + @abstractmethod + def create_mapping(enable_attention_dp: bool): + pass diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/runner_factory.py b/tensorrt_llm/tools/layer_wise_benchmarks/runner_factory.py new file mode 100644 index 00000000000..6d712f1f6d5 --- /dev/null +++ b/tensorrt_llm/tools/layer_wise_benchmarks/runner_factory.py @@ -0,0 +1,13 @@ +from tensorrt_llm._torch.pyexecutor.config_utils import load_pretrained_config + +from .deepseekv3_runner import DeepSeekV3Runner +from .qwen3_next_runner import Qwen3NextRunner + + +def get_runner_cls(pretrained_model_name_or_path: str): + pretrained_config = load_pretrained_config(pretrained_model_name_or_path) + return { + "deepseek_v3": DeepSeekV3Runner, + "deepseek_v32": DeepSeekV3Runner, + "qwen3_next": Qwen3NextRunner, + }[pretrained_config.model_type] diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py b/tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py new file mode 100644 index 00000000000..8a5f4cba5e4 --- /dev/null +++ b/tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py @@ -0,0 +1,303 @@ +import contextlib +import os +import weakref +from abc import ABC, abstractmethod +from typing import Optional + +import torch + +from tensorrt_llm._torch.attention_backend.utils import get_attention_backend +from tensorrt_llm._torch.metadata import KVCacheParams +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import WideEPMoE +from tensorrt_llm._torch.modules.linear import Linear, WeightMode +from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata +from tensorrt_llm._torch.pyexecutor._util import get_kv_cache_manager_cls +from tensorrt_llm._torch.pyexecutor.config_utils import is_mla, is_qwen3_next +from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager +from tensorrt_llm._torch.utils import get_model_extra_attrs, model_extra_attrs +from tensorrt_llm._utils import local_mpi_size, mpi_rank, mpi_world_size, torch_dtype_to_binding +from tensorrt_llm.bindings.executor import KvCacheConfig +from tensorrt_llm.bindings.internal.batch_manager import CacheType +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.modeling_utils import QuantConfig + +from .runner_base import BalanceMethod + + +def ceil_div(a, b): + return (a + b - 1) // b + + +def round_up(a, b): + return ceil_div(a, b) * b + + +class RunnerMixin(ABC): + @staticmethod + @abstractmethod + def has_mamba_metadata() -> bool: + pass + + @staticmethod + @contextlib.contextmanager + def scaled_from_ctx(scaled_from, mapping, pretrained_config): + if scaled_from is None: + yield + return + # To run the problem size of $B$ GPUs on $A$ GPUs, we need: + # (1) Attention: If TP, reduce the number of attention heads; If DP, nothing to change. + # (2) MoE: If EP, reduce the number of experts; If TP, reduce head size. + # Maintain the result of AllToAll method selection because it is affected by EP size. + if not mapping.enable_attention_dp: + if hasattr(pretrained_config, "index_n_heads"): + raise NotImplementedError("Not support Indexer TP for weak scaling") + pretrained_config.num_attention_heads = ( + pretrained_config.num_attention_heads // scaled_from * mapping.tp_size + ) + pretrained_config.num_key_value_heads = ( + pretrained_config.num_key_value_heads // scaled_from * mapping.tp_size + ) + if mapping.moe_ep_size != mapping.world_size: + raise NotImplementedError("Not support MoE TP for weak scaling") + pretrained_config.n_routed_experts = ( + pretrained_config.n_routed_experts // scaled_from * mapping.moe_ep_size + ) + select_alltoall_method_type_orig = WideEPMoE.select_alltoall_method_type + + def select_alltoall_method_type(cls: type, mapping: Mapping, top_k: int, *args, **kwargs): + # Replace the condition `mapping.moe_ep_size <= top_k` with `scaled_from <= top_k` + # by replacing `top_k` with `fake_top_k` + if scaled_from <= top_k: + fake_top_k = mapping.moe_ep_size + 1 + else: + fake_top_k = mapping.moe_ep_size - 1 + assert (mapping.moe_ep_size <= fake_top_k) == (scaled_from <= top_k) + return select_alltoall_method_type_orig(mapping, fake_top_k, *args, **kwargs) + + WideEPMoE.select_alltoall_method_type = select_alltoall_method_type + try: + yield + finally: + WideEPMoE.select_alltoall_method_type = select_alltoall_method_type_orig + + @staticmethod + def apply_quant_config_exclude_modules(layers, quant_config): + # Please refer to tensorrt_llm/_torch/models/modeling_utils.py + new_quant_config = QuantConfig(kv_cache_quant_algo=quant_config.kv_cache_quant_algo) + for layer in layers: + for name, module in layer.named_modules(): + name = f"model.layers.{layer.layer_idx}.{name}" + candidates = [name] + if isinstance(module, Linear): + weight_mode = module.weights_loading_config.weight_mode + if weight_mode == WeightMode.FUSED_GATE_UP_LINEAR: + # sometimes gate and up proj are not packed in the checkpoint, + # but they still share the same exclusion rule + candidates += [ + name.replace("gate_up_proj", "gate_proj"), + name.replace("gate_up_proj", "up_proj"), + ] + elif weight_mode == WeightMode.FUSED_QKV_LINEAR: + # sometimes q_proj, k_proj and v_proj are not packed in the checkpoint, + # but they still share the same exclusion rule + candidates += [ + name.replace("qkv_proj", "q_proj"), + name.replace("qkv_proj", "k_proj"), + name.replace("qkv_proj", "v_proj"), + ] + is_excluded = any( + quant_config.is_module_excluded_from_quantization(n) for n in candidates + ) + if is_excluded and getattr(module, "quant_config", None) is not None: + module.quant_config = new_quant_config + + def create_run_pack( + self, + run_type: str, + batch_size: int, + seq_len_q: int, + seq_len_kv_cache: int, + kv_cache_manager: KVCacheManager, + attn_workspace: Optional[torch.Tensor] = None, + ): + if self.model_config.moe_backend == "TRTLLM" and os.getenv("TRTLLM_ENABLE_PDL") != "1": + raise ValueError("Suggest to set TRTLLM_ENABLE_PDL=1 when moe_backend is TRTLLM") + world_size = mpi_world_size() + AttentionCls = get_attention_backend( + self.model_config.attn_backend, self.model_config.sparse_attention_config + ) + attn_metadata = AttentionCls.Metadata( + seq_lens=torch.tensor([seq_len_q] * batch_size, dtype=torch.int), + request_ids=list(range(batch_size)), + max_num_requests=kv_cache_manager.max_batch_size, + num_contexts={ + "CTX": batch_size, + "GEN": 0, + }[run_type], + prompt_lens=[ + { + "CTX": seq_len_q, + "GEN": seq_len_kv_cache, + }[run_type] + ] + * batch_size, + max_num_tokens=batch_size * seq_len_q, + kv_cache_manager=kv_cache_manager, + kv_cache_params=KVCacheParams( + use_cache=True, + num_cached_tokens_per_seq=[seq_len_kv_cache] * batch_size, + ), + workspace=attn_workspace, + mapping=self.model_config.mapping, + sparse_attention_config=self.model_config.sparse_attention_config, + ) + attn_metadata.all_rank_num_tokens = [batch_size * seq_len_q] * world_size + attn_metadata.prepare() + with model_extra_attrs(self.model_config.extra_attrs): + get_model_extra_attrs()["attention_metadata"] = weakref.ref(attn_metadata) + hidden_size = self.model_config.pretrained_config.hidden_size + position_ids = torch.tensor( + [list(range(seq_len_kv_cache, seq_len_kv_cache + seq_len_q)) * batch_size], + dtype=torch.int32, + device="cuda", + ) + hidden_states = torch.rand( + (batch_size * seq_len_q, hidden_size), dtype=torch.bfloat16, device="cuda" + ) + residual = torch.rand( + (batch_size * seq_len_q, hidden_size), dtype=torch.bfloat16, device="cuda" + ) + kwargs = {} + + if self.has_mamba_metadata(): + # Please refer to `tensorrt_llm/_torch/models/modeling_qwen3_next.py` for `mamba_metadata` + mamba_metadata = Mamba2Metadata(attn_metadata.max_num_requests, chunk_size=128) + mamba_metadata.prepare(attn_metadata) + kwargs["mamba_metadata"] = mamba_metadata + + def run_pack(): + output = hidden_states, residual + with model_extra_attrs(self.model_config.extra_attrs): + with torch.inference_mode(): + for layer in self.layers: + output = layer(position_ids, output[0], attn_metadata, output[1], **kwargs) + return output + + return run_pack + + def replace_routing_method(self, balance_method: BalanceMethod, balance_ratio: float): + if balance_method != BalanceMethod.NotModified: + raise NotImplementedError("not support replacing routing method for this runner") + + @staticmethod + def create_kv_cache_manager( + pretrained_model_name_or_path, + mapping, + tokens_per_block, + max_batch_size, + max_seq_len, + layer_indices, + ): + # Please refer to `tensorrt_llm/_torch/pyexecutor/py_executor_creator.py` for `tokens_per_block` + model_config = ModelConfig.from_pretrained(pretrained_model_name_or_path) + if model_config.enable_flash_mla: + assert tokens_per_block == 64 + + # Please refer to `tensorrt_llm/_torch/pyexecutor/_util.py` for `kv_cache_manager` + kv_cache_manager_cls = get_kv_cache_manager_cls(model_config) + config = model_config.pretrained_config + kv_cache_config = KvCacheConfig( + max_tokens=max_batch_size * round_up(max_seq_len, tokens_per_block), + enable_block_reuse=False, + ) + kv_cache_dtype = torch_dtype_to_binding( + { + None: torch.bfloat16, + "FP8": torch.float8_e4m3fn, + }[model_config.quant_config.kv_cache_quant_algo] + ) + if is_mla(config): + layer_mask = [i in layer_indices for i in range(config.num_hidden_layers)] + num_layers = sum(layer_mask) + kv_cache_manager = kv_cache_manager_cls( + kv_cache_config, + CacheType.SELFKONLY, + num_layers=sum(layer_mask), + num_kv_heads=1, + head_dim=model_config.pretrained_config.kv_lora_rank + + model_config.pretrained_config.qk_rope_head_dim, + tokens_per_block=tokens_per_block, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + mapping=mapping, + dtype=kv_cache_dtype, + layer_mask=layer_mask, + sparse_attn_config=model_config.sparse_attention_config, + ) + elif is_qwen3_next(config): + mamba_layer_mask = [ + i in layer_indices + if i % config.full_attention_interval != config.full_attention_interval - 1 + else False + for i in range(config.num_hidden_layers) + ] + layer_mask = [ + False + if i % config.full_attention_interval != config.full_attention_interval - 1 + else i in layer_indices + for i in range(config.num_hidden_layers) + ] + num_mamba_layers = sum(mamba_layer_mask) + num_layers = sum(layer_mask) + kv_cache_manager = kv_cache_manager_cls( + # mamba cache parameters + config.linear_key_head_dim, + config.linear_conv_kernel_dim, + config.linear_num_value_heads, + config.linear_num_key_heads, + config.linear_value_head_dim, + num_mamba_layers, + mamba_layer_mask, + config.torch_dtype, + model_config.quant_config.mamba_ssm_cache_dtype, + # kv cache parameters + kv_cache_config, + CacheType.SELF, + num_layers=num_layers, + layer_mask=layer_mask, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + tokens_per_block=tokens_per_block, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + mapping=mapping, + dtype=kv_cache_dtype, + spec_config=None, + ) + else: + raise NotImplementedError("Unsupported config") + kv_cache_manager.add_dummy_requests( + list(range(max_batch_size)), [max_seq_len] * max_batch_size + ) + return kv_cache_manager + + @staticmethod + def create_mapping(enable_attention_dp: bool): + world_size = mpi_world_size() + rank = mpi_rank() + mapping = Mapping( + world_size=world_size, + rank=rank, + gpus_per_node=local_mpi_size(), + cp_size=1, + tp_size=world_size, + pp_size=1, + moe_cluster_size=1, + moe_tp_size=1, + moe_ep_size=world_size, + attn_tp_size=world_size, + attn_cp_size=1, + enable_attention_dp=enable_attention_dp, + ) + return mapping From 31c1afb0cb9b775add84565fd85f32872bc32484 Mon Sep 17 00:00:00 2001 From: Tailing Yuan Date: Tue, 11 Nov 2025 05:13:46 +0000 Subject: [PATCH 3/9] Add Qwen3-Next benchmarks to CI Signed-off-by: Tailing Yuan --- examples/layer_wise_benchmarks/README.md | 4 ++-- .../test_lists/test-db/l0_b200.yml | 1 + .../tools/test_layer_wise_benchmarks.py | 24 +++++++++++++++++++ 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/examples/layer_wise_benchmarks/README.md b/examples/layer_wise_benchmarks/README.md index 79cbb745022..6cb324cd126 100644 --- a/examples/layer_wise_benchmarks/README.md +++ b/examples/layer_wise_benchmarks/README.md @@ -49,8 +49,8 @@ NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --scaled-from 16 --moe-back NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --scaled-from 16 --no-enable-attention-dp # Run Qwen3-Next (balanced routing is not implemented) -NP=2 TRTLLM_ENABLE_PDL=1 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified -NP=2 TRTLLM_ENABLE_PDL=1 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified +NP=2 TRTLLM_ENABLE_PDL=1 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified +NP=2 TRTLLM_ENABLE_PDL=1 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified # Run with DeepEP A2A NP=4 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEP ./mpi_launch.sh ./run_single.sh config_ctx.yaml --moe-backend WIDEEP diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 114717be909..2f70011c691 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -76,6 +76,7 @@ l0_b200: - unittest/_torch/modeling -k "modeling_llama" - unittest/_torch/modeling -k "modeling_mixtral" - unittest/_torch/modeling -k "modeling_gpt_oss" + - unittest/tools/test_layer_wise_benchmarks.py::test_qwen3_next_gen_tep[1] # ------------- AutoDeploy tests --------------- - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-1] - unittest/_torch/auto_deploy/unit/singlegpu diff --git a/tests/unittest/tools/test_layer_wise_benchmarks.py b/tests/unittest/tools/test_layer_wise_benchmarks.py index 0c9f8b1a65a..6f7880f8d4a 100644 --- a/tests/unittest/tools/test_layer_wise_benchmarks.py +++ b/tests/unittest/tools/test_layer_wise_benchmarks.py @@ -67,3 +67,27 @@ def test_deepseek_r1_gen_scaled_from_16_dep(llm_root): **os.environ, "NP": "4", }) + + +@pytest.mark.parametrize("tp_size", [1, 2, 4]) +def test_qwen3_next_gen_tep(llm_root, tp_size): + if torch.cuda.device_count() < tp_size: + pytest.skip(f"needs {tp_size:d} GPUs to run this test") + model_root = llm_models_root(check=True) + check_call([ + "./mpi_launch.sh", + "./run_single.sh", + "config_gen.yaml", + "--model", + model_root / "Qwen3" / "Qwen3-Next-80B-A3B-Instruct", + "--layer-indices=6,7", + "--no-enable-attention-dp", + "--moe-backend=TRTLLM", + "--balance-method=NotModified", + ], + cwd=llm_root / "examples" / "layer_wise_benchmarks", + env={ + **os.environ, + "NP": f"{tp_size:d}", + "TRTLLM_ENABLE_PDL": "1", + }) From cb6ed1b304d69a5f177c592deaf6b24551b771d9 Mon Sep 17 00:00:00 2001 From: Tailing Yuan Date: Tue, 11 Nov 2025 06:48:02 +0000 Subject: [PATCH 4/9] Apply new formatter Signed-off-by: Tailing Yuan --- .pre-commit-config.yaml | 3 - examples/layer_wise_benchmarks/run_single.py | 98 ++++++------ pyproject.toml | 3 - .../deepseekv3_runner.py | 151 ++++++++++-------- 4 files changed, 128 insertions(+), 127 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a75e55b79d3..cc4c3b0ff09 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -83,7 +83,6 @@ common-files: &common_files | examples/infinitebench/compute_scores.py | examples/infinitebench/construct_synthetic_dataset.py | examples/infinitebench/eval_utils.py | - examples/layer_wise_benchmarks/run_single.py | examples/llm-api/_tensorrt_engine/llm_eagle_decoding.py | examples/llm-api/_tensorrt_engine/llm_eagle2_decoding.py | examples/llm-api/_tensorrt_engine/llm_inference_customize.py | @@ -811,7 +810,6 @@ common-files: &common_files | tensorrt_llm/serve/tool_parser/utils.py | tensorrt_llm/tools/__init__.py | tensorrt_llm/tools/importlib_utils.py | - tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py | tensorrt_llm/tools/multimodal_builder.py | tensorrt_llm/tools/onnx_utils.py | tensorrt_llm/tools/plugin_gen/__init__.py | @@ -1188,7 +1186,6 @@ common-files: &common_files | tests/unittest/tools/plugin_gen/test_core.py | tests/unittest/tools/plugin_gen/test_plugin_gen.py | tests/unittest/tools/plugin_gen/test_shape_infer.py | - tests/unittest/tools/test_layer_wise_benchmarks.py | tests/unittest/tools/test_prepare_dataset.py | tests/unittest/tools/test_test_to_stage_mapping.py | tests/unittest/trt/__init__.py | diff --git a/examples/layer_wise_benchmarks/run_single.py b/examples/layer_wise_benchmarks/run_single.py index da700907bb3..fa2485a8468 100644 --- a/examples/layer_wise_benchmarks/run_single.py +++ b/examples/layer_wise_benchmarks/run_single.py @@ -9,8 +9,7 @@ from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size from tensorrt_llm.tools.layer_wise_benchmarks.runner_base import BalanceMethod -from tensorrt_llm.tools.layer_wise_benchmarks.runner_factory import \ - get_runner_cls +from tensorrt_llm.tools.layer_wise_benchmarks.runner_factory import get_runner_cls def comma_separated_ints(s): @@ -24,7 +23,8 @@ def comma_separated_ints(s): parser.add_argument( "--layer-indices", type=comma_separated_ints, - help="Comma separated indices of layers, should be a contiguous range") + help="Comma separated indices of layers, should be a contiguous range", +) parser.add_argument("--run-type", type=str, choices=["CTX", "GEN"]) parser.add_argument("--scaled-from", type=int) # KV cache related args @@ -32,24 +32,16 @@ def comma_separated_ints(s): parser.add_argument("--tokens-per-block", type=int) parser.add_argument("--max-seq-len", type=int) group = parser.add_mutually_exclusive_group(required=False) -group.add_argument("--enable-attention-dp", - action="store_true", - dest="enable_attention_dp") -group.add_argument("--no-enable-attention-dp", - action="store_false", - dest="enable_attention_dp") +group.add_argument("--enable-attention-dp", action="store_true", dest="enable_attention_dp") +group.add_argument("--no-enable-attention-dp", action="store_false", dest="enable_attention_dp") parser.set_defaults(enable_attention_dp=None) # Model init args parser.add_argument("--max-num-tokens", type=int) parser.add_argument("--moe-backend", type=str) parser.add_argument("--moe-max-num-tokens", type=int) group = parser.add_mutually_exclusive_group(required=False) -group.add_argument("--use-cuda-graph", - action="store_true", - dest="use_cuda_graph") -group.add_argument("--no-use-cuda-graph", - action="store_false", - dest="use_cuda_graph") +group.add_argument("--use-cuda-graph", action="store_true", dest="use_cuda_graph") +group.add_argument("--no-use-cuda-graph", action="store_false", dest="use_cuda_graph") parser.set_defaults(use_cuda_graph=None) # Per iteration args parser.add_argument("--batch-size", type=int) @@ -85,35 +77,41 @@ def comma_separated_ints(s): tokens_per_block=args.tokens_per_block, max_batch_size=args.max_batch_size, max_seq_len=args.max_seq_len, - layer_indices=args.layer_indices) -attn_workspace = torch.empty((0, ), device="cuda", dtype=torch.int8) + layer_indices=args.layer_indices, +) +attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8) # Create other global objects AutoTuner.get().clear_cache() capture_stream = torch.cuda.Stream() # Create Runner -runner = Runner(args.model, - mapping, - moe_backend=args.moe_backend, - layer_indices=args.layer_indices, - scaled_from=args.scaled_from, - max_seq_len=args.max_seq_len, - max_num_tokens=args.max_num_tokens, - moe_max_num_tokens=args.moe_max_num_tokens, - use_cuda_graph=args.use_cuda_graph) +runner = Runner( + args.model, + mapping, + moe_backend=args.moe_backend, + layer_indices=args.layer_indices, + scaled_from=args.scaled_from, + max_seq_len=args.max_seq_len, + max_num_tokens=args.max_num_tokens, + moe_max_num_tokens=args.moe_max_num_tokens, + use_cuda_graph=args.use_cuda_graph, +) # Warm up assert args.batch_size <= args.max_batch_size assert args.seq_len_q + args.seq_len_kv_cache <= args.max_seq_len -run_pack = runner.create_run_pack(args.run_type, - batch_size=args.batch_size, - seq_len_q=args.seq_len_q, - seq_len_kv_cache=args.seq_len_kv_cache, - kv_cache_manager=kv_cache_manager, - attn_workspace=attn_workspace) -runner.replace_routing_method(balance_method=BalanceMethod[args.balance_method], - balance_ratio=args.balance_ratio) +run_pack = runner.create_run_pack( + args.run_type, + batch_size=args.batch_size, + seq_len_q=args.seq_len_q, + seq_len_kv_cache=args.seq_len_kv_cache, + kv_cache_manager=kv_cache_manager, + attn_workspace=attn_workspace, +) +runner.replace_routing_method( + balance_method=BalanceMethod[args.balance_method], balance_ratio=args.balance_ratio +) capture_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(capture_stream): run_pack() @@ -127,21 +125,15 @@ def comma_separated_ints(s): if args.use_cuda_graph: with with_multi_stream(True): g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g, - stream=capture_stream, - capture_error_mode="global"): + with torch.cuda.graph(g, stream=capture_stream, capture_error_mode="global"): run_pack() warmup_times = 20 run_times = 100 -events = [ - torch.cuda.Event(enable_timing=True) - for _ in range(warmup_times + run_times + 1) -] +events = [torch.cuda.Event(enable_timing=True) for _ in range(warmup_times + run_times + 1)] for i in range(warmup_times + run_times): events[i].record() - with nvtx.annotate( - f"b={args.batch_size} s={args.seq_len_q} EP{world_size}"): + with nvtx.annotate(f"b={args.batch_size} s={args.seq_len_q} EP{world_size}"): if args.use_cuda_graph: g.replay() else: @@ -151,16 +143,16 @@ def comma_separated_ints(s): # Print statistics # Print before `cudaProfilerStop` to ensure messages are included in the profile -time_list = [ - start.elapsed_time(stop) for start, stop in zip(events, events[1:]) -] +time_list = [start.elapsed_time(stop) for start, stop in zip(events, events[1:])] time_list = time_list[warmup_times:] -print(f"[RANK {rank}]" - f" min {np.min(time_list) * 1000:.1f}" - f" max {np.max(time_list) * 1000:.1f}" - f" mean {np.mean(time_list) * 1000:.1f}" - f" median {np.median(time_list) * 1000:.1f}" - f" P90 {np.percentile(time_list, 90) * 1000:.1f}" - f" (us)") +print( + f"[RANK {rank}]" + f" min {np.min(time_list) * 1000:.1f}" + f" max {np.max(time_list) * 1000:.1f}" + f" mean {np.mean(time_list) * 1000:.1f}" + f" median {np.median(time_list) * 1000:.1f}" + f" P90 {np.percentile(time_list, 90) * 1000:.1f}" + f" (us)" +) torch.cuda.cudart().cudaProfilerStop() diff --git a/pyproject.toml b/pyproject.toml index 21e2921ad61..2ddd3c1b1ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,7 +123,6 @@ exclude = [ "examples/infinitebench/compute_scores.py", "examples/infinitebench/construct_synthetic_dataset.py", "examples/infinitebench/eval_utils.py", - "examples/layer_wise_benchmarks/run_single.py", "examples/llm-api/_tensorrt_engine/llm_eagle_decoding.py", "examples/llm-api/_tensorrt_engine/llm_eagle2_decoding.py", "examples/llm-api/_tensorrt_engine/llm_inference_customize.py", @@ -851,7 +850,6 @@ exclude = [ "tensorrt_llm/serve/tool_parser/utils.py", "tensorrt_llm/tools/__init__.py", "tensorrt_llm/tools/importlib_utils.py", - "tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py", "tensorrt_llm/tools/multimodal_builder.py", "tensorrt_llm/tools/onnx_utils.py", "tensorrt_llm/tools/plugin_gen/__init__.py", @@ -1228,7 +1226,6 @@ exclude = [ "tests/unittest/tools/plugin_gen/test_core.py", "tests/unittest/tools/plugin_gen/test_plugin_gen.py", "tests/unittest/tools/plugin_gen/test_shape_infer.py", - "tests/unittest/tools/test_layer_wise_benchmarks.py", "tests/unittest/tools/test_prepare_dataset.py", "tests/unittest/tools/test_test_to_stage_mapping.py", "tests/unittest/trt/__init__.py", diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py b/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py index 4469096eb2d..f9ae1b765ca 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py @@ -5,8 +5,7 @@ import tensorrt_llm._torch.models.modeling_deepseekv3 from tensorrt_llm._torch.model_config import ModelConfig -from tensorrt_llm._torch.models.modeling_deepseekv3 import ( - DeepseekV3DecoderLayer, DeepseekV3Gate) +from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3DecoderLayer, DeepseekV3Gate from tensorrt_llm._torch.modules.rms_norm import RMSNorm from tensorrt_llm._torch.utils import AuxStreamType from tensorrt_llm._utils import mpi_rank, mpi_world_size @@ -18,7 +17,6 @@ class RoutingMethod(DeepseekV3Gate): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.world_size = mpi_world_size() @@ -27,103 +25,120 @@ def __init__(self, *args, **kwargs): self.balance_ratio = None def apply(self, router_logits) -> (torch.Tensor, torch.Tensor): - token_selected_experts, token_final_scales = super().apply( - router_logits) + token_selected_experts, token_final_scales = super().apply(router_logits) num_experts = self.weight.shape[0] if self.balance_method == BalanceMethod.NotModified: pass elif self.balance_method == BalanceMethod.Balanced: token_selected_experts = RoutingMethod.get_balanced_selection( token_selected_experts.shape[0], - token_selected_experts.shape[1], num_experts, - token_selected_experts.dtype, self.world_size, self.rank) + token_selected_experts.shape[1], + num_experts, + token_selected_experts.dtype, + self.world_size, + self.rank, + ) elif self.balance_method == BalanceMethod.ImbalancedRanks: token_selected_experts = RoutingMethod.get_all_to_one_selection( token_selected_experts.shape[0], - token_selected_experts.shape[1], num_experts, - self.balance_ratio, token_selected_experts.dtype, - self.world_size, self.rank) + token_selected_experts.shape[1], + num_experts, + self.balance_ratio, + token_selected_experts.dtype, + self.world_size, + self.rank, + ) elif self.balance_method == BalanceMethod.ImbalancedExperts: token_selected_experts = RoutingMethod.get_balanced_rank_imbalanced_expert_selection( token_selected_experts.shape[0], - token_selected_experts.shape[1], num_experts, - self.balance_ratio, token_selected_experts.dtype, - self.world_size, self.rank) + token_selected_experts.shape[1], + num_experts, + self.balance_ratio, + token_selected_experts.dtype, + self.world_size, + self.rank, + ) else: - raise NotImplementedError( - f"Not support balance_method {self.balance_method}") + raise NotImplementedError(f"Not support balance_method {self.balance_method}") return token_selected_experts, token_final_scales @functools.cache @staticmethod - def get_balanced_selection(num_tokens, top_k, num_experts, dtype, - world_size, rank): - a = torch.arange(num_tokens * world_size * top_k, - dtype=dtype, - device="cuda").view(num_tokens, world_size, - top_k)[:, rank] - experts = (a * (num_experts // world_size + 1) + a // num_experts * - (num_experts // world_size)) % num_experts + def get_balanced_selection(num_tokens, top_k, num_experts, dtype, world_size, rank): + a = torch.arange(num_tokens * world_size * top_k, dtype=dtype, device="cuda").view( + num_tokens, world_size, top_k + )[:, rank] + experts = ( + a * (num_experts // world_size + 1) + a // num_experts * (num_experts // world_size) + ) % num_experts return experts.contiguous() @staticmethod - def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, - world_size, rank): + def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_size, rank): num_tokens, top_k = imbalanced_experts.shape dtype = imbalanced_experts.dtype balanced_experts = RoutingMethod.get_balanced_selection( - num_tokens, top_k, num_experts, dtype, world_size, rank) + num_tokens, top_k, num_experts, dtype, world_size, rank + ) num_balanced_tokens = round(num_tokens * balance_ratio) if balance_ratio != 0: # Activate all experts - num_balanced_tokens = max(num_balanced_tokens, - ceil_div(num_experts, world_size * top_k)) + num_balanced_tokens = max( + num_balanced_tokens, ceil_div(num_experts, world_size * top_k) + ) mixed_experts = balanced_experts.clone() - mixed_experts[num_balanced_tokens:] = imbalanced_experts[ - num_balanced_tokens:] + mixed_experts[num_balanced_tokens:] = imbalanced_experts[num_balanced_tokens:] return mixed_experts @functools.cache @staticmethod - def get_all_to_one_selection(num_tokens, top_k, num_experts, balance_ratio, - dtype, world_size, rank): + def get_all_to_one_selection( + num_tokens, top_k, num_experts, balance_ratio, dtype, world_size, rank + ): assert num_experts // world_size >= top_k - imbalanced_experts = torch.arange( - num_tokens * top_k, dtype=dtype, device="cuda").view( - num_tokens, top_k) % (num_experts // world_size) - return RoutingMethod.apply_balance_ratio(imbalanced_experts, - num_experts, balance_ratio, - world_size, rank) + imbalanced_experts = torch.arange(num_tokens * top_k, dtype=dtype, device="cuda").view( + num_tokens, top_k + ) % (num_experts // world_size) + return RoutingMethod.apply_balance_ratio( + imbalanced_experts, num_experts, balance_ratio, world_size, rank + ) @functools.cache @staticmethod - def get_balanced_rank_imbalanced_expert_selection(num_tokens, top_k, - num_experts, - balance_ratio, dtype, - world_size, rank): + def get_balanced_rank_imbalanced_expert_selection( + num_tokens, top_k, num_experts, balance_ratio, dtype, world_size, rank + ): experts_per_rank = num_experts // world_size activate_experts_per_rank = ceil_div(top_k, world_size) - a = torch.arange(num_tokens * top_k, dtype=dtype, - device="cuda").view(num_tokens, top_k) + a = torch.arange(num_tokens * top_k, dtype=dtype, device="cuda").view(num_tokens, top_k) narrow_experts = a % (activate_experts_per_rank * world_size) - imbalanced_experts = narrow_experts * experts_per_rank % num_experts + narrow_experts // world_size % experts_per_rank - return RoutingMethod.apply_balance_ratio(imbalanced_experts, - num_experts, balance_ratio, - world_size, rank) + imbalanced_experts = ( + narrow_experts * experts_per_rank % num_experts + + narrow_experts // world_size % experts_per_rank + ) + return RoutingMethod.apply_balance_ratio( + imbalanced_experts, num_experts, balance_ratio, world_size, rank + ) class DeepSeekV3Runner(RunnerMixin, RunnerBase): - @staticmethod def has_mamba_metadata() -> bool: return False - def __init__(self, pretrained_model_name_or_path: str, mapping: Mapping, *, - moe_backend: str, layer_indices: List[int], - scaled_from: Optional[int], max_seq_len: int, - max_num_tokens: int, moe_max_num_tokens: int, - use_cuda_graph: bool): - + def __init__( + self, + pretrained_model_name_or_path: str, + mapping: Mapping, + *, + moe_backend: str, + layer_indices: List[int], + scaled_from: Optional[int], + max_seq_len: int, + max_num_tokens: int, + moe_max_num_tokens: int, + use_cuda_graph: bool, + ): # Temporally replace the gate class gate_cls_orig = tensorrt_llm._torch.models.modeling_deepseekv3.DeepseekV3Gate tensorrt_llm._torch.models.modeling_deepseekv3.DeepseekV3Gate = RoutingMethod @@ -164,24 +179,26 @@ def __init__(self, pretrained_model_name_or_path: str, mapping: Mapping, *, model_config=self.model_config, layer_idx=layer_idx, aux_stream_dict=aux_stream_dict, - ) for layer_idx in layer_indices + ) + for layer_idx in layer_indices ] next_layer_layernorm = RMSNorm( hidden_size=pretrained_config.hidden_size, eps=pretrained_config.rms_norm_eps, - dtype=pretrained_config.torch_dtype) + dtype=pretrained_config.torch_dtype, + ) # TODO: apply_layerwise_quant_config - self.apply_quant_config_exclude_modules( - layers, self.model_config.quant_config) + self.apply_quant_config_exclude_modules(layers, self.model_config.quant_config) for layer in layers: for module in layer.modules(): if callable(getattr(module, "create_weights", None)): module.create_weights() layer.cuda() for module in layer.modules(): - if hasattr(module, 'post_load_weights') and not getattr( - module, '_weights_removed', False): + if hasattr(module, "post_load_weights") and not getattr( + module, "_weights_removed", False + ): module.post_load_weights() next_layer_layernorm.cuda() for layer, next_layer in zip(layers[:-1], layers[1:]): @@ -191,14 +208,12 @@ def __init__(self, pretrained_model_name_or_path: str, mapping: Mapping, *, self.layers = layers tensorrt_llm._torch.models.modeling_deepseekv3.DeepseekV3Gate = gate_cls_orig - def replace_routing_method(self, balance_method: BalanceMethod, - balance_ratio: float): - if self.model_config.moe_backend not in [ - "CUTLASS", "DEEPGEMM", "TRTLLM", "WIDEEP" - ]: + def replace_routing_method(self, balance_method: BalanceMethod, balance_ratio: float): + if self.model_config.moe_backend not in ["CUTLASS", "DEEPGEMM", "TRTLLM", "WIDEEP"]: raise NotImplementedError( - f"Not support replace routing method for moe_backend \"{self.model_config.moe_backend}\"," - f" please set balance_method to \"NotModified\"") + f'Not support replace routing method for moe_backend "{self.model_config.moe_backend}",' + f' please set balance_method to "NotModified"' + ) for layer in self.layers: layer.mlp.gate.balance_method = balance_method layer.mlp.gate.balance_ratio = balance_ratio From a83bdbc507a9a2a412fa5fae1d1c7cc5cdedc964 Mon Sep 17 00:00:00 2001 From: Tailing Yuan Date: Tue, 11 Nov 2025 08:02:19 +0000 Subject: [PATCH 5/9] Update interface Signed-off-by: Tailing Yuan --- examples/layer_wise_benchmarks/run_single.py | 3 +-- tensorrt_llm/tools/layer_wise_benchmarks/__init__.py | 7 +++++++ .../tools/layer_wise_benchmarks/deepseekv3_runner.py | 2 +- .../tools/layer_wise_benchmarks/qwen3_next_runner.py | 2 +- tensorrt_llm/tools/layer_wise_benchmarks/runner_factory.py | 2 +- .../{runner_base.py => runner_interface.py} | 0 tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py | 2 +- 7 files changed, 12 insertions(+), 6 deletions(-) rename tensorrt_llm/tools/layer_wise_benchmarks/{runner_base.py => runner_interface.py} (100%) diff --git a/examples/layer_wise_benchmarks/run_single.py b/examples/layer_wise_benchmarks/run_single.py index fa2485a8468..77b95c35a6c 100644 --- a/examples/layer_wise_benchmarks/run_single.py +++ b/examples/layer_wise_benchmarks/run_single.py @@ -8,8 +8,7 @@ from tensorrt_llm._torch.autotuner import AutoTuner, autotune from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size -from tensorrt_llm.tools.layer_wise_benchmarks.runner_base import BalanceMethod -from tensorrt_llm.tools.layer_wise_benchmarks.runner_factory import get_runner_cls +from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, get_runner_cls def comma_separated_ints(s): diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/__init__.py b/tensorrt_llm/tools/layer_wise_benchmarks/__init__.py index e69de29bb2d..607e110b62e 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/__init__.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/__init__.py @@ -0,0 +1,7 @@ +from .runner_factory import get_runner_cls +from .runner_interface import BalanceMethod + +__all__ = [ + "BalanceMethod", + "get_runner_cls", +] diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py b/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py index f9ae1b765ca..27b16e94bfc 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py @@ -12,7 +12,7 @@ from tensorrt_llm.functional import AllReduceStrategy from tensorrt_llm.mapping import Mapping -from .runner_base import BalanceMethod, RunnerBase +from .runner_interface import BalanceMethod, RunnerBase from .runner_utils import RunnerMixin, ceil_div diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/qwen3_next_runner.py b/tensorrt_llm/tools/layer_wise_benchmarks/qwen3_next_runner.py index 19ee41054a5..888d0c4574e 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/qwen3_next_runner.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/qwen3_next_runner.py @@ -8,7 +8,7 @@ from tensorrt_llm.functional import AllReduceStrategy from tensorrt_llm.mapping import Mapping -from .runner_base import RunnerBase +from .runner_interface import RunnerBase from .runner_utils import RunnerMixin diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/runner_factory.py b/tensorrt_llm/tools/layer_wise_benchmarks/runner_factory.py index 6d712f1f6d5..b45d1e8e5ba 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/runner_factory.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/runner_factory.py @@ -4,7 +4,7 @@ from .qwen3_next_runner import Qwen3NextRunner -def get_runner_cls(pretrained_model_name_or_path: str): +def get_runner_cls(pretrained_model_name_or_path: str) -> type: pretrained_config = load_pretrained_config(pretrained_model_name_or_path) return { "deepseek_v3": DeepSeekV3Runner, diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/runner_base.py b/tensorrt_llm/tools/layer_wise_benchmarks/runner_interface.py similarity index 100% rename from tensorrt_llm/tools/layer_wise_benchmarks/runner_base.py rename to tensorrt_llm/tools/layer_wise_benchmarks/runner_interface.py diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py b/tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py index 8a5f4cba5e4..ea7201e7f2b 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py @@ -22,7 +22,7 @@ from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig -from .runner_base import BalanceMethod +from .runner_interface import BalanceMethod def ceil_div(a, b): From 8334235ad23387df7802662837af69818aaf178b Mon Sep 17 00:00:00 2001 From: Tailing Yuan Date: Tue, 11 Nov 2025 08:44:12 +0000 Subject: [PATCH 6/9] Fix the order between @staticmethod and @functools.cache Signed-off-by: Tailing Yuan --- .../tools/layer_wise_benchmarks/deepseekv3_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py b/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py index 27b16e94bfc..f5d41374eca 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py @@ -62,8 +62,8 @@ def apply(self, router_logits) -> (torch.Tensor, torch.Tensor): raise NotImplementedError(f"Not support balance_method {self.balance_method}") return token_selected_experts, token_final_scales - @functools.cache @staticmethod + @functools.cache def get_balanced_selection(num_tokens, top_k, num_experts, dtype, world_size, rank): a = torch.arange(num_tokens * world_size * top_k, dtype=dtype, device="cuda").view( num_tokens, world_size, top_k @@ -90,8 +90,8 @@ def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_si mixed_experts[num_balanced_tokens:] = imbalanced_experts[num_balanced_tokens:] return mixed_experts - @functools.cache @staticmethod + @functools.cache def get_all_to_one_selection( num_tokens, top_k, num_experts, balance_ratio, dtype, world_size, rank ): @@ -103,8 +103,8 @@ def get_all_to_one_selection( imbalanced_experts, num_experts, balance_ratio, world_size, rank ) - @functools.cache @staticmethod + @functools.cache def get_balanced_rank_imbalanced_expert_selection( num_tokens, top_k, num_experts, balance_ratio, dtype, world_size, rank ): From 10d7f1952d5328d3b44f5d761fdc68eb32fcb382 Mon Sep 17 00:00:00 2001 From: Tailing Yuan Date: Tue, 11 Nov 2025 08:51:14 +0000 Subject: [PATCH 7/9] Auto trigger multi-gpu CI Signed-off-by: Tailing Yuan --- jenkins/L0_MergeRequest.groovy | 9 ++++++++- tests/integration/test_lists/test-db/l0_dgx_b200.yml | 4 ++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/jenkins/L0_MergeRequest.groovy b/jenkins/L0_MergeRequest.groovy index 097818b95b8..fc8ca6b3017 100644 --- a/jenkins/L0_MergeRequest.groovy +++ b/jenkins/L0_MergeRequest.groovy @@ -627,9 +627,16 @@ def getAutoTriggerTagList(pipeline, testFilter, globalVars) { return autoTriggerTagList } def specialFileToTagMap = [ - "tensorrt_llm/_torch/models/modeling_deepseekv3.py": ["-DeepSeek-"], "cpp/kernels/fmha_v2/": ["-FMHA-"], + "examples/layer_wise_benchmarks/config_ctx.yaml": ["DGX_B200-4_GPUs-PyTorch-Post-Merge"], + "examples/layer_wise_benchmarks/config_gen.yaml": ["DGX_B200-4_GPUs-PyTorch-Post-Merge"], + "examples/layer_wise_benchmarks/mpi_launch.sh": ["DGX_B200-4_GPUs-PyTorch-Post-Merge"], + "examples/layer_wise_benchmarks/run_single.py": ["DGX_B200-4_GPUs-PyTorch-Post-Merge"], + "examples/layer_wise_benchmarks/run_single.sh": ["DGX_B200-4_GPUs-PyTorch-Post-Merge"], + "tensorrt_llm/_torch/models/modeling_deepseekv3.py": ["-DeepSeek-"], "tensorrt_llm/_torch/models/modeling_gpt_oss.py": ["-GptOss-"], + "tensorrt_llm/tools/layer_wise_benchmarks/": ["DGX_B200-4_GPUs-PyTorch-Post-Merge"], + "tests/unittest/tools/test_layer_wise_benchmarks.py": ["DGX_B200-4_GPUs-PyTorch-Post-Merge"], ] for (file in changedFileList) { for (String key : specialFileToTagMap.keySet()) { diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index 3de071247da..f02f95d2aca 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -18,8 +18,6 @@ l0_dgx_b200: - unittest/_torch/multi_gpu_modeling -k "deepseek" - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency] - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[MNNVL] - - unittest/tools/test_layer_wise_benchmarks.py::test_deepseek_r1_ctx_tep - - unittest/tools/test_layer_wise_benchmarks.py::test_deepseek_r1_gen_scaled_from_16_dep - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False] @@ -144,6 +142,8 @@ l0_dgx_b200: orchestrator: mpi tests: - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEP] + - unittest/tools/test_layer_wise_benchmarks.py::test_deepseek_r1_ctx_tep + - unittest/tools/test_layer_wise_benchmarks.py::test_deepseek_r1_gen_scaled_from_16_dep - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=FLASHINFER-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False] From 02cf0dd64707ff1c426cd8f71295ff04e69fc6eb Mon Sep 17 00:00:00 2001 From: Tailing Yuan Date: Tue, 11 Nov 2025 09:38:25 +0000 Subject: [PATCH 8/9] Apply new formatter to test_layer_wise_benchmarks.py Signed-off-by: Tailing Yuan --- .../unit/singlegpu/test_ad_trtllm_bench.py | 1 - tests/unittest/conftest.py | 1 + .../tools/test_layer_wise_benchmarks.py | 140 +++++++++--------- tests/unittest/tools/test_prepare_dataset.py | 1 - 4 files changed, 73 insertions(+), 70 deletions(-) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py index 7c4da257bfa..e6515d3d802 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py @@ -6,7 +6,6 @@ import yaml from _model_test_utils import get_small_model_config from click.testing import CliRunner -from utils.cpp_paths import llm_root # noqa: F401 from tensorrt_llm.commands.bench import main diff --git a/tests/unittest/conftest.py b/tests/unittest/conftest.py index 5b1557e4c63..97c48b6b96d 100644 --- a/tests/unittest/conftest.py +++ b/tests/unittest/conftest.py @@ -25,6 +25,7 @@ import torch import tqdm from mpi4py.futures import MPIPoolExecutor +from utils.cpp_paths import llm_root # noqa: F401 from utils.util import get_current_process_gpu_memory sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) diff --git a/tests/unittest/tools/test_layer_wise_benchmarks.py b/tests/unittest/tools/test_layer_wise_benchmarks.py index 6f7880f8d4a..72e5cff6ef4 100644 --- a/tests/unittest/tools/test_layer_wise_benchmarks.py +++ b/tests/unittest/tools/test_layer_wise_benchmarks.py @@ -3,70 +3,72 @@ import pytest import torch -from utils.cpp_paths import llm_root # noqa: F401 from utils.llm_data import llm_models_root -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="needs 4 GPUs to run this test") +@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="needs 4 GPUs to run this test") def test_deepseek_r1_ctx_tep(llm_root): model_root = llm_models_root(check=True) - check_call([ - "./mpi_launch.sh", - "./run_single.sh", - "config_ctx.yaml", - "--model", - model_root / "DeepSeek-R1" / "DeepSeek-R1-0528-FP4-v2", - "--no-enable-attention-dp", - "--moe-backend=TRTLLM", - ], - cwd=llm_root / "examples" / "layer_wise_benchmarks", - env={ - **os.environ, - "NP": "4", - "TRTLLM_ENABLE_PDL": "1", - }) + check_call( + [ + "./mpi_launch.sh", + "./run_single.sh", + "config_ctx.yaml", + "--model", + model_root / "DeepSeek-R1" / "DeepSeek-R1-0528-FP4-v2", + "--no-enable-attention-dp", + "--moe-backend=TRTLLM", + ], + cwd=llm_root / "examples" / "layer_wise_benchmarks", + env={ + **os.environ, + "NP": "4", + "TRTLLM_ENABLE_PDL": "1", + }, + ) -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="needs 4 GPUs to run this test") +@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="needs 4 GPUs to run this test") def test_deepseek_v32_ctx_dep(llm_root): model_root = llm_models_root(check=True) - check_call([ - "./mpi_launch.sh", - "./run_single.sh", - "config_ctx.yaml", - "--model", - model_root / "DeepSeek-V3.2-Exp-hf", - "--tokens-per-block=64", - "--moe-backend=DEEPGEMM", - ], - cwd=llm_root / "examples" / "layer_wise_benchmarks", - env={ - **os.environ, - "NP": "4", - }) + check_call( + [ + "./mpi_launch.sh", + "./run_single.sh", + "config_ctx.yaml", + "--model", + model_root / "DeepSeek-V3.2-Exp-hf", + "--tokens-per-block=64", + "--moe-backend=DEEPGEMM", + ], + cwd=llm_root / "examples" / "layer_wise_benchmarks", + env={ + **os.environ, + "NP": "4", + }, + ) -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="needs 4 GPUs to run this test") +@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="needs 4 GPUs to run this test") def test_deepseek_r1_gen_scaled_from_16_dep(llm_root): model_root = llm_models_root(check=True) - check_call([ - "./mpi_launch.sh", - "./run_single.sh", - "config_gen.yaml", - "--model", - model_root / "DeepSeek-R1" / "DeepSeek-R1-0528-FP4-v2", - "--layer-indices=5,6", - "--scaled-from=16", - "--moe-backend=WIDEEP", - ], - cwd=llm_root / "examples" / "layer_wise_benchmarks", - env={ - **os.environ, - "NP": "4", - }) + check_call( + [ + "./mpi_launch.sh", + "./run_single.sh", + "config_gen.yaml", + "--model", + model_root / "DeepSeek-R1" / "DeepSeek-R1-0528-FP4-v2", + "--layer-indices=5,6", + "--scaled-from=16", + "--moe-backend=WIDEEP", + ], + cwd=llm_root / "examples" / "layer_wise_benchmarks", + env={ + **os.environ, + "NP": "4", + }, + ) @pytest.mark.parametrize("tp_size", [1, 2, 4]) @@ -74,20 +76,22 @@ def test_qwen3_next_gen_tep(llm_root, tp_size): if torch.cuda.device_count() < tp_size: pytest.skip(f"needs {tp_size:d} GPUs to run this test") model_root = llm_models_root(check=True) - check_call([ - "./mpi_launch.sh", - "./run_single.sh", - "config_gen.yaml", - "--model", - model_root / "Qwen3" / "Qwen3-Next-80B-A3B-Instruct", - "--layer-indices=6,7", - "--no-enable-attention-dp", - "--moe-backend=TRTLLM", - "--balance-method=NotModified", - ], - cwd=llm_root / "examples" / "layer_wise_benchmarks", - env={ - **os.environ, - "NP": f"{tp_size:d}", - "TRTLLM_ENABLE_PDL": "1", - }) + check_call( + [ + "./mpi_launch.sh", + "./run_single.sh", + "config_gen.yaml", + "--model", + model_root / "Qwen3" / "Qwen3-Next-80B-A3B-Instruct", + "--layer-indices=6,7", + "--no-enable-attention-dp", + "--moe-backend=TRTLLM", + "--balance-method=NotModified", + ], + cwd=llm_root / "examples" / "layer_wise_benchmarks", + env={ + **os.environ, + "NP": f"{tp_size:d}", + "TRTLLM_ENABLE_PDL": "1", + }, + ) diff --git a/tests/unittest/tools/test_prepare_dataset.py b/tests/unittest/tools/test_prepare_dataset.py index 05da19a5957..df2c8e9d1b3 100644 --- a/tests/unittest/tools/test_prepare_dataset.py +++ b/tests/unittest/tools/test_prepare_dataset.py @@ -6,7 +6,6 @@ from typing import Any, Dict, List, Tuple import pytest -from utils.cpp_paths import llm_root # noqa: F401 from utils.llm_data import llm_models_root # Constants for test configuration From 175a73fb4519b620f16dc7b7c6eb1003d434c542 Mon Sep 17 00:00:00 2001 From: Tailing Yuan Date: Thu, 13 Nov 2025 05:44:07 +0000 Subject: [PATCH 9/9] Move tests to single-gpu Signed-off-by: Tailing Yuan --- jenkins/L0_MergeRequest.groovy | 9 +-- .../test_lists/test-db/l0_b200.yml | 1 + .../test_lists/test-db/l0_dgx_b200.yml | 2 - .../tools/test_layer_wise_benchmarks.py | 55 ++++++++++++++----- 4 files changed, 43 insertions(+), 24 deletions(-) diff --git a/jenkins/L0_MergeRequest.groovy b/jenkins/L0_MergeRequest.groovy index fc8ca6b3017..097818b95b8 100644 --- a/jenkins/L0_MergeRequest.groovy +++ b/jenkins/L0_MergeRequest.groovy @@ -627,16 +627,9 @@ def getAutoTriggerTagList(pipeline, testFilter, globalVars) { return autoTriggerTagList } def specialFileToTagMap = [ - "cpp/kernels/fmha_v2/": ["-FMHA-"], - "examples/layer_wise_benchmarks/config_ctx.yaml": ["DGX_B200-4_GPUs-PyTorch-Post-Merge"], - "examples/layer_wise_benchmarks/config_gen.yaml": ["DGX_B200-4_GPUs-PyTorch-Post-Merge"], - "examples/layer_wise_benchmarks/mpi_launch.sh": ["DGX_B200-4_GPUs-PyTorch-Post-Merge"], - "examples/layer_wise_benchmarks/run_single.py": ["DGX_B200-4_GPUs-PyTorch-Post-Merge"], - "examples/layer_wise_benchmarks/run_single.sh": ["DGX_B200-4_GPUs-PyTorch-Post-Merge"], "tensorrt_llm/_torch/models/modeling_deepseekv3.py": ["-DeepSeek-"], + "cpp/kernels/fmha_v2/": ["-FMHA-"], "tensorrt_llm/_torch/models/modeling_gpt_oss.py": ["-GptOss-"], - "tensorrt_llm/tools/layer_wise_benchmarks/": ["DGX_B200-4_GPUs-PyTorch-Post-Merge"], - "tests/unittest/tools/test_layer_wise_benchmarks.py": ["DGX_B200-4_GPUs-PyTorch-Post-Merge"], ] for (file in changedFileList) { for (String key : specialFileToTagMap.keySet()) { diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 2f70011c691..0972c6b0b4b 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -76,6 +76,7 @@ l0_b200: - unittest/_torch/modeling -k "modeling_llama" - unittest/_torch/modeling -k "modeling_mixtral" - unittest/_torch/modeling -k "modeling_gpt_oss" + - unittest/tools/test_layer_wise_benchmarks.py::test_deepseek_r1_ctx_dep[1] - unittest/tools/test_layer_wise_benchmarks.py::test_qwen3_next_gen_tep[1] # ------------- AutoDeploy tests --------------- - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-1] diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index e1be6ba1ba8..312ae9963ce 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -143,8 +143,6 @@ l0_dgx_b200: orchestrator: mpi tests: - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEP] - - unittest/tools/test_layer_wise_benchmarks.py::test_deepseek_r1_ctx_tep - - unittest/tools/test_layer_wise_benchmarks.py::test_deepseek_r1_gen_scaled_from_16_dep - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=FLASHINFER-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False] diff --git a/tests/unittest/tools/test_layer_wise_benchmarks.py b/tests/unittest/tools/test_layer_wise_benchmarks.py index 72e5cff6ef4..14a02a9ae07 100644 --- a/tests/unittest/tools/test_layer_wise_benchmarks.py +++ b/tests/unittest/tools/test_layer_wise_benchmarks.py @@ -6,8 +6,31 @@ from utils.llm_data import llm_models_root -@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="needs 4 GPUs to run this test") -def test_deepseek_r1_ctx_tep(llm_root): +@pytest.mark.parametrize("world_size", [1, 4]) +def test_deepseek_r1_ctx_dep(llm_root, world_size): + if torch.cuda.device_count() < world_size: + pytest.skip(f"needs {world_size:d} GPUs to run this test") + model_root = llm_models_root(check=True) + check_call( + [ + "./mpi_launch.sh", + "./run_single.sh", + "config_ctx.yaml", + "--model", + model_root / "DeepSeek-R1" / "DeepSeek-R1-0528-FP4-v2", + ], + cwd=llm_root / "examples" / "layer_wise_benchmarks", + env={ + **os.environ, + "NP": f"{world_size:d}", + }, + ) + + +@pytest.mark.parametrize("world_size", [1, 4]) +def test_deepseek_r1_ctx_tep(llm_root, world_size): + if torch.cuda.device_count() < world_size: + pytest.skip(f"needs {world_size:d} GPUs to run this test") model_root = llm_models_root(check=True) check_call( [ @@ -22,14 +45,16 @@ def test_deepseek_r1_ctx_tep(llm_root): cwd=llm_root / "examples" / "layer_wise_benchmarks", env={ **os.environ, - "NP": "4", + "NP": f"{world_size:d}", "TRTLLM_ENABLE_PDL": "1", }, ) -@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="needs 4 GPUs to run this test") -def test_deepseek_v32_ctx_dep(llm_root): +@pytest.mark.parametrize("world_size", [1, 4]) +def test_deepseek_v32_ctx_dep(llm_root, world_size): + if torch.cuda.device_count() < world_size: + pytest.skip(f"needs {world_size:d} GPUs to run this test") model_root = llm_models_root(check=True) check_call( [ @@ -44,13 +69,15 @@ def test_deepseek_v32_ctx_dep(llm_root): cwd=llm_root / "examples" / "layer_wise_benchmarks", env={ **os.environ, - "NP": "4", + "NP": f"{world_size:d}", }, ) -@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="needs 4 GPUs to run this test") -def test_deepseek_r1_gen_scaled_from_16_dep(llm_root): +@pytest.mark.parametrize("world_size", [4]) +def test_deepseek_r1_gen_scaled_from_16_dep(llm_root, world_size): + if torch.cuda.device_count() < world_size: + pytest.skip(f"needs {world_size:d} GPUs to run this test") model_root = llm_models_root(check=True) check_call( [ @@ -66,15 +93,15 @@ def test_deepseek_r1_gen_scaled_from_16_dep(llm_root): cwd=llm_root / "examples" / "layer_wise_benchmarks", env={ **os.environ, - "NP": "4", + "NP": f"{world_size:d}", }, ) -@pytest.mark.parametrize("tp_size", [1, 2, 4]) -def test_qwen3_next_gen_tep(llm_root, tp_size): - if torch.cuda.device_count() < tp_size: - pytest.skip(f"needs {tp_size:d} GPUs to run this test") +@pytest.mark.parametrize("world_size", [1, 4]) +def test_qwen3_next_gen_tep(llm_root, world_size): + if torch.cuda.device_count() < world_size: + pytest.skip(f"needs {world_size:d} GPUs to run this test") model_root = llm_models_root(check=True) check_call( [ @@ -91,7 +118,7 @@ def test_qwen3_next_gen_tep(llm_root, tp_size): cwd=llm_root / "examples" / "layer_wise_benchmarks", env={ **os.environ, - "NP": f"{tp_size:d}", + "NP": f"{world_size:d}", "TRTLLM_ENABLE_PDL": "1", }, )