2424#include < numeric>
2525#include < random>
2626#include < sstream>
27+ #include < type_traits>
2728
2829#include " tensorrt_llm/common/memoryUtils.h"
2930#include " tensorrt_llm/common/workspace.h"
@@ -865,7 +866,7 @@ void threeStepBuildExpertMapsSortFirstToken(
865866// ============================== Infer GEMM sizes =================================
866867// TODO Could linear search be better for small # experts
867868template <class T >
868- __device__ inline int64_t findTotalEltsLessThanTarget (T const * sorted_indices,
869+ __device__ inline int64_t findTotalEltsLessThanTarget_v1 (T const * sorted_indices,
869870 int64_t const arr_length, T const target) {
870871 int64_t low = 0 , high = arr_length - 1 , target_location = -1 ;
871872 while (low <= high) {
@@ -881,6 +882,48 @@ __device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices,
881882 return target_location + 1 ;
882883}
883884
885+ template <int ARR_LENGTH_CONST, class T >
886+ __device__ inline int64_t findTotalEltsLessThanTarget_v2 (T const * sorted_indices, int64_t const arr_length, T const target) {
887+ if (arr_length != ARR_LENGTH_CONST) {
888+ asm (" trap;" );
889+ }
890+
891+ constexpr unsigned full_mask = 0xffffffffu ;
892+ constexpr int WARP_SZ = 32 ;
893+ const int lane_id = threadIdx .x & (WARP_SZ - 1 );
894+
895+ int local_count = 0 ;
896+ #pragma unroll
897+ for (int k = 0 ; k < ARR_LENGTH_CONST / WARP_SZ; ++k) {
898+ const int idx = lane_id + k * WARP_SZ;
899+ T v = sorted_indices[idx];
900+ local_count += (v < target) ? 1 : 0 ;
901+ }
902+
903+ #pragma unroll
904+ for (int offset = 16 ; offset > 0 ; offset >>= 1 ) {
905+ local_count += __shfl_down_sync (full_mask, local_count, offset);
906+ }
907+ int total = __shfl_sync (full_mask, local_count, 0 );
908+
909+ return (int64_t )total;
910+ }
911+
912+ template <int ARR_LENGTH_CONST, class T >
913+ __device__ inline int64_t findTotalEltsLessThanTarget (T const * sorted_indices, int64_t const arr_length, T const target) {
914+ // return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
915+
916+ return findTotalEltsLessThanTarget_v2<ARR_LENGTH_CONST>(sorted_indices, arr_length, target);
917+
918+ // int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
919+ // int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
920+ // if (out_v1 != out_v2) {
921+ // printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2);
922+ // asm("trap;");
923+ // }
924+ // return out_v1;
925+ }
926+
884927template <class T >
885928using sizeof_bits = cutlass::sizeof_bits<
886929 typename cutlass_kernels::TllmToCutlassTypeAdapter<std::remove_cv_t <T>>::type>;
@@ -1418,16 +1461,19 @@ constexpr static int EXPAND_THREADS_PER_BLOCK = 256;
14181461
14191462template <class InputActivationsType , class ExpandedActivationsType ,
14201463 TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType,
1421- bool PRE_QUANT_AWQ>
1464+ bool PRE_QUANT_AWQ, int NUM_EXPERTS_PER_NODE_CONST = 128 >
14221465__global__ void expandInputRowsKernel (
14231466 InputActivationsType const * unpermuted_input, ExpandedActivationsType* permuted_output,
14241467 float const * unpermuted_scales, float * permuted_scales,
1425- int const * permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size ,
1468+ int const * permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size_real_ ,
14261469 int64_t const k, float const * fc1_act_global_scale, bool use_per_expert_act_scale,
14271470 int64_t const * expert_first_token_offset,
14281471 TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
14291472 TmaWarpSpecializedGroupedGemmInput::ElementSF const * input_sf,
14301473 int64_t const num_experts_per_node, InputActivationsType const * prequant_scales = nullptr ) {
1474+ constexpr int hidden_size = 7168 ;
1475+ if (hidden_size != hidden_size_real_) { asm (" trap;" ); }
1476+
14311477 static_assert (BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE ||
14321478 !PRE_QUANT_AWQ,
14331479 " AWQ and Block Scaling are mutually exclusive" );
@@ -1503,14 +1549,14 @@ __global__ void expandInputRowsKernel(
15031549 permuted_row * hidden_size / ELEM_PER_THREAD;
15041550
15051551 int64_t const start_offset = threadIdx .x ;
1506- int64_t const stride = EXPAND_THREADS_PER_BLOCK;
1507- int64_t const num_elems_in_col = hidden_size / ELEM_PER_THREAD;
1552+ constexpr int64_t stride = EXPAND_THREADS_PER_BLOCK;
1553+ constexpr int64_t num_elems_in_col = hidden_size / ELEM_PER_THREAD;
15081554 assert (hidden_size % ELEM_PER_THREAD == 0 );
15091555 assert (hidden_size % VecSize == 0 );
15101556
15111557 if constexpr (is_nvfp4 || is_mxfp8) {
15121558 static_assert (ELEM_PER_THREAD == 8 , " Expecting 8 elements per thread for quantized types" );
1513- int64_t expert = findTotalEltsLessThanTarget (expert_first_token_offset, num_experts_per_node,
1559+ int64_t expert = findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST> (expert_first_token_offset, num_experts_per_node,
15141560 (int64_t )permuted_row + 1 ) -
15151561 1 ;
15161562
@@ -1519,6 +1565,7 @@ __global__ void expandInputRowsKernel(
15191565 float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1 .0f ;
15201566 int64_t num_tokens_before_expert = expert_first_token_offset[expert];
15211567
1568+ #pragma unroll
15221569 for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
15231570 auto in_vec = source_row_ptr[elem_index];
15241571 if constexpr (need_nvfp4_quant || need_mxfp8_quant) {
@@ -1687,9 +1734,20 @@ void expandInputRowsKernelLauncher(
16871734 TLLM_CHECK_WITH_INFO (quant_params.fp4 .fc1 .weight_block_scale ,
16881735 " NVFP4 block scaling is expected for FP4xFP4" );
16891736 TLLM_CHECK_WITH_INFO (!prequant_scales, " NVFP4 is not supported for AWQ" );
1690- return &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
1737+ if (num_experts_per_node == 128 ) {
1738+ constexpr int NUM_EXPERTS_PER_NODE_CONST = 128 ;
1739+ return &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
16911740 TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4,
1692- false >;
1741+ false , NUM_EXPERTS_PER_NODE_CONST>;
1742+ }
1743+ if (num_experts_per_node == 64 ) {
1744+ constexpr int NUM_EXPERTS_PER_NODE_CONST = 64 ;
1745+ return &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
1746+ TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4,
1747+ false , NUM_EXPERTS_PER_NODE_CONST>;
1748+ }
1749+ printf (" unsupported num_experts_per_node\n " );
1750+ exit (1 );
16931751 } else
16941752#endif
16951753 {
@@ -1748,11 +1806,16 @@ constexpr static int FINALIZE_THREADS_PER_BLOCK = 256;
17481806// This kernel unpermutes the original data, does the k-way reduction and performs the final skip
17491807// connection.
17501808template <typename OutputType, class GemmOutputType , class ScaleBiasType , ScaleMode SCALE_MODE>
1751- __global__ void finalizeMoeRoutingKernel (
1809+ __global__
1810+ __maxnreg__ (64 )
1811+ void finalizeMoeRoutingKernel (
17521812 GemmOutputType const * expanded_permuted_rows, OutputType* reduced_unpermuted_output,
17531813 ScaleBiasType const * bias, float const * scales, int const * unpermuted_row_to_permuted_row,
1754- int const * token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token ,
1814+ int const * token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_ ,
17551815 int const num_experts_per_node, int const start_expert_id) {
1816+ constexpr int experts_per_token = 8 ;
1817+ if (experts_per_token != experts_per_token_real_) { asm (" trap;" ); }
1818+
17561819 int64_t const original_row = blockIdx .x ;
17571820 int64_t const num_rows = gridDim .x ;
17581821 auto const offset = original_row * orig_cols;
@@ -2078,7 +2141,7 @@ __global__ void doGatedActivationKernel(ActivationOutputType* output,
20782141 float gate_bias = 0 .0f ;
20792142 float gate_limit = std::numeric_limits<float >::infinity ();
20802143 if (activation_type.swiglu_alpha || activation_type.swiglu_beta || activation_type.swiglu_limit ) {
2081- int expert = findTotalEltsLessThanTarget (expert_first_token_offset, num_experts_per_node,
2144+ int expert = findTotalEltsLessThanTarget< 128 > (expert_first_token_offset, num_experts_per_node,
20822145 (int64_t )token + 1 ) -
20832146 1 ;
20842147 gate_alpha = activation_type.swiglu_alpha ? activation_type.swiglu_alpha [expert] : 1 .0f ;
@@ -2126,14 +2189,17 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_
21262189// ============================== Activation =================================
21272190
21282191template <class T , class GemmOutputType , class ScaleBiasType , class ActFn ,
2129- TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType>
2192+ TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType, int NUM_EXPERTS_PER_NODE_CONST = 128 >
21302193__global__ void doActivationKernel (T* output, GemmOutputType const * gemm_result,
21312194 float const * fp8_quant, ScaleBiasType const * bias_ptr,
21322195 bool bias_is_broadcast, int64_t const * expert_first_token_offset,
2133- int num_experts_per_node, int64_t inter_size ,
2196+ int num_experts_per_node, int64_t inter_size_real_ ,
21342197 float const * fc2_act_global_scale, bool use_per_expert_act_scale,
21352198 TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat,
21362199 ActivationParams activation_params) {
2200+ constexpr int inter_size = 2048 ;
2201+ if (inter_size != inter_size_real_) { asm (" trap;" ); }
2202+
21372203#ifdef ENABLE_FP4
21382204 constexpr bool IsNVFP4 =
21392205 std::is_same_v<T, __nv_fp4_e2m1> &&
@@ -2186,7 +2252,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
21862252 activation_params.swiglu_limit ) {
21872253 // TODO this is almost certainly faster as a linear scan
21882254 expert =
2189- findTotalEltsLessThanTarget (expert_first_token_offset, num_experts_per_node, token + 1 ) -
2255+ findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST> (expert_first_token_offset, num_experts_per_node, token + 1 ) -
21902256 1 ;
21912257 gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha [expert] : 1 .0f ;
21922258 gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta [expert] : 0 .0f ;
@@ -2218,16 +2284,18 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
22182284 auto output_vec = reinterpret_cast <OutputElem*>(safe_inc_ptr (output, output_offset));
22192285 auto bias_ptr_vec = reinterpret_cast <BiasElem const *>(bias_ptr + bias_offset);
22202286 int64_t const start_offset = tid;
2221- int64_t const stride = ACTIVATION_THREADS_PER_BLOCK;
2287+ constexpr int64_t stride = ACTIVATION_THREADS_PER_BLOCK;
22222288 assert (inter_size % ACTIVATION_ELEM_PER_THREAD == 0 );
2223- int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
2289+ constexpr int64_t num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
22242290 assert (gated_off % ACTIVATION_ELEM_PER_THREAD == 0 );
22252291 int64_t const gated_off_vec = gated_off / ACTIVATION_ELEM_PER_THREAD;
22262292
22272293 ActFn fn{};
22282294 fn.alpha = gate_alpha;
22292295 fn.beta = gate_beta;
22302296 fn.limit = gate_limit;
2297+
2298+ #pragma unroll
22312299 for (int64_t elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
22322300 auto fc1_value =
22332301 arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index + gated_off_vec]);
@@ -2358,30 +2426,62 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
23582426
23592427 auto fn = [&]() {
23602428 auto fn = [&](auto block_scaling_type) {
2361- auto fn_list = std::array{
2362- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2363- IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2364- decltype (block_scaling_type)::value>, // Gelu
2365- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2366- IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2367- decltype (block_scaling_type)::value>, // Relu
2368- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2369- IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2370- decltype (block_scaling_type)::value>, // Silu
2371- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2372- GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2373- decltype (block_scaling_type)::value>, // Swiglu
2374- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2375- GLUAdaptor<cutlass::epilogue::thread::GELU>,
2376- decltype (block_scaling_type)::value>, // Geglu
2377- &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2378- decltype (block_scaling_type)::value>, // SwigluBias
2379- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2380- IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2381- decltype (block_scaling_type)::value> // Identity
2382-
2383- };
2384- return fn_list[static_cast <int >(activation_type.activation_type )];
2429+ if (num_experts_per_node == 128 ) {
2430+ constexpr int NUM_EXPERTS_PER_NODE_CONST = 128 ;
2431+ auto fn_list = std::array{
2432+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2433+ IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2434+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Gelu
2435+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2436+ IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2437+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Relu
2438+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2439+ IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2440+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Silu
2441+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2442+ GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2443+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Swiglu
2444+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2445+ GLUAdaptor<cutlass::epilogue::thread::GELU>,
2446+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Geglu
2447+ &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2448+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // SwigluBias
2449+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2450+ IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2451+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST> // Identity
2452+
2453+ };
2454+ return fn_list[static_cast <int >(activation_type.activation_type )];
2455+ }
2456+ if (num_experts_per_node == 64 ) {
2457+ constexpr int NUM_EXPERTS_PER_NODE_CONST = 64 ;
2458+ auto fn_list = std::array{
2459+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2460+ IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2461+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Gelu
2462+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2463+ IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2464+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Relu
2465+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2466+ IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2467+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Silu
2468+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2469+ GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2470+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Swiglu
2471+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2472+ GLUAdaptor<cutlass::epilogue::thread::GELU>,
2473+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Geglu
2474+ &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2475+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // SwigluBias
2476+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2477+ IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2478+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST> // Identity
2479+
2480+ };
2481+ return fn_list[static_cast <int >(activation_type.activation_type )];
2482+ }
2483+ printf (" unsupported num_experts_per_node\n " );
2484+ exit (1 );
23852485 };
23862486 auto NVFP4 = tensorrt_llm::common::ConstExprWrapper<
23872487 TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
0 commit comments