Skip to content

Commit 11132fe

Browse files
nzmora-nvidiaWanli-Jiang
authored andcommitted
Add relu2 to kernel and python api
Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Fixes and UT Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Use trtllm moe for relu2 mlp case Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Fix the runGemmProfile Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Replace the FP8 fused MoE backend Before: torch.ops.auto_deploy.triton_quant_fp8_moe After: torch.ops.auto_deploy.trtllm_quant_fp8moe_fused Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Code refactoring Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> syntax error fixes Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> remove dead code Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> fix moe operator function name Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Add skips if not hopper+ Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> remove unused code Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
1 parent 6e8037a commit 11132fe

File tree

9 files changed

+722
-31
lines changed

9 files changed

+722
-31
lines changed

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,30 @@ __forceinline__ __device__ float tanh_opt(float x)
5959
#endif
6060
}
6161

62+
template <typename T>
63+
struct Relu2
64+
{
65+
static bool const kIsHeavy = false;
66+
67+
CUTLASS_HOST_DEVICE
68+
T operator()(T threshold, T value) const
69+
{
70+
ReLu<T> relu_op;
71+
multiplies<T> mul;
72+
T val = relu_op(threshold, value);
73+
return mul(val, val);
74+
}
75+
76+
CUTLASS_HOST_DEVICE
77+
T operator()(T value) const
78+
{
79+
ReLu<T> relu_op;
80+
multiplies<T> mul;
81+
T val = relu_op(value);
82+
return mul(val, val);
83+
}
84+
};
85+
6286
} // namespace thread
6387
} // namespace epilogue
6488
} // namespace cutlass

cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ enum class ActivationType
2828
Swiglu,
2929
Geglu,
3030
SwigluBias,
31+
Relu2,
3132
Identity,
3233
InvalidType
3334
};

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,7 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::moeGemmBiasAct(
954954
case ActivationType::Identity: runGemm<cutlass_extensions::EpilogueOpDefault>(inputs, hopper_inputs); break;
955955
case ActivationType::Swiglu: runGemm<cutlass_extensions::EpilogueOpDefaultSilu>(inputs, hopper_inputs); break;
956956
case ActivationType::Geglu: runGemm<cutlass_extensions::EpilogueOpDefaultFtGelu>(inputs, hopper_inputs); break;
957+
case ActivationType::Relu2: TLLM_THROW("Relu2 is not supported."); break;
957958
case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break;
958959
default: TLLM_THROW("Invalid activation type."); break;
959960
}

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2307,6 +2307,8 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
23072307
decltype(block_scaling_type)::value>, // Geglu
23082308
&doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
23092309
decltype(block_scaling_type)::value>, // SwigluBias
2310+
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::Relu2>,
2311+
decltype(block_scaling_type)::value>, // Relu2
23102312
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
23112313
IdentityAdaptor<cutlass::epilogue::thread::Identity>,
23122314
decltype(block_scaling_type)::value> // Identity

cpp/tensorrt_llm/thop/moeOp.cpp

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
259259
torch::optional<torch::Tensor> const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank,
260260
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
261261
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids,
262-
torch::optional<int64_t> const& unpadded_hidden_size, torch::optional<int64_t> const& num_valid_tokens,
263-
torch::optional<torch::Tensor> const& out_tensor)
262+
torch::optional<int64_t> const& activation_type, torch::optional<int64_t> const& unpadded_hidden_size,
263+
torch::optional<int64_t> const& num_valid_tokens, torch::optional<torch::Tensor> const& out_tensor)
264264
{
265265
std::lock_guard<std::mutex> lock(mMutex);
266266
// Free the profile workspace to save memory
@@ -328,6 +328,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
328328
TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc2_expert_weights.sizes()[0],
329329
"fc1_expert_weights and fc2_expert_weights must have the same number of experts.");
330330

331+
ActivationType base_activation_type = activation_type.has_value()
332+
? static_cast<ActivationType>(activation_type.value())
333+
: ActivationType::Swiglu;
331334
if (mUseINT8WoqPerChannel)
332335
{
333336
// Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
@@ -337,8 +340,16 @@ class FusedMoeRunner : public torch::CustomClassHolder
337340
}
338341
else
339342
{
340-
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2,
341-
"fc1_expert_weights inter size must be fc2_expert_weights inter size.");
343+
if (isGatedActivation(base_activation_type))
344+
{
345+
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2,
346+
"fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size.");
347+
}
348+
else
349+
{
350+
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier,
351+
"fc1_expert_weights inter size must be equal to fc2_expert_weights inter size.");
352+
}
342353
}
343354

344355
int experts_per_token = token_selected_experts.sizes()[1];
@@ -375,7 +386,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
375386
int const num_experts_on_rank = fc2_expert_weights.sizes()[0];
376387
auto const num_experts_total = static_cast<int>(num_experts_on_rank * ep_size);
377388
auto parallelism_config = kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank);
378-
ActivationType base_activation_type = ActivationType::Swiglu;
389+
379390
if (swiglu_alpha.has_value())
380391
{
381392
CHECK_INPUT(swiglu_alpha.value(), at::ScalarType::Float);
@@ -474,8 +485,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
474485
torch::optional<torch::Tensor> const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank,
475486
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
476487
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids,
477-
torch::optional<int64_t> const& unpadded_hidden_size, torch::optional<int64_t> const& num_valid_tokens,
478-
torch::optional<torch::Tensor> const& out_tensor)
488+
torch::optional<int64_t> const& activation_type, torch::optional<int64_t> const& unpadded_hidden_size,
489+
torch::optional<int64_t> const& num_valid_tokens, torch::optional<torch::Tensor> const& out_tensor)
479490
{
480491
std::lock_guard<std::mutex> lock(mMutex);
481492

@@ -541,7 +552,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
541552
auto const num_experts_total = static_cast<int>(num_experts_on_rank * ep_size);
542553
auto parallelism_config
543554
= kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank, cluster_size, cluster_rank);
544-
ActivationType base_activation_type = ActivationType::Swiglu;
555+
ActivationType base_activation_type = activation_type.has_value()
556+
? static_cast<ActivationType>(activation_type.value())
557+
: ActivationType::Swiglu;
545558
if (swiglu_alpha.has_value())
546559
{
547560
CHECK_INPUT(swiglu_alpha.value(), at::ScalarType::Float);
@@ -652,7 +665,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
652665
torch::optional<torch::Tensor> const& fc2_expert_biases, int64_t const top_k, int64_t const tp_size,
653666
int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size,
654667
int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode, int64_t const gemm_idx,
655-
int64_t const profile_id, bool const do_preparation, int64_t const unpadded_hidden_size)
668+
int64_t const profile_id, bool const do_preparation, int64_t const activation_type_int,
669+
int64_t const unpadded_hidden_size)
656670
{
657671
std::lock_guard<std::mutex> lock(mMutex);
658672

@@ -661,6 +675,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
661675
{
662676
return;
663677
}
678+
ActivationType activation_type = static_cast<ActivationType>(activation_type_int);
664679

665680
int64_t const num_rows = input.sizes()[0];
666681
int64_t hidden_size = fc2_expert_weights.sizes()[1];
@@ -715,14 +730,14 @@ class FusedMoeRunner : public torch::CustomClassHolder
715730
tensorrt_llm::runtime::TorchUtils::dataType(mWeightDtype),
716731
tensorrt_llm::runtime::TorchUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k),
717732
hidden_size, unpadded_hidden_size > 0 ? unpadded_hidden_size : hidden_size, inter_size, group_size,
718-
ActivationType::Swiglu, USE_BIAS, USE_LORA, min_latency_mode,
733+
activation_type, USE_BIAS, USE_LORA, min_latency_mode,
719734
/*need_weights*/ false, parallelism_config, enable_alltoall);
720735
#else
721736
mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile,
722737
tensorrt_llm::runtime::TorchUtils::dataType(activation_dtype),
723738
tensorrt_llm::runtime::TorchUtils::dataType(mWeightDtype),
724739
tensorrt_llm::runtime::TorchUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k),
725-
hidden_size, inter_size, group_size, ActivationType::Swiglu, USE_BIAS, USE_LORA, min_latency_mode,
740+
hidden_size, inter_size, group_size, activation_type, USE_BIAS, USE_LORA, min_latency_mode,
726741
/*need_weights*/ false, parallelism_config);
727742
#endif
728743

0 commit comments

Comments
 (0)