From dd9438615f08eba927e96df81685e5366d14dd00 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 21 Aug 2025 16:17:23 -0700 Subject: [PATCH 1/5] add scattermoe kernel --- .../kernels/scattermoe/__init__.py | 185 ++++++++++++++++++ .../scattermoe/group_backward_kernel.py | 125 ++++++++++++ .../kernels/scattermoe/group_kernel.py | 101 ++++++++++ .../kernels/scattermoe/scatter_kernel.py | 180 +++++++++++++++++ 4 files changed, 591 insertions(+) create mode 100644 src/transformers/kernels/scattermoe/__init__.py create mode 100644 src/transformers/kernels/scattermoe/group_backward_kernel.py create mode 100644 src/transformers/kernels/scattermoe/group_kernel.py create mode 100644 src/transformers/kernels/scattermoe/scatter_kernel.py diff --git a/src/transformers/kernels/scattermoe/__init__.py b/src/transformers/kernels/scattermoe/__init__.py new file mode 100644 index 000000000000..442f0e2400b4 --- /dev/null +++ b/src/transformers/kernels/scattermoe/__init__.py @@ -0,0 +1,185 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +import torch + +from transformers.kernels.scattermoe.group_backward_kernel import group_bwd_W +from transformers.kernels.scattermoe.group_kernel import group +from transformers.kernels.scattermoe.scatter_kernel import scatter2scatter + + +class _ScatteredExperts(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + gates=None, + grouped_in=False, + grouped_out=False, + ): + output = torch.empty(sorted_expert_idxs.size(0), expert_weights.size(-1), device=x.device, dtype=x.dtype) + + scatter2scatter( + X=x, + W=expert_weights, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + out=output, + FAN_OUT=k, + x_grouped=grouped_in, + y_grouped=grouped_out, + ) + + if gates is None: + output_expanded = None + else: + output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1)) + output = torch.bmm(gates.unsqueeze(1), output_expanded).squeeze(1) + + ctx.save_for_backward( + x, + expert_weights, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + gates, + output_expanded, + ) + + ctx.grouped_in = grouped_in + ctx.grouped_out = grouped_out + ctx.k = k + + return output + + @staticmethod + def backward(ctx, grad_out): + ( + x, + expert_weights, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + gates, + output_expanded, + ) = ctx.saved_tensors + k = ctx.k + grouped_in = ctx.grouped_in + grouped_out = ctx.grouped_out + + if gates is None: + d_gates = None + gates_flat = None + gate_fan = 1 + grouped_grad_out = None + else: + # calculate gates gradient + d_gates = torch.bmm(output_expanded, grad_out.unsqueeze(2)).squeeze(-1) + gates_flat = gates.flatten() + gate_fan = gates.size(1) + # print("expanded and grouping") + grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later + + if grouped_out: + grouped_grad_out = grad_out + else: + if grouped_grad_out is None: + if gate_fan == 1: + grouped_grad_out = torch.empty_like(grad_out) + else: + raise RuntimeError("Need to infer size") + group( + A=grad_out, + sorted_expert_idxs=sorted_scattered_idxs, + out=grouped_grad_out, + coeff=gates_flat, + fan_out=gate_fan, + ) + + if grouped_in: + grouped_x = x + d_expanded_input = torch.empty( + sorted_expert_idxs.size(0), expert_weights.size(1), device=x.device, dtype=x.dtype + ) + else: + grouped_x = torch.empty(sorted_scattered_idxs.size(0), x.size(1), dtype=x.dtype, device=x.device) + group( + A=x, + sorted_expert_idxs=sorted_scattered_idxs, + out=grouped_x, + fan_out=k, + ) + + d_expanded_input = grouped_x + + d_weights = torch.zeros_like(expert_weights) + + group_bwd_W( + DY=grouped_grad_out, + X=grouped_x, + expert_offsets=expert_offsets, + DW=d_weights, + E=expert_weights.size(0), + ) + + scatter2scatter( + X=grouped_grad_out, + W=expert_weights.permute(0, 2, 1), + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + out=d_expanded_input, + FAN_OUT=1, + x_grouped=True, + y_grouped=grouped_in, + ) + + if k == 1: + d_input = d_expanded_input + else: + d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2) + + return ( + # x, expert_weights, k, + d_input, + d_weights, + None, + # sorted_expert_idxs, sorted_scattered_idxs, + None, + None, + # expert_offsets, + None, + # gates + d_gates, + None, + None, + ) + + +def scattered_experts( + inputs, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + gates=None, + grouped_in=False, + grouped_out=False, +): + return _ScatteredExperts.apply( + inputs, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + gates, + grouped_in, + grouped_out, + ) diff --git a/src/transformers/kernels/scattermoe/group_backward_kernel.py b/src/transformers/kernels/scattermoe/group_backward_kernel.py new file mode 100644 index 000000000000..cb0cb6ab1e14 --- /dev/null +++ b/src/transformers/kernels/scattermoe/group_backward_kernel.py @@ -0,0 +1,125 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +import torch +import triton +import triton.language as tl +from torch.library import custom_op + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 128}, num_stages=1, num_warps=4), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 64}, num_stages=2, num_warps=4), + ], + key=["N", "K"], +) +@triton.jit +def groupXtY_triton_kernel( + DY_ptr, + stride_dym, + stride_dyk, + X_ptr, + stride_xm, + stride_xn, + DW_ptr, + stride_dwe, + stride_dwk, + stride_dwn, + expert_offsets_ptr, + K: tl.constexpr, + N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + num0 = tl.num_programs(0) + num1 = tl.num_programs(1) + pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128) + + K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) + E_idx = pid0 // K_BLOCK_COUNT + K_block_id = pid0 % K_BLOCK_COUNT + N_block_id = pid1 + + if E_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) + + end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) + + if end_idx > start_idx: + M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M) + + K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + K_mask = K_block < K + K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) + + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) + + M_idxs = M_block + xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm + dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk + + acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=tl.float32) + + iters = tl.cdiv(end_idx - start_idx, BLOCK_M) + + no_k_mask = K % BLOCK_K == 0 + no_n_mask = N % BLOCK_N == 0 + + for i in range(iters): + M_mask = (i * BLOCK_M + M_block) < end_idx + + if no_k_mask: + xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :]) + else: + xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) + + if no_n_mask: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None]) + else: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) + + xt_blk_ptrs += BLOCK_M * stride_xm + dy_blk_ptrs += BLOCK_M * stride_dym + acc = tl.dot(xt, dy, acc, allow_tf32=True) + + DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn + acc = acc.to(DW_blk_ptrs.dtype.element_ty) + tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) + + +@custom_op("transformers::group_bwd_W", mutates_args={"DW"}) +def group_bwd_W(DY: torch.Tensor, X: torch.Tensor, expert_offsets: torch.Tensor, DW: torch.Tensor, E: int) -> None: + def grid(meta): + return (E * triton.cdiv(meta["K"], meta["BLOCK_K"]), triton.cdiv(meta["N"], meta["BLOCK_N"])) + + with torch.device(X.device): + groupXtY_triton_kernel[grid]( + # DY_ptr, stride_dym, stride_dyk, + DY, + DY.stride(0), + DY.stride(1), + # X_ptr, stride_xm, stride_xn, + X, + X.stride(0), + X.stride(1), + # DW_ptr, stride_dwe, stride_dwk, stride_dwn, + DW, + DW.stride(0), + DW.stride(1), + DW.stride(2), + # expert_offsets_ptr, + expert_offsets, + # K: tl.constexpr, N: tl.constexpr, + N=DY.size(-1), + K=X.size(-1), + ) diff --git a/src/transformers/kernels/scattermoe/group_kernel.py b/src/transformers/kernels/scattermoe/group_kernel.py new file mode 100644 index 000000000000..495cd72c3297 --- /dev/null +++ b/src/transformers/kernels/scattermoe/group_kernel.py @@ -0,0 +1,101 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +import torch +import triton +import triton.language as tl +from torch.library import custom_op + + +@triton.autotune(configs=[triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4)], key=["K"]) +@triton.jit +def group_triton_kernel( + src_ptr, + stride_sn, + stride_sk, + has_coeff: tl.constexpr, + coeff_ptr, + FAN_OUT, + tgt_ptr, + stride_tn, + stride_ti, + grouped_idx_ptr, + N, + K: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + N_block_id = tl.program_id(axis=0) + + N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_blk < N + N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N) + N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask) + + K_blk = tl.arange(0, BLOCK_K) + src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk + tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti + + if has_coeff: + c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None] + + iters = tl.cdiv(K, BLOCK_K) + no_k_mask = K % BLOCK_K == 0 + + for i in range(iters): + if no_k_mask or i < iters - 1: + block = tl.load(src_blk_ptrs, mask=N_mask[:, None]) + + if has_coeff: + block *= c + + tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None]) + else: + K_mask = (i * BLOCK_K + K_blk) < K + mask = N_mask[:, None] & K_mask[None, :] + block = tl.load(src_blk_ptrs, mask=mask) + + if has_coeff: + block *= c + + tl.store(tgt_blk_ptrs, block, mask=mask) + + src_blk_ptrs += BLOCK_K * stride_sk + tgt_blk_ptrs += BLOCK_K * stride_ti + + +@custom_op("transformers::group", mutates_args={"out"}) +def group( + A: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + out: torch.Tensor, + coeff: torch.Tensor | None = None, + fan_out: int = 1, +) -> None: + N = sorted_expert_idxs.size(0) + K = A.size(1) + assert A.size(0) * fan_out == N + + def grid(meta): + return (triton.cdiv(meta["N"], meta["BLOCK_N"]),) + + with torch.device(A.device): + group_triton_kernel[grid]( + # A_ptr, stride_an, stride_ai, + A, + A.stride(0), + A.stride(1), + coeff is not None, + coeff, + fan_out, + # Y_ptr, stride_yn, stride_yk, + out, + out.stride(0), + out.stride(1), + # grouped_idx_ptr, + sorted_expert_idxs, + # N: tl.constexpr, K: tl.constexpr, + N, + K, + ) diff --git a/src/transformers/kernels/scattermoe/scatter_kernel.py b/src/transformers/kernels/scattermoe/scatter_kernel.py new file mode 100644 index 000000000000..c4ac1e1061c1 --- /dev/null +++ b/src/transformers/kernels/scattermoe/scatter_kernel.py @@ -0,0 +1,180 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +import torch +import triton +import triton.language as tl +from torch.library import custom_op + + +@triton.jit +def _compute_expert_block( + E_idx, + E_mask, + M_in_idx, + N_block, + N_mask, + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_we, + stride_wk, + stride_wn, + K, + acc, + no_k_mask, + BLOCK_K, +): + K_block = tl.arange(0, BLOCK_K) + X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we + iters = tl.cdiv(K, BLOCK_K) + + for K_block_id in range(iters): + if no_k_mask: + x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) + w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) + else: + K_mask = (K_block_id * BLOCK_K + K_block) < K + x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) + w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) + + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + acc = tl.dot(x, w, acc, allow_tf32=True) + + return acc + + +@triton.autotune( + configs=[triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4)], + key=["N", "K"], +) +@triton.jit +def scatter2scatter_triton_kernel( + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_we, + stride_wk, + stride_wn, + Y_ptr, + stride_ym, + stride_yn, + grouped_idx_ptr, + expert_idxs_ptr, + FAN_OUT, + M, + K: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + x_grouped, + y_grouped, +): + pid = tl.program_id(axis=0) + + N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N) + M_block_id = pid // N_BLOCK_COUNT + N_block_id = pid % N_BLOCK_COUNT + + M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M) + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + M_boundary_mask = M_block < (FAN_OUT * M) + E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E) + + no_k_mask = K % BLOCK_K == 0 + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + E_first_idx = tl.min(E_idxs) + E_last_idx = tl.minimum(tl.max(E_idxs), E - 1) + M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32) + + for E_idx in range(E_first_idx, E_last_idx + 1): + E_mask = E_idxs == E_idx + E_M_idx = M_idx + if x_grouped: + M_in_idx = M_block + else: + M_in_idx = E_M_idx // FAN_OUT + + acc = _compute_expert_block( + E_idx, + E_mask, + M_in_idx, + N_block, + N_mask, + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_we, + stride_wk, + stride_wn, + K, + acc, + no_k_mask, + BLOCK_K, + ) + if y_grouped: + M_out_idx = M_block + else: + M_out_idx = M_idx + + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :]) + + +@custom_op("transformers::scatter2scatter", mutates_args={"out"}) +def scatter2scatter( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + out: torch.Tensor, + FAN_OUT: int, + x_grouped: bool = False, + y_grouped: bool = False, +) -> None: + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * FAN_OUT + assert out.size(0) == sorted_expert_idxs.size(0) + assert out.size(1) == W.size(-1) + + def grid(meta): + return (triton.cdiv(sorted_expert_idxs.size(0), meta["BLOCK_M"]) * triton.cdiv(meta["N"], meta["BLOCK_N"]),) + + BLOCK_M = 128 + + with torch.device(X.device): + scatter2scatter_triton_kernel[grid]( + # X_ptr, stride_xm, stride_xk, + X, + X.stride(0), + X.stride(1), + # W_ptr, stride_we, stride_wk, stride_wn, + W, + W.stride(0), + W.stride(1), + W.stride(2), + # Y_ptr, stride_ym, stride_yn, + out, + out.stride(0), + out.stride(1), + grouped_idx_ptr=sorted_scattered_idxs, + expert_idxs_ptr=sorted_expert_idxs, + FAN_OUT=FAN_OUT, + M=X.size(0), + K=X.size(1), + N=out.size(1), + E=W.size(0), + BLOCK_M=BLOCK_M, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) From 013458c1caaacbd32680c030827296fa65d6fe99 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 21 Aug 2025 16:32:45 -0700 Subject: [PATCH 2/5] add scattermoe kernel to granitemoe --- .../models/granitemoe/modeling_granitemoe.py | 111 ++++++++++++++---- 1 file changed, 86 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 012a0581a808..27ec0e206986 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -29,6 +29,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import auto_docstring, is_torch_flex_attn_available, logging from ...utils.deprecation import deprecate_kwarg +from ...kernels.scattermoe import scattered_experts from .configuration_granitemoe import GraniteMoeConfig @@ -317,6 +318,17 @@ def forward(self, hidden_states): return index_sorted_experts, batch_index, batch_gates, expert_size, logits +# TODO add support for combileable bincount in PyTorch directly +@torch.library.custom_op("transformers::bincount", mutates_args={}) +def bincount(x: torch.Tensor, minlength: int) -> torch.Tensor: + return x.bincount(minlength=minlength).to(torch.uint32) + + +@bincount.register_fake +def _(x: torch.Tensor, minlength: int) -> torch.Tensor: + return torch.empty(minlength, device=x.device, dtype=torch.uint32) + + class GraniteMoeMoE(nn.Module): """ A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts. @@ -341,36 +353,85 @@ def __init__(self, config: GraniteMoeConfig): top_k=config.num_experts_per_tok, ) - def forward(self, layer_input): - """ - Forward pass of the mixture of experts layer. - - Args: - layer_input (Tensor): - Input tensor. - - Returns: - Tensor: - Output tensor. - Tensor: - Router logits. - """ - bsz, length, emb_size = layer_input.size() - layer_input = layer_input.reshape(-1, emb_size) - _, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input) + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # def forward(self, layer_input): + # """ + # Forward pass of the mixture of experts layer. + + # Args: + # layer_input (Tensor): + # Input tensor. + + # Returns: + # Tensor: + # Output tensor. + # Tensor: + # Router logits. + # """ + # bsz, length, emb_size = layer_input.size() + # layer_input = layer_input.reshape(-1, emb_size) + # _, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input) + + # expert_inputs = layer_input[batch_index] + # hidden_states = self.input_linear(expert_inputs, expert_size) + # chunked_hidden_states = hidden_states.chunk(2, dim=-1) + # hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1] + # expert_outputs = self.output_linear(hidden_states, expert_size) + + # expert_outputs = expert_outputs * batch_gates[:, None] + + # zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device) + # layer_output = zeros.index_add(0, batch_index, expert_outputs) + # layer_output = layer_output.view(bsz, length, self.input_size) + # return layer_output, router_logits + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + original_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + + router_logits = self.router.layer(hidden_states) + router_weights, selected_experts = router_logits.topk(self.top_k, dim=-1) + + router_weights = F.softmax(router_weights.float(), dim=-1) + router_weights = router_weights.type_as(hidden_states) + + with torch.no_grad(): + sorted_expert_idxs, sorted_scattered_idxs = selected_experts.flatten().sort() + expert_frequency = bincount(x=sorted_expert_idxs, minlength=self.num_experts) + expert_offsets = expert_frequency.cumsum(-1) + + hidden_states = scattered_experts( + inputs=hidden_states, + expert_weights=self.input_linear.weight.permute(0, 2, 1), + k=self.top_k, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + expert_offsets=expert_offsets, + gates=None, + grouped_in=False, + grouped_out=True, + ) - expert_inputs = layer_input[batch_index] - hidden_states = self.input_linear(expert_inputs, expert_size) chunked_hidden_states = hidden_states.chunk(2, dim=-1) hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1] - expert_outputs = self.output_linear(hidden_states, expert_size) - expert_outputs = expert_outputs * batch_gates[:, None] + hidden_states = scattered_experts( + inputs=hidden_states, + expert_weights=self.weight.permute(0, 2, 1), + k=1, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + expert_offsets=expert_offsets, + gates=router_weights, + grouped_in=True, + grouped_out=False, + ) + + hidden_states = hidden_states.view(original_shape) - zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device) - layer_output = zeros.index_add(0, batch_index, expert_outputs) - layer_output = layer_output.view(bsz, length, self.input_size) - return layer_output, router_logits + return hidden_states, router_logits # Copied from transformers.models.granite.modeling_granite.repeat_kv with Granite->GraniteMoe From 83f2722fc857b586d8cb00e2da6c6029142cb476 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 21 Aug 2025 16:35:10 -0700 Subject: [PATCH 3/5] refactor --- src/transformers/models/granitemoe/modeling_granitemoe.py | 2 +- .../{kernels => models/granitemoe}/scattermoe/__init__.py | 6 +++--- .../granitemoe}/scattermoe/group_backward_kernel.py | 0 .../granitemoe}/scattermoe/group_kernel.py | 0 .../granitemoe}/scattermoe/scatter_kernel.py | 0 5 files changed, 4 insertions(+), 4 deletions(-) rename src/transformers/{kernels => models/granitemoe}/scattermoe/__init__.py (95%) rename src/transformers/{kernels => models/granitemoe}/scattermoe/group_backward_kernel.py (100%) rename src/transformers/{kernels => models/granitemoe}/scattermoe/group_kernel.py (100%) rename src/transformers/{kernels => models/granitemoe}/scattermoe/scatter_kernel.py (100%) diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 27ec0e206986..33aa82dfeb1e 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -29,7 +29,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import auto_docstring, is_torch_flex_attn_available, logging from ...utils.deprecation import deprecate_kwarg -from ...kernels.scattermoe import scattered_experts +from .scattermoe import scattered_experts from .configuration_granitemoe import GraniteMoeConfig diff --git a/src/transformers/kernels/scattermoe/__init__.py b/src/transformers/models/granitemoe/scattermoe/__init__.py similarity index 95% rename from src/transformers/kernels/scattermoe/__init__.py rename to src/transformers/models/granitemoe/scattermoe/__init__.py index 442f0e2400b4..b7bd2fa38eb2 100644 --- a/src/transformers/kernels/scattermoe/__init__.py +++ b/src/transformers/models/granitemoe/scattermoe/__init__.py @@ -4,9 +4,9 @@ import torch -from transformers.kernels.scattermoe.group_backward_kernel import group_bwd_W -from transformers.kernels.scattermoe.group_kernel import group -from transformers.kernels.scattermoe.scatter_kernel import scatter2scatter +from .group_backward_kernel import group_bwd_W +from .group_kernel import group +from .scatter_kernel import scatter2scatter class _ScatteredExperts(torch.autograd.Function): diff --git a/src/transformers/kernels/scattermoe/group_backward_kernel.py b/src/transformers/models/granitemoe/scattermoe/group_backward_kernel.py similarity index 100% rename from src/transformers/kernels/scattermoe/group_backward_kernel.py rename to src/transformers/models/granitemoe/scattermoe/group_backward_kernel.py diff --git a/src/transformers/kernels/scattermoe/group_kernel.py b/src/transformers/models/granitemoe/scattermoe/group_kernel.py similarity index 100% rename from src/transformers/kernels/scattermoe/group_kernel.py rename to src/transformers/models/granitemoe/scattermoe/group_kernel.py diff --git a/src/transformers/kernels/scattermoe/scatter_kernel.py b/src/transformers/models/granitemoe/scattermoe/scatter_kernel.py similarity index 100% rename from src/transformers/kernels/scattermoe/scatter_kernel.py rename to src/transformers/models/granitemoe/scattermoe/scatter_kernel.py From e558968e87ae4925f229eb1593bc18d1e28d3ad9 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Fri, 22 Aug 2025 13:41:51 -0700 Subject: [PATCH 4/5] refactor --- src/transformers/models/granitemoe/modeling_granitemoe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 33aa82dfeb1e..b4aab07ffe12 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -389,7 +389,7 @@ def __init__(self, config: GraniteMoeConfig): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: original_shape = hidden_states.shape - hidden_states = hidden_states.view(-1, self.hidden_size) + hidden_states = hidden_states.view(-1, self.input_size) router_logits = self.router.layer(hidden_states) router_weights, selected_experts = router_logits.topk(self.top_k, dim=-1) From d8170a7a71306594b44f586d51de4be72338f574 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Fri, 22 Aug 2025 13:42:45 -0700 Subject: [PATCH 5/5] refactor --- src/transformers/models/granitemoe/modeling_granitemoe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index b4aab07ffe12..ddfffc977ca5 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -419,7 +419,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = scattered_experts( inputs=hidden_states, - expert_weights=self.weight.permute(0, 2, 1), + expert_weights=self.output_linear.weight.permute(0, 2, 1), k=1, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs,