88from tensorrt_llm ._torch .autotuner import AutoTuner , autotune
99from tensorrt_llm ._torch .modules .multi_stream_utils import with_multi_stream
1010from tensorrt_llm ._utils import local_mpi_rank , mpi_rank , mpi_world_size
11- from tensorrt_llm .tools .layer_wise_benchmarks .deepseekv3_runner import (
12- BalanceMethod , DeepSeekV3Runner )
11+ from tensorrt_llm .tools .layer_wise_benchmarks .runner_base import BalanceMethod
12+ from tensorrt_llm .tools .layer_wise_benchmarks .runner_factory import \
13+ get_runner_cls
1314
1415
1516def comma_separated_ints (s ):
@@ -76,9 +77,9 @@ def comma_separated_ints(s):
7677torch .cuda .set_device (local_rank )
7778
7879# Create KV cache manager
79- mapping = DeepSeekV3Runner . create_mapping (
80- enable_attention_dp = args .enable_attention_dp )
81- kv_cache_manager = DeepSeekV3Runner .create_kv_cache_manager (
80+ Runner = get_runner_cls ( args . model )
81+ mapping = Runner . create_mapping ( enable_attention_dp = args .enable_attention_dp )
82+ kv_cache_manager = Runner .create_kv_cache_manager (
8283 args .model ,
8384 mapping ,
8485 tokens_per_block = args .tokens_per_block ,
@@ -92,15 +93,15 @@ def comma_separated_ints(s):
9293capture_stream = torch .cuda .Stream ()
9394
9495# Create Runner
95- runner = DeepSeekV3Runner (args .model ,
96- mapping ,
97- moe_backend = args .moe_backend ,
98- layer_indices = args .layer_indices ,
99- scaled_from = args .scaled_from ,
100- max_seq_len = args .max_seq_len ,
101- max_num_tokens = args .max_num_tokens ,
102- moe_max_num_tokens = args .moe_max_num_tokens ,
103- use_cuda_graph = args .use_cuda_graph )
96+ runner = Runner (args .model ,
97+ mapping ,
98+ moe_backend = args .moe_backend ,
99+ layer_indices = args .layer_indices ,
100+ scaled_from = args .scaled_from ,
101+ max_seq_len = args .max_seq_len ,
102+ max_num_tokens = args .max_num_tokens ,
103+ moe_max_num_tokens = args .moe_max_num_tokens ,
104+ use_cuda_graph = args .use_cuda_graph )
104105
105106# Warm up
106107assert args .batch_size <= args .max_batch_size
0 commit comments