Skip to content

Commit 9a79b78

Browse files
authored
[API change] deprecate tile_token_dim in trtllm_moe (#2086)
<!-- .github/pull_request_template.md --> ## 📌 Description Deprecate `tile_token_dim` in trtllm_moe. It is already not used and mark with deprecation warning, plan to deprecate totally in next major release <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Removed the deprecated `tile_tokens_dim` parameter from MOE benchmarks and kernel functions, streamlining API calls and eliminating associated deprecation warnings. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
1 parent 54101e9 commit 9a79b78

File tree

10 files changed

+8
-91
lines changed

10 files changed

+8
-91
lines changed

benchmarks/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,7 @@ The output CSV will contain detailed metrics including:
166166
| `--topk_group` | Number of groups to consider for top-k routing. Default: 1 |
167167
| `--routed_scaling_factor`| Scaling factor for routing. Default: 2.5 |
168168
| `--local_expert_offset` | Offset of local experts in global expert space. Default: 0 |
169-
| `--local_num_experts` | Number of experts handled by this device. Default: equals num_experts |
170-
| `--tile_tokens_dim` | Tile dimension for tokens. Default: 8 |
169+
| `--local_num_experts` | Number of experts handled by this device. Default: equals num_experts | |
171170
| `--routing_method` | Routing method: `renormalize`, `deepseek_v3`, `llama4`, `renormalize_naive`. Default: `deepseek_v3`. |
172171
| `--use_shuffled_weight` | Whether to use shuffled weight layout |
173172
| `--weight_layout` | Weight layout: 0=MajorK, 1=MajorMn, 2=BlockMajorK. Default: 0 |

benchmarks/bench_trtllm_gen_fused_moe_autotuner.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
114114
0, # local_expert_offset
115115
num_experts,
116116
2.5, # routed_scaling_factor
117-
None, # tile_tokens_dim
118117
RoutingMethodType.DeepSeekV3.value,
119118
True, # use_shuffled_weight
120119
WeightLayout.BlockMajorK.value, # weight_layout
@@ -142,7 +141,6 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
142141
num_experts,
143142
1.0, # routed_scaling_factor
144143
False, # use_routing_scales_on_input
145-
None, # tile_tokens_dim
146144
RoutingMethodType.TopK.value,
147145
enable_pdl,
148146
num_tokens if tune_max_num_tokens is None else tune_max_num_tokens,
@@ -287,7 +285,6 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4(
287285
0, # local_expert_offset
288286
num_experts,
289287
None, # routed_scaling_factor
290-
None, # tile_tokens_dim
291288
RoutingMethodType.Renormalize.value,
292289
True,
293290
enable_pdl,

benchmarks/routines/flashinfer_benchmark_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
"routed_scaling_factor",
5454
"local_expert_offset",
5555
"local_num_experts",
56-
"tile_tokens_dim",
5756
"routing_method",
5857
"use_shuffled_weight",
5958
"weight_layout",

benchmarks/routines/moe.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,6 @@ def parse_moe_args(line, parser):
116116
default=None,
117117
help="Number of experts handled by this device. Defaults to num_experts.",
118118
)
119-
parser.add_argument(
120-
"--tile_tokens_dim",
121-
type=int,
122-
required=False,
123-
default=8,
124-
help="Tile dimension for tokens.",
125-
)
126119
parser.add_argument(
127120
"--routing_method",
128121
type=str,
@@ -560,7 +553,6 @@ def testTrtllmFp4BlockScaleMoe(args):
560553
)
561554
local_expert_offset = args.local_expert_offset
562555
local_num_experts = args.local_num_experts or num_experts
563-
tile_tokens_dim = args.tile_tokens_dim
564556
routing_method_type = args.routing_method_type
565557
use_shuffled_weight = args.use_shuffled_weight
566558
weight_layout = args.weight_layout
@@ -705,7 +697,6 @@ def run_fp4_moe():
705697
local_expert_offset=local_expert_offset,
706698
local_num_experts=local_num_experts,
707699
routed_scaling_factor=routed_scaling_factor,
708-
tile_tokens_dim=tile_tokens_dim,
709700
routing_method_type=routing_method_type,
710701
gated_act_type=gated_act_type,
711702
do_finalize=True,
@@ -780,7 +771,6 @@ def run_fp4_moe():
780771
cur_res["routed_scaling_factor"] = routed_scaling_factor
781772
cur_res["local_expert_offset"] = local_expert_offset
782773
cur_res["local_num_experts"] = local_num_experts
783-
cur_res["tile_tokens_dim"] = tile_tokens_dim
784774
cur_res["routing_method"] = args.routing_method
785775
cur_res["use_shuffled_weight"] = use_shuffled_weight
786776
cur_res["weight_layout"] = weight_layout
@@ -1185,7 +1175,6 @@ def testTrtllmFp8BlockScaleMoe(args):
11851175
)
11861176
local_expert_offset = args.local_expert_offset
11871177
local_num_experts = args.local_num_experts or num_experts
1188-
tile_tokens_dim = args.tile_tokens_dim
11891178
routing_method_type = args.routing_method_type
11901179
use_shuffled_weight = args.use_shuffled_weight
11911180
weight_layout = args.weight_layout
@@ -1277,27 +1266,6 @@ def testTrtllmFp8BlockScaleMoe(args):
12771266
print(f"[VVERBOSE] gemm1_weights_fp8.shape = {gemm1_weights_fp8.shape}")
12781267
print(f"[VVERBOSE] gemm2_weights_fp8.shape = {gemm2_weights_fp8.shape}")
12791268

1280-
# Match test heuristic for tile_tokens_dim when using BlockMajorK
1281-
if use_shuffled_weight and weight_layout == WeightLayout.BlockMajorK:
1282-
1283-
def _next_pow2(x: int) -> int:
1284-
x = max(1, x)
1285-
x -= 1
1286-
x |= x >> 1
1287-
x |= x >> 2
1288-
x |= x >> 4
1289-
x |= x >> 8
1290-
x |= x >> 16
1291-
return x + 1
1292-
1293-
tokens_per_expert = max(1, (num_tokens * top_k) // max(local_num_experts, 1))
1294-
suggested_tile = min(max(_next_pow2(tokens_per_expert), 8), 64)
1295-
if suggested_tile != tile_tokens_dim and args.verbose >= 1:
1296-
print(
1297-
f"[INFO] Overriding tile_tokens_dim {tile_tokens_dim} -> {suggested_tile} for BlockMajorK"
1298-
)
1299-
tile_tokens_dim = suggested_tile
1300-
13011269
def run_fp8_block_moe():
13021270
# Quantize hidden states to FP8 for block scale MOE
13031271
hidden_states_fp8 = hidden_states.to(torch.float8_e4m3fn)
@@ -1320,7 +1288,6 @@ def run_fp8_block_moe():
13201288
local_expert_offset=local_expert_offset,
13211289
local_num_experts=local_num_experts,
13221290
routed_scaling_factor=routed_scaling_factor,
1323-
tile_tokens_dim=tile_tokens_dim,
13241291
routing_method_type=routing_method_type,
13251292
use_shuffled_weight=use_shuffled_weight,
13261293
weight_layout=weight_layout,
@@ -1381,7 +1348,6 @@ def run_fp8_block_moe():
13811348
cur_res["routed_scaling_factor"] = routed_scaling_factor
13821349
cur_res["local_expert_offset"] = local_expert_offset
13831350
cur_res["local_num_experts"] = local_num_experts
1384-
cur_res["tile_tokens_dim"] = tile_tokens_dim
13851351
cur_res["routing_method"] = args.routing_method
13861352
cur_res["use_shuffled_weight"] = use_shuffled_weight
13871353
cur_res["weight_layout"] = weight_layout
@@ -1448,7 +1414,6 @@ def testTrtllmFp8PerTensorScaleMoe(args):
14481414
)
14491415
local_expert_offset = args.local_expert_offset
14501416
local_num_experts = args.local_num_experts or num_experts
1451-
tile_tokens_dim = args.tile_tokens_dim
14521417
routing_method_type = args.routing_method_type
14531418
use_routing_scales_on_input = args.use_routing_scales_on_input
14541419
is_cuda_graph_compatible = not args.no_cuda_graph
@@ -1527,7 +1492,6 @@ def run_fp8_per_tensor_moe():
15271492
local_num_experts=local_num_experts,
15281493
routed_scaling_factor=routed_scaling_factor,
15291494
use_routing_scales_on_input=use_routing_scales_on_input,
1530-
tile_tokens_dim=tile_tokens_dim,
15311495
routing_method_type=routing_method_type,
15321496
)
15331497

@@ -1585,7 +1549,6 @@ def run_fp8_per_tensor_moe():
15851549
cur_res["routed_scaling_factor"] = routed_scaling_factor
15861550
cur_res["local_expert_offset"] = local_expert_offset
15871551
cur_res["local_num_experts"] = local_num_experts
1588-
cur_res["tile_tokens_dim"] = tile_tokens_dim
15891552
cur_res["routing_method"] = args.routing_method
15901553
cur_res["use_routing_bias"] = args.use_routing_bias
15911554
cur_res["use_routing_scales_on_input"] = use_routing_scales_on_input

benchmarks/samples/sample_testlist_output.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
routine,median_time,std_time,tflops,tb_per_sec,backend,page_size,batch_size,s_qo,s_kv,num_qo_heads,num_kv_heads,head_dim_qk,head_dim_vo,head_dim_ckv,head_dim_kpe,causal,q_dtype,kv_dtype,avg_actual_seq_len,random_actual_seq_len,m,n,k,group_size,tile_size,scale_major_mode,out_dtype,mma_sm,use_128x4_sf_layout,use_nvfp4,num_tokens,hidden_size,intermediate_size,num_experts,top_k,n_group,topk_group,routed_scaling_factor,local_expert_offset,local_num_experts,tile_tokens_dim,routing_method,use_shuffled_weight,weight_layout,use_routing_bias,use_routing_scales_on_input,input_dtype,weight_dtype,gated_act,cutlass_variant,quantized_input,tp_size,tp_rank,ep_size,ep_rank,refcheck,no_cuda_graph,use_cupti,allow_output_mismatch,random_seed,case_tag,generate_repro_command,repro_command
1+
routine,median_time,std_time,tflops,tb_per_sec,backend,page_size,batch_size,s_qo,s_kv,num_qo_heads,num_kv_heads,head_dim_qk,head_dim_vo,head_dim_ckv,head_dim_kpe,causal,q_dtype,kv_dtype,avg_actual_seq_len,random_actual_seq_len,m,n,k,group_size,tile_size,scale_major_mode,out_dtype,mma_sm,use_128x4_sf_layout,use_nvfp4,num_tokens,hidden_size,intermediate_size,num_experts,top_k,n_group,topk_group,routed_scaling_factor,local_expert_offset,local_num_experts,routing_method,use_shuffled_weight,weight_layout,use_routing_bias,use_routing_scales_on_input,input_dtype,weight_dtype,gated_act,cutlass_variant,quantized_input,tp_size,tp_rank,ep_size,ep_rank,refcheck,no_cuda_graph,use_cupti,allow_output_mismatch,random_seed,case_tag,generate_repro_command,repro_command
22
BatchPrefillWithPagedKVCacheWrapper,0.01244799979031086,0.0009464459008260536,13.963516944729905,0.3050282827732261,fa2,16,1,1024,1024,64,8,128,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B
33
BatchPrefillWithPagedKVCacheWrapper,0.01839040070772171,0.00021363710731210026,9.45155349045863,0.20646597430613514,cudnn,16,1,1024,1024,64,8,128,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B
44
BatchPrefillWithPagedKVCacheWrapper,0.008396799862384795,5.550615129103214e-05,20.70048814413847,0.45219512936224815,trtllm-gen,16,1,1024,1024,64,8,128,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B

0 commit comments

Comments
 (0)