99from tensorrt_llm ._torch .modules .multi_stream_utils import with_multi_stream
1010from tensorrt_llm ._utils import local_mpi_rank , mpi_rank , mpi_world_size
1111from tensorrt_llm .tools .layer_wise_benchmarks .runner_base import BalanceMethod
12- from tensorrt_llm .tools .layer_wise_benchmarks .runner_factory import \
13- get_runner_cls
12+ from tensorrt_llm .tools .layer_wise_benchmarks .runner_factory import get_runner_cls
1413
1514
1615def comma_separated_ints (s ):
@@ -24,32 +23,25 @@ def comma_separated_ints(s):
2423parser .add_argument (
2524 "--layer-indices" ,
2625 type = comma_separated_ints ,
27- help = "Comma separated indices of layers, should be a contiguous range" )
26+ help = "Comma separated indices of layers, should be a contiguous range" ,
27+ )
2828parser .add_argument ("--run-type" , type = str , choices = ["CTX" , "GEN" ])
2929parser .add_argument ("--scaled-from" , type = int )
3030# KV cache related args
3131parser .add_argument ("--max-batch-size" , type = int )
3232parser .add_argument ("--tokens-per-block" , type = int )
3333parser .add_argument ("--max-seq-len" , type = int )
3434group = parser .add_mutually_exclusive_group (required = False )
35- group .add_argument ("--enable-attention-dp" ,
36- action = "store_true" ,
37- dest = "enable_attention_dp" )
38- group .add_argument ("--no-enable-attention-dp" ,
39- action = "store_false" ,
40- dest = "enable_attention_dp" )
35+ group .add_argument ("--enable-attention-dp" , action = "store_true" , dest = "enable_attention_dp" )
36+ group .add_argument ("--no-enable-attention-dp" , action = "store_false" , dest = "enable_attention_dp" )
4137parser .set_defaults (enable_attention_dp = None )
4238# Model init args
4339parser .add_argument ("--max-num-tokens" , type = int )
4440parser .add_argument ("--moe-backend" , type = str )
4541parser .add_argument ("--moe-max-num-tokens" , type = int )
4642group = parser .add_mutually_exclusive_group (required = False )
47- group .add_argument ("--use-cuda-graph" ,
48- action = "store_true" ,
49- dest = "use_cuda_graph" )
50- group .add_argument ("--no-use-cuda-graph" ,
51- action = "store_false" ,
52- dest = "use_cuda_graph" )
43+ group .add_argument ("--use-cuda-graph" , action = "store_true" , dest = "use_cuda_graph" )
44+ group .add_argument ("--no-use-cuda-graph" , action = "store_false" , dest = "use_cuda_graph" )
5345parser .set_defaults (use_cuda_graph = None )
5446# Per iteration args
5547parser .add_argument ("--batch-size" , type = int )
@@ -85,35 +77,41 @@ def comma_separated_ints(s):
8577 tokens_per_block = args .tokens_per_block ,
8678 max_batch_size = args .max_batch_size ,
8779 max_seq_len = args .max_seq_len ,
88- layer_indices = args .layer_indices )
89- attn_workspace = torch .empty ((0 , ), device = "cuda" , dtype = torch .int8 )
80+ layer_indices = args .layer_indices ,
81+ )
82+ attn_workspace = torch .empty ((0 ,), device = "cuda" , dtype = torch .int8 )
9083
9184# Create other global objects
9285AutoTuner .get ().clear_cache ()
9386capture_stream = torch .cuda .Stream ()
9487
9588# Create Runner
96- runner = Runner (args .model ,
97- mapping ,
98- moe_backend = args .moe_backend ,
99- layer_indices = args .layer_indices ,
100- scaled_from = args .scaled_from ,
101- max_seq_len = args .max_seq_len ,
102- max_num_tokens = args .max_num_tokens ,
103- moe_max_num_tokens = args .moe_max_num_tokens ,
104- use_cuda_graph = args .use_cuda_graph )
89+ runner = Runner (
90+ args .model ,
91+ mapping ,
92+ moe_backend = args .moe_backend ,
93+ layer_indices = args .layer_indices ,
94+ scaled_from = args .scaled_from ,
95+ max_seq_len = args .max_seq_len ,
96+ max_num_tokens = args .max_num_tokens ,
97+ moe_max_num_tokens = args .moe_max_num_tokens ,
98+ use_cuda_graph = args .use_cuda_graph ,
99+ )
105100
106101# Warm up
107102assert args .batch_size <= args .max_batch_size
108103assert args .seq_len_q + args .seq_len_kv_cache <= args .max_seq_len
109- run_pack = runner .create_run_pack (args .run_type ,
110- batch_size = args .batch_size ,
111- seq_len_q = args .seq_len_q ,
112- seq_len_kv_cache = args .seq_len_kv_cache ,
113- kv_cache_manager = kv_cache_manager ,
114- attn_workspace = attn_workspace )
115- runner .replace_routing_method (balance_method = BalanceMethod [args .balance_method ],
116- balance_ratio = args .balance_ratio )
104+ run_pack = runner .create_run_pack (
105+ args .run_type ,
106+ batch_size = args .batch_size ,
107+ seq_len_q = args .seq_len_q ,
108+ seq_len_kv_cache = args .seq_len_kv_cache ,
109+ kv_cache_manager = kv_cache_manager ,
110+ attn_workspace = attn_workspace ,
111+ )
112+ runner .replace_routing_method (
113+ balance_method = BalanceMethod [args .balance_method ], balance_ratio = args .balance_ratio
114+ )
117115capture_stream .wait_stream (torch .cuda .current_stream ())
118116with torch .cuda .stream (capture_stream ):
119117 run_pack ()
@@ -127,21 +125,15 @@ def comma_separated_ints(s):
127125if args .use_cuda_graph :
128126 with with_multi_stream (True ):
129127 g = torch .cuda .CUDAGraph ()
130- with torch .cuda .graph (g ,
131- stream = capture_stream ,
132- capture_error_mode = "global" ):
128+ with torch .cuda .graph (g , stream = capture_stream , capture_error_mode = "global" ):
133129 run_pack ()
134130
135131warmup_times = 20
136132run_times = 100
137- events = [
138- torch .cuda .Event (enable_timing = True )
139- for _ in range (warmup_times + run_times + 1 )
140- ]
133+ events = [torch .cuda .Event (enable_timing = True ) for _ in range (warmup_times + run_times + 1 )]
141134for i in range (warmup_times + run_times ):
142135 events [i ].record ()
143- with nvtx .annotate (
144- f"b={ args .batch_size } s={ args .seq_len_q } EP{ world_size } " ):
136+ with nvtx .annotate (f"b={ args .batch_size } s={ args .seq_len_q } EP{ world_size } " ):
145137 if args .use_cuda_graph :
146138 g .replay ()
147139 else :
@@ -151,16 +143,16 @@ def comma_separated_ints(s):
151143
152144# Print statistics
153145# Print before `cudaProfilerStop` to ensure messages are included in the profile
154- time_list = [
155- start .elapsed_time (stop ) for start , stop in zip (events , events [1 :])
156- ]
146+ time_list = [start .elapsed_time (stop ) for start , stop in zip (events , events [1 :])]
157147time_list = time_list [warmup_times :]
158- print (f"[RANK { rank } ]"
159- f" min { np .min (time_list ) * 1000 :.1f} "
160- f" max { np .max (time_list ) * 1000 :.1f} "
161- f" mean { np .mean (time_list ) * 1000 :.1f} "
162- f" median { np .median (time_list ) * 1000 :.1f} "
163- f" P90 { np .percentile (time_list , 90 ) * 1000 :.1f} "
164- f" (us)" )
148+ print (
149+ f"[RANK { rank } ]"
150+ f" min { np .min (time_list ) * 1000 :.1f} "
151+ f" max { np .max (time_list ) * 1000 :.1f} "
152+ f" mean { np .mean (time_list ) * 1000 :.1f} "
153+ f" median { np .median (time_list ) * 1000 :.1f} "
154+ f" P90 { np .percentile (time_list , 90 ) * 1000 :.1f} "
155+ f" (us)"
156+ )
165157
166158torch .cuda .cudart ().cudaProfilerStop ()
0 commit comments