Skip to content

Commit 157ba33

Browse files
committed
add BF16 autotune and fix api
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
1 parent 81c21c8 commit 157ba33

File tree

4 files changed

+251
-114
lines changed

4 files changed

+251
-114
lines changed

csrc/trtllm_batched_gemm_runner.cu

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,16 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
116116
}
117117
}
118118

119-
FLASHINFER_CHECK(
120-
!mPassingConfigIndices.empty(),
121-
"No kernel found for the given options: mDtypeA: %s, mDtypeB: %s, mDtypeC: %s, "
122-
"mUseDeepSeekFp8: %d, "
123-
"mTransposeMmaOutput: %d, mRouteAct: %d, mFusedAct: %d, mIsStaticBatch: %d, mTileSize: %d",
124-
tg::dtypeToString(mOptions.dtypeA).c_str(), tg::dtypeToString(mOptions.dtypeB).c_str(),
125-
tg::dtypeToString(mOptions.dtypeC).c_str(), mOptions.deepSeekFp8, mOptions.transposeMmaOutput,
126-
mOptions.routeAct, mOptions.fusedAct, mOptions.staticBatch, mOptions.tileSize);
119+
std::ostringstream error_msg;
120+
error_msg << "No kernel found for the given options: "
121+
<< "mDtypeA: " << tg::dtypeToString(mOptions.dtypeA)
122+
<< ", mDtypeB: " << tg::dtypeToString(mOptions.dtypeB)
123+
<< ", mDtypeC: " << tg::dtypeToString(mOptions.dtypeC)
124+
<< ", mUseDeepSeekFp8: " << mOptions.deepSeekFp8
125+
<< ", mTransposeMmaOutput: " << mOptions.transposeMmaOutput
126+
<< ", mRouteAct: " << mOptions.routeAct << ", mFusedAct: " << mOptions.fusedAct
127+
<< ", mIsStaticBatch: " << mOptions.staticBatch << ", mTileSize: " << mOptions.tileSize;
128+
FLASHINFER_CHECK(!mPassingConfigIndices.empty(), error_msg.str());
127129
}
128130

129131
size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 88 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ void FusedMoeLauncher::init_common(
409409

410410
class Bf16MoeLauncher : public FusedMoeLauncher {
411411
public:
412-
static constexpr std::array<int32_t, 4> mSupportedTileNums = {8, 16, 32, 64};
412+
static constexpr std::array<int32_t, 5> mSupportedTileNums = {8, 16, 32, 64, 128};
413413

414414
Bf16MoeLauncher(TensorView const& routing_logits, Optional<TensorView> const& routing_bias,
415415
TensorView const& hidden_states, TensorView const& gemm1_weights,
@@ -559,21 +559,7 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher {
559559
use_shuffled_weight, weight_layout, gated_act_type);
560560
}
561561

562-
void check_routing() const override {
563-
FusedMoeLauncher::check_routing_common();
564-
565-
if (use_routing_scales_on_input) {
566-
TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_bfloat16)
567-
<< "routing_logits must be bfloat16.";
568-
} else if (static_cast<RoutingMethodType>(routing_method_type) ==
569-
RoutingMethodType::DeepSeekV3) {
570-
TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32)
571-
<< "routing_logits must be float.";
572-
} else {
573-
TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_bfloat16)
574-
<< "routing_logits must be bfloat16.";
575-
}
576-
}
562+
void check_routing() const override { FusedMoeLauncher::check_routing_common(); }
577563

578564
void prepare_routing() override {
579565
FusedMoeLauncher::prepare_routing_common();
@@ -767,14 +753,6 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
767753
void check_routing() const override {
768754
FusedMoeLauncher::check_routing_common();
769755

770-
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
771-
TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32)
772-
<< "routing_logits must be float.";
773-
} else {
774-
TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_bfloat16)
775-
<< "routing_logits must be bfloat16.";
776-
}
777-
778756
if (args->n_group != 0) {
779757
TVM_FFI_ICHECK(static_cast<RoutingMethodType>(routing_method_type) ==
780758
RoutingMethodType::DeepSeekV3)
@@ -1272,44 +1250,72 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
12721250
Tensor trtllm_bf16_moe(TensorView const& routing_logits, Optional<TensorView> const& routing_bias,
12731251
TensorView const& hidden_states, TensorView const& gemm1_weights,
12741252
TensorView const& gemm2_weights, int64_t num_experts, int64_t top_k,
1275-
int64_t n_group, int64_t topk_group, int64_t intermediate_size,
1276-
int64_t local_expert_offset, int64_t local_num_experts,
1277-
int64_t tile_tokens_dim, int64_t routing_method_type,
1278-
bool use_shuffled_weight, int64_t weight_layout, int64_t moe_tactic,
1279-
bool enable_pdl) {
1253+
Optional<int64_t> n_group, Optional<int64_t> topk_group,
1254+
int64_t intermediate_size, int64_t local_expert_offset,
1255+
int64_t local_num_experts, int64_t routing_method_type,
1256+
bool use_shuffled_weight, int64_t weight_layout, bool enable_pdl,
1257+
Array<int64_t> moe_tactic) {
12801258
// Just some basic type validation first and leave more checks to the launcher
12811259
TVM_FFI_ICHECK(routing_logits.dtype() == dl_float32 || routing_logits.dtype() == dl_bfloat16)
12821260
<< "BF16 MoE: routing_logits must be bfloat16 or float.";
1283-
if (routing_bias.has_value()) {
1284-
TVM_FFI_ICHECK_EQ(routing_bias.value().dtype(), dl_bfloat16)
1285-
<< "BF16 MoE: routing_bias must be bfloat16.";
1286-
}
12871261
TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_bfloat16)
12881262
<< "BF16 MoE: hidden_states must be bfloat16.";
12891263
TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_bfloat16)
12901264
<< "BF16 MoE: gemm1_weights must be bfloat16.";
12911265
TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_bfloat16)
12921266
<< "BF16 MoE: gemm2_weights must be bfloat16.";
12931267

1294-
// Save params to MoE arguments
1295-
auto args = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>();
1296-
args->num_tokens = hidden_states.size(0);
1297-
args->num_experts = num_experts;
1298-
args->hidden_size = hidden_states.size(1);
1299-
args->hidden_size_output = args->hidden_size;
1300-
args->top_k = top_k;
1301-
args->n_group = n_group;
1302-
args->topk_group = topk_group;
1303-
args->local_expert_offset = local_expert_offset;
1304-
args->local_num_experts = local_num_experts;
1305-
args->intermediate_size = intermediate_size;
1306-
1307-
Bf16MoeLauncher launcher(routing_logits, routing_bias, hidden_states, gemm1_weights,
1308-
gemm2_weights);
1309-
launcher.init(std::move(args), tile_tokens_dim, routing_method_type, use_shuffled_weight,
1310-
weight_layout);
1311-
auto data = launcher.run(moe_tactic, enable_pdl)[0];
1312-
return data;
1268+
auto const num_tokens = hidden_states.size(0);
1269+
auto const hidden_size = hidden_states.size(1);
1270+
1271+
// Calculate supported tile sizes
1272+
std::vector<int32_t> mSupportedTileN(Bf16MoeLauncher::mSupportedTileNums.begin(),
1273+
Bf16MoeLauncher::mSupportedTileNums.end());
1274+
std::set<int32_t> selected_tile_nums =
1275+
computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts);
1276+
1277+
// Create a map of launchers for each tile size
1278+
std::unordered_map<int32_t, std::unique_ptr<Bf16MoeLauncher>> launchers_map;
1279+
1280+
for (int32_t curr_tile_N : selected_tile_nums) {
1281+
// Create MoE arguments for this launcher
1282+
auto args = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>();
1283+
args->num_tokens = num_tokens;
1284+
args->num_experts = num_experts;
1285+
args->hidden_size = hidden_size;
1286+
args->hidden_size_output = args->hidden_size;
1287+
args->top_k = top_k;
1288+
args->n_group = n_group.value_or(0);
1289+
args->topk_group = topk_group.value_or(0);
1290+
;
1291+
args->local_expert_offset = local_expert_offset;
1292+
args->local_num_experts = local_num_experts;
1293+
args->intermediate_size = intermediate_size;
1294+
1295+
// Create and initialize launcher for this tile size
1296+
auto launcher = std::make_unique<Bf16MoeLauncher>(routing_logits, routing_bias, hidden_states,
1297+
gemm1_weights, gemm2_weights);
1298+
launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight,
1299+
weight_layout);
1300+
1301+
launchers_map[curr_tile_N] = std::move(launcher);
1302+
}
1303+
1304+
// Extract tile_N and config from moe_tactic
1305+
int64_t tile_N = moe_tactic[0];
1306+
int64_t config = moe_tactic[1];
1307+
1308+
// Handle default case
1309+
if (tile_N == -1 || config == -1) {
1310+
tile_N = *selected_tile_nums.begin();
1311+
}
1312+
1313+
// Get the launcher for the selected tile_N
1314+
auto& selected_launcher = launchers_map.at(tile_N);
1315+
1316+
// Run the launcher - it will create its own runner internally
1317+
auto result = selected_launcher->run(config, enable_pdl)[0];
1318+
return result;
13131319
}
13141320

13151321
Tensor trtllm_fp8_per_tensor_scale_moe(
@@ -1323,6 +1329,13 @@ Tensor trtllm_fp8_per_tensor_scale_moe(
13231329
Array<int64_t> config_index) {
13241330
// Basic type validation
13251331
auto dtype = hidden_states.dtype();
1332+
if (use_routing_scales_on_input) {
1333+
TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16.";
1334+
} else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
1335+
TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) << "routing_logits must be float.";
1336+
} else {
1337+
TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16.";
1338+
}
13261339
TVM_FFI_ICHECK(dtype == dl_float8_e4m3fn || dtype == dl_float16 || dtype == dl_bfloat16)
13271340
<< "FP8 MoE: hidden_states must be float8_e4m3fn, float16, or bfloat16.";
13281341
TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn)
@@ -1407,6 +1420,11 @@ Tensor trtllm_fp8_block_scale_moe(
14071420
int64_t weight_layout, bool enable_pdl, Array<int64_t> config_index) {
14081421
// Basic type validation
14091422
auto dtype = hidden_states.dtype();
1423+
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
1424+
TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) << "routing_logits must be float.";
1425+
} else {
1426+
TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16.";
1427+
}
14101428
TVM_FFI_ICHECK(dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn)
14111429
<< "FP8 block scale MoE: hidden_states must be fp16, bf16, or fp8.";
14121430
TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_float32)
@@ -1507,6 +1525,24 @@ Array<Tensor> trtllm_fp4_block_scale_moe(
15071525
<< "unsupported weight_scale_vec_size.";
15081526
auto mDtypeWeights = weight_scale_vec_size == 16 ? btg::Dtype::E2m1 : btg::Dtype::MxE2m1;
15091527

1528+
if (routing_logits.has_value()) {
1529+
TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 ||
1530+
routing_logits.value().dtype() == dl_bfloat16)
1531+
<< "routing_logits must be float or bfloat16.";
1532+
TVM_FFI_ICHECK_EQ(routing_logits.value().ndim(), 2) << "routing_logits must be 2D.";
1533+
TVM_FFI_ICHECK_EQ(routing_logits.value().size(1), num_experts)
1534+
<< "routing_logits has incorrect shape.";
1535+
}
1536+
if (routing_bias.has_value()) {
1537+
TVM_FFI_ICHECK(routing_bias.value().dtype() == dl_bfloat16 ||
1538+
routing_bias.value().dtype() == dl_float32)
1539+
<< "routing_bias must be bfloat16 or float.";
1540+
1541+
TVM_FFI_ICHECK_EQ(routing_bias.value().ndim(), 1) << "routing_bias must be 1D.";
1542+
TVM_FFI_ICHECK_EQ(routing_bias.value().size(0), num_experts)
1543+
<< "routing_bias has incorrect shape.";
1544+
}
1545+
15101546
// Determine activation type
15111547
TVM_FFI_ICHECK(gemm1_weights.dtype() == dl_uint8 && gemm2_weights.dtype() == dl_uint8)
15121548
<< "weights must be fp4 packed in uint8.";

0 commit comments

Comments
 (0)