@@ -664,6 +664,76 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
664664 )
665665
666666
667+ @triton .jit
668+ def compute_identity_kernel (
669+ top_k : int ,
670+ hidden_states_ptr : tl .tensor ,
671+ expert_scales_ptr : tl .tensor ,
672+ num_tokens : int ,
673+ output_ptr : tl .tensor ,
674+ hidden_dim : int ,
675+ scales_stride : int ,
676+ BLOCK_SIZE : tl .constexpr ,
677+ ) -> None :
678+ pid = tl .program_id (0 )
679+
680+ batch_id = pid // (hidden_dim // BLOCK_SIZE )
681+ dim_offset = pid % (hidden_dim // BLOCK_SIZE ) * BLOCK_SIZE
682+
683+ if batch_id >= num_tokens or dim_offset >= hidden_dim :
684+ return
685+
686+ h = tl .load (hidden_states_ptr + batch_id * hidden_dim + dim_offset +
687+ tl .arange (0 , BLOCK_SIZE ),
688+ mask = (dim_offset + tl .arange (0 , BLOCK_SIZE )) < hidden_dim )
689+
690+ result = tl .zeros ([BLOCK_SIZE ], dtype = tl .float32 )
691+ for i in range (top_k ):
692+ scale = tl .load (expert_scales_ptr + batch_id * scales_stride + i )
693+ result += h * scale
694+
695+ tl .store (output_ptr + batch_id * hidden_dim + dim_offset +
696+ tl .arange (0 , BLOCK_SIZE ),
697+ result ,
698+ mask = (dim_offset + tl .arange (0 , BLOCK_SIZE )) < hidden_dim )
699+
700+
701+ def zero_experts_compute_triton (expert_indices : torch .Tensor ,
702+ expert_scales : torch .Tensor , num_experts : int ,
703+ zero_expert_type : str ,
704+ hidden_states : torch .Tensor ) -> torch .Tensor :
705+ N = expert_indices .numel ()
706+ top_k = expert_indices .size (- 1 )
707+ grid = lambda meta : (triton .cdiv (N , meta ['BLOCK_SIZE' ]), )
708+
709+ if zero_expert_type == "identity" :
710+ zero_expert_mask = expert_indices < num_experts
711+ zero_expert_scales = expert_scales .clone ()
712+ zero_expert_scales [zero_expert_mask ] = 0.0
713+
714+ normal_expert_mask = expert_indices >= num_experts
715+ expert_indices [normal_expert_mask ] = 0
716+ expert_scales [normal_expert_mask ] = 0.0
717+
718+ output = torch .zeros_like (hidden_states ).to (hidden_states .device )
719+ hidden_dim = hidden_states .size (- 1 )
720+ num_tokens = hidden_states .size (0 )
721+
722+ grid = lambda meta : (num_tokens * (hidden_dim // meta ['BLOCK_SIZE' ]), )
723+ compute_identity_kernel [grid ](
724+ top_k ,
725+ hidden_states ,
726+ zero_expert_scales ,
727+ num_tokens ,
728+ output ,
729+ hidden_dim ,
730+ zero_expert_scales .stride (0 ),
731+ BLOCK_SIZE = 256 ,
732+ )
733+
734+ return output
735+
736+
667737# Adapted from: https://github.com/sgl-project/sglang/pull/2628
668738def get_config_file_name (E : int ,
669739 N : int ,
@@ -940,6 +1010,25 @@ def fused_topk(
9401010 return topk_weights , topk_ids , token_expert_indices
9411011
9421012
1013+ def fused_topk_bias (
1014+ hidden_states : torch .Tensor ,
1015+ gating_output : torch .Tensor ,
1016+ e_score_correction_bias : torch .Tensor ,
1017+ topk : int ,
1018+ renormalize : bool ,
1019+ ):
1020+ n_routed_experts = gating_output .shape [- 1 ]
1021+ scores = gating_output .softmax (dim = - 1 )
1022+ scores_for_choice = scores .view (
1023+ - 1 , n_routed_experts ) + e_score_correction_bias .unsqueeze (0 )
1024+ topk_indices = torch .topk (scores_for_choice , k = topk , dim = - 1 ,
1025+ sorted = False )[1 ]
1026+ topk_weights = scores .gather (1 , topk_indices )
1027+ if renormalize :
1028+ topk_weights = topk_weights / topk_weights .sum (dim = - 1 , keepdim = True )
1029+ return topk_weights .to (torch .float32 ), topk_indices .to (torch .int32 )
1030+
1031+
9431032# This is used by the Deepseek-V2 and Deepseek-V3 model
9441033@torch .compile (dynamic = True , backend = current_platform .simple_compile_backend )
9451034def grouped_topk (
0 commit comments