@@ -409,7 +409,7 @@ void FusedMoeLauncher::init_common(
409409
410410class Bf16MoeLauncher : public FusedMoeLauncher {
411411 public:
412- static constexpr std::array<int32_t , 4 > mSupportedTileNums = {8 , 16 , 32 , 64 };
412+ static constexpr std::array<int32_t , 5 > mSupportedTileNums = {8 , 16 , 32 , 64 , 128 };
413413
414414 Bf16MoeLauncher (TensorView const & routing_logits, Optional<TensorView> const & routing_bias,
415415 TensorView const & hidden_states, TensorView const & gemm1_weights,
@@ -559,21 +559,7 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher {
559559 use_shuffled_weight, weight_layout, gated_act_type);
560560 }
561561
562- void check_routing () const override {
563- FusedMoeLauncher::check_routing_common ();
564-
565- if (use_routing_scales_on_input) {
566- TVM_FFI_ICHECK_EQ (routing_logits.value ().dtype (), dl_bfloat16)
567- << " routing_logits must be bfloat16." ;
568- } else if (static_cast <RoutingMethodType>(routing_method_type) ==
569- RoutingMethodType::DeepSeekV3) {
570- TVM_FFI_ICHECK_EQ (routing_logits.value ().dtype (), dl_float32)
571- << " routing_logits must be float." ;
572- } else {
573- TVM_FFI_ICHECK_EQ (routing_logits.value ().dtype (), dl_bfloat16)
574- << " routing_logits must be bfloat16." ;
575- }
576- }
562+ void check_routing () const override { FusedMoeLauncher::check_routing_common (); }
577563
578564 void prepare_routing () override {
579565 FusedMoeLauncher::prepare_routing_common ();
@@ -767,14 +753,6 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
767753 void check_routing () const override {
768754 FusedMoeLauncher::check_routing_common ();
769755
770- if (static_cast <RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
771- TVM_FFI_ICHECK_EQ (routing_logits.value ().dtype (), dl_float32)
772- << " routing_logits must be float." ;
773- } else {
774- TVM_FFI_ICHECK_EQ (routing_logits.value ().dtype (), dl_bfloat16)
775- << " routing_logits must be bfloat16." ;
776- }
777-
778756 if (args->n_group != 0 ) {
779757 TVM_FFI_ICHECK (static_cast <RoutingMethodType>(routing_method_type) ==
780758 RoutingMethodType::DeepSeekV3)
@@ -1272,44 +1250,72 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
12721250Tensor trtllm_bf16_moe (TensorView const & routing_logits, Optional<TensorView> const & routing_bias,
12731251 TensorView const & hidden_states, TensorView const & gemm1_weights,
12741252 TensorView const & gemm2_weights, int64_t num_experts, int64_t top_k,
1275- int64_t n_group, int64_t topk_group, int64_t intermediate_size ,
1276- int64_t local_expert_offset , int64_t local_num_experts ,
1277- int64_t tile_tokens_dim , int64_t routing_method_type,
1278- bool use_shuffled_weight, int64_t weight_layout, int64_t moe_tactic ,
1279- bool enable_pdl ) {
1253+ Optional< int64_t > n_group, Optional< int64_t > topk_group,
1254+ int64_t intermediate_size , int64_t local_expert_offset ,
1255+ int64_t local_num_experts , int64_t routing_method_type,
1256+ bool use_shuffled_weight, int64_t weight_layout, bool enable_pdl ,
1257+ Array< int64_t > moe_tactic ) {
12801258 // Just some basic type validation first and leave more checks to the launcher
12811259 TVM_FFI_ICHECK (routing_logits.dtype () == dl_float32 || routing_logits.dtype () == dl_bfloat16)
12821260 << " BF16 MoE: routing_logits must be bfloat16 or float." ;
1283- if (routing_bias.has_value ()) {
1284- TVM_FFI_ICHECK_EQ (routing_bias.value ().dtype (), dl_bfloat16)
1285- << " BF16 MoE: routing_bias must be bfloat16." ;
1286- }
12871261 TVM_FFI_ICHECK_EQ (hidden_states.dtype (), dl_bfloat16)
12881262 << " BF16 MoE: hidden_states must be bfloat16." ;
12891263 TVM_FFI_ICHECK_EQ (gemm1_weights.dtype (), dl_bfloat16)
12901264 << " BF16 MoE: gemm1_weights must be bfloat16." ;
12911265 TVM_FFI_ICHECK_EQ (gemm2_weights.dtype (), dl_bfloat16)
12921266 << " BF16 MoE: gemm2_weights must be bfloat16." ;
12931267
1294- // Save params to MoE arguments
1295- auto args = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>();
1296- args->num_tokens = hidden_states.size (0 );
1297- args->num_experts = num_experts;
1298- args->hidden_size = hidden_states.size (1 );
1299- args->hidden_size_output = args->hidden_size ;
1300- args->top_k = top_k;
1301- args->n_group = n_group;
1302- args->topk_group = topk_group;
1303- args->local_expert_offset = local_expert_offset;
1304- args->local_num_experts = local_num_experts;
1305- args->intermediate_size = intermediate_size;
1306-
1307- Bf16MoeLauncher launcher (routing_logits, routing_bias, hidden_states, gemm1_weights,
1308- gemm2_weights);
1309- launcher.init (std::move (args), tile_tokens_dim, routing_method_type, use_shuffled_weight,
1310- weight_layout);
1311- auto data = launcher.run (moe_tactic, enable_pdl)[0 ];
1312- return data;
1268+ auto const num_tokens = hidden_states.size (0 );
1269+ auto const hidden_size = hidden_states.size (1 );
1270+
1271+ // Calculate supported tile sizes
1272+ std::vector<int32_t > mSupportedTileN (Bf16MoeLauncher::mSupportedTileNums .begin (),
1273+ Bf16MoeLauncher::mSupportedTileNums .end ());
1274+ std::set<int32_t > selected_tile_nums =
1275+ computeSelectedTileN (mSupportedTileN , num_tokens, top_k, local_num_experts);
1276+
1277+ // Create a map of launchers for each tile size
1278+ std::unordered_map<int32_t , std::unique_ptr<Bf16MoeLauncher>> launchers_map;
1279+
1280+ for (int32_t curr_tile_N : selected_tile_nums) {
1281+ // Create MoE arguments for this launcher
1282+ auto args = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>();
1283+ args->num_tokens = num_tokens;
1284+ args->num_experts = num_experts;
1285+ args->hidden_size = hidden_size;
1286+ args->hidden_size_output = args->hidden_size ;
1287+ args->top_k = top_k;
1288+ args->n_group = n_group.value_or (0 );
1289+ args->topk_group = topk_group.value_or (0 );
1290+ ;
1291+ args->local_expert_offset = local_expert_offset;
1292+ args->local_num_experts = local_num_experts;
1293+ args->intermediate_size = intermediate_size;
1294+
1295+ // Create and initialize launcher for this tile size
1296+ auto launcher = std::make_unique<Bf16MoeLauncher>(routing_logits, routing_bias, hidden_states,
1297+ gemm1_weights, gemm2_weights);
1298+ launcher->init (std::move (args), curr_tile_N, routing_method_type, use_shuffled_weight,
1299+ weight_layout);
1300+
1301+ launchers_map[curr_tile_N] = std::move (launcher);
1302+ }
1303+
1304+ // Extract tile_N and config from moe_tactic
1305+ int64_t tile_N = moe_tactic[0 ];
1306+ int64_t config = moe_tactic[1 ];
1307+
1308+ // Handle default case
1309+ if (tile_N == -1 || config == -1 ) {
1310+ tile_N = *selected_tile_nums.begin ();
1311+ }
1312+
1313+ // Get the launcher for the selected tile_N
1314+ auto & selected_launcher = launchers_map.at (tile_N);
1315+
1316+ // Run the launcher - it will create its own runner internally
1317+ auto result = selected_launcher->run (config, enable_pdl)[0 ];
1318+ return result;
13131319}
13141320
13151321Tensor trtllm_fp8_per_tensor_scale_moe (
@@ -1323,6 +1329,13 @@ Tensor trtllm_fp8_per_tensor_scale_moe(
13231329 Array<int64_t > config_index) {
13241330 // Basic type validation
13251331 auto dtype = hidden_states.dtype ();
1332+ if (use_routing_scales_on_input) {
1333+ TVM_FFI_ICHECK_EQ (routing_logits.dtype (), dl_bfloat16) << " routing_logits must be bfloat16." ;
1334+ } else if (static_cast <RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
1335+ TVM_FFI_ICHECK_EQ (routing_logits.dtype (), dl_float32) << " routing_logits must be float." ;
1336+ } else {
1337+ TVM_FFI_ICHECK_EQ (routing_logits.dtype (), dl_bfloat16) << " routing_logits must be bfloat16." ;
1338+ }
13261339 TVM_FFI_ICHECK (dtype == dl_float8_e4m3fn || dtype == dl_float16 || dtype == dl_bfloat16)
13271340 << " FP8 MoE: hidden_states must be float8_e4m3fn, float16, or bfloat16." ;
13281341 TVM_FFI_ICHECK_EQ (gemm1_weights.dtype (), dl_float8_e4m3fn)
@@ -1407,6 +1420,11 @@ Tensor trtllm_fp8_block_scale_moe(
14071420 int64_t weight_layout, bool enable_pdl, Array<int64_t > config_index) {
14081421 // Basic type validation
14091422 auto dtype = hidden_states.dtype ();
1423+ if (static_cast <RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
1424+ TVM_FFI_ICHECK_EQ (routing_logits.dtype (), dl_float32) << " routing_logits must be float." ;
1425+ } else {
1426+ TVM_FFI_ICHECK_EQ (routing_logits.dtype (), dl_bfloat16) << " routing_logits must be bfloat16." ;
1427+ }
14101428 TVM_FFI_ICHECK (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn)
14111429 << " FP8 block scale MoE: hidden_states must be fp16, bf16, or fp8." ;
14121430 TVM_FFI_ICHECK_EQ (hidden_states_scale.dtype (), dl_float32)
@@ -1507,6 +1525,24 @@ Array<Tensor> trtllm_fp4_block_scale_moe(
15071525 << " unsupported weight_scale_vec_size." ;
15081526 auto mDtypeWeights = weight_scale_vec_size == 16 ? btg::Dtype::E2m1 : btg::Dtype::MxE2m1;
15091527
1528+ if (routing_logits.has_value ()) {
1529+ TVM_FFI_ICHECK (routing_logits.value ().dtype () == dl_float32 ||
1530+ routing_logits.value ().dtype () == dl_bfloat16)
1531+ << " routing_logits must be float or bfloat16." ;
1532+ TVM_FFI_ICHECK_EQ (routing_logits.value ().ndim (), 2 ) << " routing_logits must be 2D." ;
1533+ TVM_FFI_ICHECK_EQ (routing_logits.value ().size (1 ), num_experts)
1534+ << " routing_logits has incorrect shape." ;
1535+ }
1536+ if (routing_bias.has_value ()) {
1537+ TVM_FFI_ICHECK (routing_bias.value ().dtype () == dl_bfloat16 ||
1538+ routing_bias.value ().dtype () == dl_float32)
1539+ << " routing_bias must be bfloat16 or float." ;
1540+
1541+ TVM_FFI_ICHECK_EQ (routing_bias.value ().ndim (), 1 ) << " routing_bias must be 1D." ;
1542+ TVM_FFI_ICHECK_EQ (routing_bias.value ().size (0 ), num_experts)
1543+ << " routing_bias has incorrect shape." ;
1544+ }
1545+
15101546 // Determine activation type
15111547 TVM_FFI_ICHECK (gemm1_weights.dtype () == dl_uint8 && gemm2_weights.dtype () == dl_uint8)
15121548 << " weights must be fp4 packed in uint8." ;
0 commit comments