Skip to content

Commit 450bf90

Browse files
committed
Revert "[API change] deprecate tile_token_dim in trtllm_moe (#2086)"
This reverts commit 9a79b78.
1 parent 049e8db commit 450bf90

File tree

10 files changed

+91
-8
lines changed

10 files changed

+91
-8
lines changed

benchmarks/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ The output CSV will contain detailed metrics including:
166166
| `--topk_group` | Number of groups to consider for top-k routing. Default: 1 |
167167
| `--routed_scaling_factor`| Scaling factor for routing. Default: 2.5 |
168168
| `--local_expert_offset` | Offset of local experts in global expert space. Default: 0 |
169-
| `--local_num_experts` | Number of experts handled by this device. Default: equals num_experts | |
169+
| `--local_num_experts` | Number of experts handled by this device. Default: equals num_experts |
170+
| `--tile_tokens_dim` | Tile dimension for tokens. Default: 8 |
170171
| `--routing_method` | Routing method: `renormalize`, `deepseek_v3`, `llama4`, `renormalize_naive`. Default: `deepseek_v3`. |
171172
| `--use_shuffled_weight` | Whether to use shuffled weight layout |
172173
| `--weight_layout` | Weight layout: 0=MajorK, 1=MajorMn, 2=BlockMajorK. Default: 0 |

benchmarks/bench_trtllm_gen_fused_moe_autotuner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
114114
0, # local_expert_offset
115115
num_experts,
116116
2.5, # routed_scaling_factor
117+
None, # tile_tokens_dim
117118
RoutingMethodType.DeepSeekV3.value,
118119
True, # use_shuffled_weight
119120
WeightLayout.BlockMajorK.value, # weight_layout
@@ -141,6 +142,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
141142
num_experts,
142143
1.0, # routed_scaling_factor
143144
False, # use_routing_scales_on_input
145+
None, # tile_tokens_dim
144146
RoutingMethodType.TopK.value,
145147
enable_pdl,
146148
num_tokens if tune_max_num_tokens is None else tune_max_num_tokens,
@@ -285,6 +287,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4(
285287
0, # local_expert_offset
286288
num_experts,
287289
None, # routed_scaling_factor
290+
None, # tile_tokens_dim
288291
RoutingMethodType.Renormalize.value,
289292
True,
290293
enable_pdl,

benchmarks/routines/flashinfer_benchmark_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
"routed_scaling_factor",
5454
"local_expert_offset",
5555
"local_num_experts",
56+
"tile_tokens_dim",
5657
"routing_method",
5758
"use_shuffled_weight",
5859
"weight_layout",

benchmarks/routines/moe.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,13 @@ def parse_moe_args(line, parser):
116116
default=None,
117117
help="Number of experts handled by this device. Defaults to num_experts.",
118118
)
119+
parser.add_argument(
120+
"--tile_tokens_dim",
121+
type=int,
122+
required=False,
123+
default=8,
124+
help="Tile dimension for tokens.",
125+
)
119126
parser.add_argument(
120127
"--routing_method",
121128
type=str,
@@ -553,6 +560,7 @@ def testTrtllmFp4BlockScaleMoe(args):
553560
)
554561
local_expert_offset = args.local_expert_offset
555562
local_num_experts = args.local_num_experts or num_experts
563+
tile_tokens_dim = args.tile_tokens_dim
556564
routing_method_type = args.routing_method_type
557565
use_shuffled_weight = args.use_shuffled_weight
558566
weight_layout = args.weight_layout
@@ -697,6 +705,7 @@ def run_fp4_moe():
697705
local_expert_offset=local_expert_offset,
698706
local_num_experts=local_num_experts,
699707
routed_scaling_factor=routed_scaling_factor,
708+
tile_tokens_dim=tile_tokens_dim,
700709
routing_method_type=routing_method_type,
701710
gated_act_type=gated_act_type,
702711
do_finalize=True,
@@ -771,6 +780,7 @@ def run_fp4_moe():
771780
cur_res["routed_scaling_factor"] = routed_scaling_factor
772781
cur_res["local_expert_offset"] = local_expert_offset
773782
cur_res["local_num_experts"] = local_num_experts
783+
cur_res["tile_tokens_dim"] = tile_tokens_dim
774784
cur_res["routing_method"] = args.routing_method
775785
cur_res["use_shuffled_weight"] = use_shuffled_weight
776786
cur_res["weight_layout"] = weight_layout
@@ -1175,6 +1185,7 @@ def testTrtllmFp8BlockScaleMoe(args):
11751185
)
11761186
local_expert_offset = args.local_expert_offset
11771187
local_num_experts = args.local_num_experts or num_experts
1188+
tile_tokens_dim = args.tile_tokens_dim
11781189
routing_method_type = args.routing_method_type
11791190
use_shuffled_weight = args.use_shuffled_weight
11801191
weight_layout = args.weight_layout
@@ -1266,6 +1277,27 @@ def testTrtllmFp8BlockScaleMoe(args):
12661277
print(f"[VVERBOSE] gemm1_weights_fp8.shape = {gemm1_weights_fp8.shape}")
12671278
print(f"[VVERBOSE] gemm2_weights_fp8.shape = {gemm2_weights_fp8.shape}")
12681279

1280+
# Match test heuristic for tile_tokens_dim when using BlockMajorK
1281+
if use_shuffled_weight and weight_layout == WeightLayout.BlockMajorK:
1282+
1283+
def _next_pow2(x: int) -> int:
1284+
x = max(1, x)
1285+
x -= 1
1286+
x |= x >> 1
1287+
x |= x >> 2
1288+
x |= x >> 4
1289+
x |= x >> 8
1290+
x |= x >> 16
1291+
return x + 1
1292+
1293+
tokens_per_expert = max(1, (num_tokens * top_k) // max(local_num_experts, 1))
1294+
suggested_tile = min(max(_next_pow2(tokens_per_expert), 8), 64)
1295+
if suggested_tile != tile_tokens_dim and args.verbose >= 1:
1296+
print(
1297+
f"[INFO] Overriding tile_tokens_dim {tile_tokens_dim} -> {suggested_tile} for BlockMajorK"
1298+
)
1299+
tile_tokens_dim = suggested_tile
1300+
12691301
def run_fp8_block_moe():
12701302
# Quantize hidden states to FP8 for block scale MOE
12711303
hidden_states_fp8 = hidden_states.to(torch.float8_e4m3fn)
@@ -1288,6 +1320,7 @@ def run_fp8_block_moe():
12881320
local_expert_offset=local_expert_offset,
12891321
local_num_experts=local_num_experts,
12901322
routed_scaling_factor=routed_scaling_factor,
1323+
tile_tokens_dim=tile_tokens_dim,
12911324
routing_method_type=routing_method_type,
12921325
use_shuffled_weight=use_shuffled_weight,
12931326
weight_layout=weight_layout,
@@ -1348,6 +1381,7 @@ def run_fp8_block_moe():
13481381
cur_res["routed_scaling_factor"] = routed_scaling_factor
13491382
cur_res["local_expert_offset"] = local_expert_offset
13501383
cur_res["local_num_experts"] = local_num_experts
1384+
cur_res["tile_tokens_dim"] = tile_tokens_dim
13511385
cur_res["routing_method"] = args.routing_method
13521386
cur_res["use_shuffled_weight"] = use_shuffled_weight
13531387
cur_res["weight_layout"] = weight_layout
@@ -1414,6 +1448,7 @@ def testTrtllmFp8PerTensorScaleMoe(args):
14141448
)
14151449
local_expert_offset = args.local_expert_offset
14161450
local_num_experts = args.local_num_experts or num_experts
1451+
tile_tokens_dim = args.tile_tokens_dim
14171452
routing_method_type = args.routing_method_type
14181453
use_routing_scales_on_input = args.use_routing_scales_on_input
14191454
is_cuda_graph_compatible = not args.no_cuda_graph
@@ -1492,6 +1527,7 @@ def run_fp8_per_tensor_moe():
14921527
local_num_experts=local_num_experts,
14931528
routed_scaling_factor=routed_scaling_factor,
14941529
use_routing_scales_on_input=use_routing_scales_on_input,
1530+
tile_tokens_dim=tile_tokens_dim,
14951531
routing_method_type=routing_method_type,
14961532
)
14971533

@@ -1549,6 +1585,7 @@ def run_fp8_per_tensor_moe():
15491585
cur_res["routed_scaling_factor"] = routed_scaling_factor
15501586
cur_res["local_expert_offset"] = local_expert_offset
15511587
cur_res["local_num_experts"] = local_num_experts
1588+
cur_res["tile_tokens_dim"] = tile_tokens_dim
15521589
cur_res["routing_method"] = args.routing_method
15531590
cur_res["use_routing_bias"] = args.use_routing_bias
15541591
cur_res["use_routing_scales_on_input"] = use_routing_scales_on_input

benchmarks/samples/sample_testlist_output.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
routine,median_time,std_time,tflops,tb_per_sec,backend,page_size,batch_size,s_qo,s_kv,num_qo_heads,num_kv_heads,head_dim_qk,head_dim_vo,head_dim_ckv,head_dim_kpe,causal,q_dtype,kv_dtype,avg_actual_seq_len,random_actual_seq_len,m,n,k,group_size,tile_size,scale_major_mode,out_dtype,mma_sm,use_128x4_sf_layout,use_nvfp4,num_tokens,hidden_size,intermediate_size,num_experts,top_k,n_group,topk_group,routed_scaling_factor,local_expert_offset,local_num_experts,routing_method,use_shuffled_weight,weight_layout,use_routing_bias,use_routing_scales_on_input,input_dtype,weight_dtype,gated_act,cutlass_variant,quantized_input,tp_size,tp_rank,ep_size,ep_rank,refcheck,no_cuda_graph,use_cupti,allow_output_mismatch,random_seed,case_tag,generate_repro_command,repro_command
1+
routine,median_time,std_time,tflops,tb_per_sec,backend,page_size,batch_size,s_qo,s_kv,num_qo_heads,num_kv_heads,head_dim_qk,head_dim_vo,head_dim_ckv,head_dim_kpe,causal,q_dtype,kv_dtype,avg_actual_seq_len,random_actual_seq_len,m,n,k,group_size,tile_size,scale_major_mode,out_dtype,mma_sm,use_128x4_sf_layout,use_nvfp4,num_tokens,hidden_size,intermediate_size,num_experts,top_k,n_group,topk_group,routed_scaling_factor,local_expert_offset,local_num_experts,tile_tokens_dim,routing_method,use_shuffled_weight,weight_layout,use_routing_bias,use_routing_scales_on_input,input_dtype,weight_dtype,gated_act,cutlass_variant,quantized_input,tp_size,tp_rank,ep_size,ep_rank,refcheck,no_cuda_graph,use_cupti,allow_output_mismatch,random_seed,case_tag,generate_repro_command,repro_command
22
BatchPrefillWithPagedKVCacheWrapper,0.01244799979031086,0.0009464459008260536,13.963516944729905,0.3050282827732261,fa2,16,1,1024,1024,64,8,128,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B
33
BatchPrefillWithPagedKVCacheWrapper,0.01839040070772171,0.00021363710731210026,9.45155349045863,0.20646597430613514,cudnn,16,1,1024,1024,64,8,128,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B
44
BatchPrefillWithPagedKVCacheWrapper,0.008396799862384795,5.550615129103214e-05,20.70048814413847,0.45219512936224815,trtllm-gen,16,1,1024,1024,64,8,128,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B

0 commit comments

Comments
 (0)