@@ -259,8 +259,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
259259 torch::optional<torch::Tensor> const & swiglu_limit, int64_t const tp_size, int64_t const tp_rank,
260260 int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
261261 bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t >> const & profile_ids,
262- torch::optional<int64_t > const & unpadded_hidden_size , torch::optional<int64_t > const & num_valid_tokens ,
263- torch::optional<torch::Tensor> const & out_tensor)
262+ torch::optional<int64_t > const & activation_type , torch::optional<int64_t > const & unpadded_hidden_size ,
263+ torch::optional<int64_t > const & num_valid_tokens, torch::optional< torch::Tensor> const & out_tensor)
264264 {
265265 std::lock_guard<std::mutex> lock (mMutex );
266266 // Free the profile workspace to save memory
@@ -328,6 +328,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
328328 TORCH_CHECK (fc1_expert_weights.sizes ()[0 ] == fc2_expert_weights.sizes ()[0 ],
329329 " fc1_expert_weights and fc2_expert_weights must have the same number of experts." );
330330
331+ ActivationType base_activation_type = activation_type.has_value ()
332+ ? static_cast <ActivationType>(activation_type.value ())
333+ : ActivationType::Swiglu;
331334 if (mUseINT8WoqPerChannel )
332335 {
333336 // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
@@ -337,8 +340,16 @@ class FusedMoeRunner : public torch::CustomClassHolder
337340 }
338341 else
339342 {
340- TORCH_CHECK (fc1_expert_weights.sizes ()[1 ] == fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier * 2 ,
341- " fc1_expert_weights inter size must be fc2_expert_weights inter size." );
343+ if (isGatedActivation (base_activation_type))
344+ {
345+ TORCH_CHECK (fc1_expert_weights.sizes ()[1 ] == fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier * 2 ,
346+ " fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size." );
347+ }
348+ else
349+ {
350+ TORCH_CHECK (fc1_expert_weights.sizes ()[1 ] == fc2_expert_weights.sizes ()[2 ] * mInnerDimMultiplier ,
351+ " fc1_expert_weights inter size must be equal to fc2_expert_weights inter size." );
352+ }
342353 }
343354
344355 int experts_per_token = token_selected_experts.sizes ()[1 ];
@@ -375,7 +386,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
375386 int const num_experts_on_rank = fc2_expert_weights.sizes ()[0 ];
376387 auto const num_experts_total = static_cast <int >(num_experts_on_rank * ep_size);
377388 auto parallelism_config = kernels::MOEParallelismConfig (tp_size, tp_rank, ep_size, ep_rank);
378- ActivationType base_activation_type = ActivationType::Swiglu;
389+
379390 if (swiglu_alpha.has_value ())
380391 {
381392 CHECK_INPUT (swiglu_alpha.value (), at::ScalarType::Float);
@@ -474,8 +485,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
474485 torch::optional<torch::Tensor> const & swiglu_limit, int64_t const tp_size, int64_t const tp_rank,
475486 int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
476487 bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t >> const & profile_ids,
477- torch::optional<int64_t > const & unpadded_hidden_size , torch::optional<int64_t > const & num_valid_tokens ,
478- torch::optional<torch::Tensor> const & out_tensor)
488+ torch::optional<int64_t > const & activation_type , torch::optional<int64_t > const & unpadded_hidden_size ,
489+ torch::optional<int64_t > const & num_valid_tokens, torch::optional< torch::Tensor> const & out_tensor)
479490 {
480491 std::lock_guard<std::mutex> lock (mMutex );
481492
@@ -541,7 +552,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
541552 auto const num_experts_total = static_cast <int >(num_experts_on_rank * ep_size);
542553 auto parallelism_config
543554 = kernels::MOEParallelismConfig (tp_size, tp_rank, ep_size, ep_rank, cluster_size, cluster_rank);
544- ActivationType base_activation_type = ActivationType::Swiglu;
555+ ActivationType base_activation_type = activation_type.has_value ()
556+ ? static_cast <ActivationType>(activation_type.value ())
557+ : ActivationType::Swiglu;
545558 if (swiglu_alpha.has_value ())
546559 {
547560 CHECK_INPUT (swiglu_alpha.value (), at::ScalarType::Float);
@@ -652,7 +665,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
652665 torch::optional<torch::Tensor> const & fc2_expert_biases, int64_t const top_k, int64_t const tp_size,
653666 int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size,
654667 int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode, int64_t const gemm_idx,
655- int64_t const profile_id, bool const do_preparation, int64_t const unpadded_hidden_size)
668+ int64_t const profile_id, bool const do_preparation, int64_t const activation_type_int,
669+ int64_t const unpadded_hidden_size)
656670 {
657671 std::lock_guard<std::mutex> lock (mMutex );
658672
@@ -661,6 +675,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
661675 {
662676 return ;
663677 }
678+ ActivationType activation_type = static_cast <ActivationType>(activation_type_int);
664679
665680 int64_t const num_rows = input.sizes ()[0 ];
666681 int64_t hidden_size = fc2_expert_weights.sizes ()[1 ];
@@ -715,14 +730,14 @@ class FusedMoeRunner : public torch::CustomClassHolder
715730 tensorrt_llm::runtime::TorchUtils::dataType (mWeightDtype ),
716731 tensorrt_llm::runtime::TorchUtils::dataType (mOutputDtype ), num_experts, static_cast <int >(top_k),
717732 hidden_size, unpadded_hidden_size > 0 ? unpadded_hidden_size : hidden_size, inter_size, group_size,
718- ActivationType::Swiglu , USE_BIAS, USE_LORA, min_latency_mode,
733+ activation_type , USE_BIAS, USE_LORA, min_latency_mode,
719734 /* need_weights*/ false , parallelism_config, enable_alltoall);
720735#else
721736 mProfiler ->init (*mKernelRunner .get (), mProfiler ->mGemmToProfile ,
722737 tensorrt_llm::runtime::TorchUtils::dataType (activation_dtype),
723738 tensorrt_llm::runtime::TorchUtils::dataType (mWeightDtype ),
724739 tensorrt_llm::runtime::TorchUtils::dataType (mOutputDtype ), num_experts, static_cast <int >(top_k),
725- hidden_size, inter_size, group_size, ActivationType::Swiglu , USE_BIAS, USE_LORA, min_latency_mode,
740+ hidden_size, inter_size, group_size, activation_type , USE_BIAS, USE_LORA, min_latency_mode,
726741 /* need_weights*/ false , parallelism_config);
727742#endif
728743
0 commit comments