From aa2679c77da251f917b1cd8d3abdad9a2f74ec3b Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Fri, 10 Oct 2025 04:36:25 +0000 Subject: [PATCH 1/3] Add ScatterMoE. --- scattermoe/README.md | 3 + scattermoe/build.toml | 3 + scattermoe/flake.nix | 17 + scattermoe/torch-ext/scattermoe/__init__.py | 15 + .../torch-ext/scattermoe/kernels/__init__.py | 3 + .../torch-ext/scattermoe/kernels/ops.py | 457 ++++++++++++++++++ .../torch-ext/scattermoe/kernels/single.py | 59 +++ scattermoe/torch-ext/scattermoe/layers.py | 52 ++ scattermoe/torch-ext/scattermoe/mlp.py | 96 ++++ .../torch-ext/scattermoe/parallel_experts.py | 182 +++++++ 10 files changed, 887 insertions(+) create mode 100644 scattermoe/README.md create mode 100644 scattermoe/build.toml create mode 100644 scattermoe/flake.nix create mode 100644 scattermoe/torch-ext/scattermoe/__init__.py create mode 100644 scattermoe/torch-ext/scattermoe/kernels/__init__.py create mode 100644 scattermoe/torch-ext/scattermoe/kernels/ops.py create mode 100644 scattermoe/torch-ext/scattermoe/kernels/single.py create mode 100644 scattermoe/torch-ext/scattermoe/layers.py create mode 100644 scattermoe/torch-ext/scattermoe/mlp.py create mode 100644 scattermoe/torch-ext/scattermoe/parallel_experts.py diff --git a/scattermoe/README.md b/scattermoe/README.md new file mode 100644 index 0000000..69ac952 --- /dev/null +++ b/scattermoe/README.md @@ -0,0 +1,3 @@ +This is the [ScatterMoE](https://arxiv.org/abs/2403.08245) written in Triton + +The main repository for ScatterMoE is [here](https://github.com/shawntan/scattermoe) \ No newline at end of file diff --git a/scattermoe/build.toml b/scattermoe/build.toml new file mode 100644 index 0000000..2080800 --- /dev/null +++ b/scattermoe/build.toml @@ -0,0 +1,3 @@ +[general] +name = "scattermoe" +universal = true diff --git a/scattermoe/flake.nix b/scattermoe/flake.nix new file mode 100644 index 0000000..d9491f6 --- /dev/null +++ b/scattermoe/flake.nix @@ -0,0 +1,17 @@ +{ + description = "Flake for scattermoe kernels"; + + inputs = { + kernel-builder.url = "github:huggingface/kernel-builder"; + }; + + outputs = + { + self, + kernel-builder, + }: + kernel-builder.lib.genFlakeOutputs { + path = ./.; + rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; + }; +} diff --git a/scattermoe/torch-ext/scattermoe/__init__.py b/scattermoe/torch-ext/scattermoe/__init__.py new file mode 100644 index 0000000..2af2dd0 --- /dev/null +++ b/scattermoe/torch-ext/scattermoe/__init__.py @@ -0,0 +1,15 @@ +from .parallel_experts import flatten_sort_count, parallel_linear, ParallelExperts +from . import parallel_experts +from . import kernels +from . import mlp +from . import layers + +__all__ = [ + "flatten_sort_count", + "parallel_linear", + "ParallelExperts", + "parallel_experts", + "kernels", + "mlp", + "layers" +] diff --git a/scattermoe/torch-ext/scattermoe/kernels/__init__.py b/scattermoe/torch-ext/scattermoe/kernels/__init__.py new file mode 100644 index 0000000..8d3cc9b --- /dev/null +++ b/scattermoe/torch-ext/scattermoe/kernels/__init__.py @@ -0,0 +1,3 @@ +from . import ops + +__all__ = ["ops"] \ No newline at end of file diff --git a/scattermoe/torch-ext/scattermoe/kernels/ops.py b/scattermoe/torch-ext/scattermoe/kernels/ops.py new file mode 100644 index 0000000..b47534c --- /dev/null +++ b/scattermoe/torch-ext/scattermoe/kernels/ops.py @@ -0,0 +1,457 @@ +import torch +import triton +import triton.language as tl +from typing import Optional + +BLOCK_M = 128 +ALLOW_TF32 = True + + + +@triton.jit +def _compute_expert_block( + E_idx, E_mask, + M_in_idx, + N_block, N_mask, + X_ptr, stride_xm, stride_xk, + W_ptr, stride_we, stride_wk, stride_wn, + K, + acc, + no_k_mask, + BLOCK_K, + allow_tf32=True, +): + + K_block = tl.arange(0, BLOCK_K) + X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we + iters = tl.cdiv(K, BLOCK_K) + + for K_block_id in range(iters): + if no_k_mask: + x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) + w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) + else: + K_mask = (K_block_id * BLOCK_K + K_block) < K + x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) + w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) + + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + acc = tl.dot(x, w, acc, allow_tf32=allow_tf32) + return acc + + +def _scatter2scatter_configs(): + return [ + triton.Config({'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4), + ] + +@triton.autotune(configs=_scatter2scatter_configs(), key=['M', 'N', 'K'], ) +@triton.heuristics({ + "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0, + "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0, +}) +@triton.jit +def _scatter2scatter( + X_ptr, stride_xm: tl.constexpr, stride_xk: tl.constexpr, + W_ptr, stride_we, stride_wk: tl.constexpr, stride_wn: tl.constexpr, + Y_ptr, stride_ym: tl.constexpr, stride_yn: tl.constexpr, + B_ptr, stride_be: tl.constexpr, stride_bn: tl.constexpr, + grouped_idx_ptr, expert_idxs_ptr, + # block_start_idx_ptr, + FAN_OUT: tl.constexpr, + M, K: tl.constexpr, N: tl.constexpr, E: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + # OUT_M, + allow_tf32: tl.constexpr, + x_grouped: tl.constexpr, y_grouped: tl.constexpr, + NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr +): + pid = tl.program_id(axis=0) + + N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N) + M_block_id = pid // N_BLOCK_COUNT + N_block_id = pid % N_BLOCK_COUNT + + M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M) + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + M_boundary_mask = M_block < (FAN_OUT * M) + E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E) + + no_k_mask = K % BLOCK_K == 0 + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + E_first_idx = tl.min(E_idxs) + E_last_idx = tl.minimum(tl.max(E_idxs), E - 1) + M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32) + for E_idx in range(E_first_idx, E_last_idx + 1): + E_mask = E_idxs == E_idx + E_M_idx = M_idx + if x_grouped: + M_in_idx = M_block + else: + M_in_idx = E_M_idx // FAN_OUT + acc = _compute_expert_block( + E_idx, E_mask, + M_in_idx, N_block, N_mask, + X_ptr, stride_xm, stride_xk, + W_ptr, stride_we, stride_wk, stride_wn, + K, + acc, + no_k_mask, + BLOCK_K, + allow_tf32=allow_tf32, + ) + + if B_ptr is not None: + B_blk_ptrs = B_ptr + E_idxs[:, None] * stride_be + N_block[None, :] * stride_bn + acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :]) + + if y_grouped: + M_out_idx = M_block + else: + M_out_idx = M_idx + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :]) + +def scatter2scatter(X, W, sorted_expert_idxs, sorted_scattered_idxs, k, + b=None, + x_grouped=False, y_grouped=False, + out=None): + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * k + # Pre-kernel setup + y_dim = W.size(-1) + L_scattered = sorted_expert_idxs.size(0) + if out is None: + output = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype) + else: + assert out.size(0) == L_scattered and out.size(1) == y_dim + output = out + + scatter2scatter_compileable(output, W, X, k, sorted_expert_idxs, sorted_scattered_idxs, + b, x_grouped, y_grouped) + return output + + +@torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"output"}) +def scatter2scatter_compileable( + output: torch.Tensor, + W: torch.Tensor, + X: torch.Tensor, + k: int, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + b: Optional[torch.Tensor], + x_grouped: bool, y_grouped: bool) -> None: + def grid(META): + grid_num = ( + triton.cdiv(sorted_expert_idxs.size(0), META["BLOCK_M"]) * + triton.cdiv(META['N'], META['BLOCK_N']), + ) + return grid_num + + if b is None: + b = None + stride_be = stride_bk = 0 + else: + stride_be, stride_bk = b.stride() + + _scatter2scatter[grid]( + # X_ptr, stride_xm, stride_xk, + X, X.stride(0), X.stride(1), + # W_ptr, stride_we, stride_wk, stride_wn, + W, W.stride(0), W.stride(1), W.stride(2), + # Y_ptr, stride_ym, stride_yn, + output, output.stride(0), output.stride(1), + # B_ptr, stride_be, stride_bk + b, stride_be, stride_bk, + grouped_idx_ptr=sorted_scattered_idxs, + expert_idxs_ptr=sorted_expert_idxs, + # block_start_idx_ptr=padded_block_idxs, + FAN_OUT=k, + M=X.size(0), + K=X.size(1), + N=output.size(1), E=W.size(0), + BLOCK_M=BLOCK_M, + ACC_TYPE=tl.float32, + allow_tf32=ALLOW_TF32, + x_grouped=x_grouped, y_grouped=y_grouped, + ) + + +def _config_XtY(): + return [ + triton.Config({'BLOCK_N': 128, 'BLOCK_K': 128, 'BLOCK_M': 32}, num_stages=4, num_warps=4), + ] + +def group_bwd_W(DY, X, expert_offsets, E, has_bias=False): + DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype) + DW = DWt.permute(0, 2, 1) + if has_bias: + Db = torch.zeros((E, DY.size(-1)), device=DY.device, dtype=DY.dtype) + else: + Db = None + groupXtY_compileable(E, DW, Db, DY, X, expert_offsets) + return DW, Db + + +@torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW"}) +def groupXtY_compileable( + E: int, + DW: torch.Tensor, + Db: Optional[torch.Tensor], + DY: torch.Tensor, + X: torch.Tensor, + expert_offsets: torch.Tensor) -> None: + def grid(META): + grid = ( + E * triton.cdiv(META['K'], META['BLOCK_K']), + triton.cdiv(META['N'], META['BLOCK_N']), + ) + return grid + + if Db is None: + stride_dbe = 0 + stride_dbn = 0 + else: + stride_dbe, stride_dbn = Db.stride() + + _groupXtY[grid]( + # DY_ptr, stride_dym, stride_dyk, + DY, DY.stride(0), DY.stride(1), + # X_ptr, stride_xm, stride_xn, + X, X.stride(0), X.stride(1), + # DW_ptr, stride_dwe, stride_dwk, stride_dwn, + DW, DW.stride(0), DW.stride(1), DW.stride(2), + # Db_ptr, stride_dwe, stride_dbn, + Db, stride_dbe, stride_dbn, + # expert_offsets_ptr, + expert_offsets, + # K: tl.constexpr, N: tl.constexpr, + M=DY.size(0), N=DY.size(-1), K=X.size(-1), + # ACC_TYPE: tl.constexpr, + ACC_TYPE=tl.float32, + allow_tf32=ALLOW_TF32 + ) + + +@triton.autotune(configs=_config_XtY(), key=['M', 'N', 'K'], ) +@triton.heuristics({ + "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0, + "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0, +}) +@triton.jit +def _groupXtY( + DY_ptr, stride_dym, stride_dyk, + X_ptr, stride_xm, stride_xn, + DW_ptr, stride_dwe, stride_dwk, stride_dwn, + Db_ptr, stride_dbe, stride_dbn, + expert_offsets_ptr, + M, K: tl.constexpr, N: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, + NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr +): + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + num0 = tl.num_programs(0) + num1 = tl.num_programs(1) + # pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128) + pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4) + + K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) + E_idx = pid0 // K_BLOCK_COUNT + K_block_id = pid0 % K_BLOCK_COUNT + N_block_id = pid1 + + if E_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) + end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) + + + if end_idx > start_idx: + M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M) + + K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + K_mask = K_block < K + K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) + + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) + + M_idxs = M_block + xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm + dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk + if (Db_ptr is not None) and (K_block_id == 0): + _xty_and_bias( + E_idx, start_idx, end_idx, + M_block, + K_block, K_mask, N_block, N_mask, + dy_blk_ptrs, stride_dym, + xt_blk_ptrs, stride_xm, + DW_ptr, stride_dwe, stride_dwk, stride_dwn, + Db_ptr, stride_dbe, stride_dbn, + BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE, + allow_tf32, NO_K_MASK, NO_N_MASK, + compute_bias=True + ) + else: + _xty_and_bias( + E_idx, start_idx, end_idx, + M_block, + K_block, K_mask, N_block, N_mask, + dy_blk_ptrs, stride_dym, + xt_blk_ptrs, stride_xm, + DW_ptr, stride_dwe, stride_dwk, stride_dwn, + Db_ptr, stride_dbe, stride_dbn, + BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE, + allow_tf32, NO_K_MASK, NO_N_MASK, + compute_bias=False + ) + + +@triton.jit +def _xty_and_bias( + E_idx, start_idx, end_idx, + M_block, + K_block, K_mask, N_block, N_mask, + dy_blk_ptrs, stride_dym, + xt_blk_ptrs, stride_xm, + DW_ptr, stride_dwe, stride_dwk, stride_dwn, + Db_ptr, stride_dbe, stride_dbn, + BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE, + allow_tf32, NO_K_MASK, NO_N_MASK, + compute_bias: tl.constexpr + ): + + if compute_bias: + db_acc = tl.zeros((BLOCK_N,), dtype=ACC_TYPE) + else: + db_acc = None + + acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE) + iters = tl.cdiv(end_idx - start_idx, BLOCK_M) + for i in range(0, iters): + M_mask = (i * BLOCK_M + M_block) < end_idx + if NO_K_MASK: + xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :]) + else: + xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) + if NO_N_MASK: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None]) + else: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) + + acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) + + xt_blk_ptrs += BLOCK_M * stride_xm + dy_blk_ptrs += BLOCK_M * stride_dym + + if compute_bias: + db_acc += tl.sum(dy, axis=0) + + DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn + acc = acc.to(DW_blk_ptrs.dtype.element_ty) + tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) + if compute_bias: + Db_blk_ptrs = Db_ptr + E_idx * stride_dbe + N_block * stride_dbn + tl.store(Db_blk_ptrs, db_acc, mask=N_mask) + + + +def _config_grouping(): + return [ + triton.Config({'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4), + ] + +def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None): + N = sorted_expert_idxs.size(0) + K = A.size(1) + assert A.size(0) * fan_out == N + if out is not None: + Y = out + else: + Y = torch.empty((N, K), dtype=A.dtype, device=A.device) + group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs) + return Y + + +@torch.library.custom_op("scattermoe::group", mutates_args={"Y"}) +def group_compileable( + A: torch.Tensor, + K: int, + N: int, + Y: torch.Tensor, + coeff: torch.Tensor, has_coeff: bool, + fan_out: int, + sorted_expert_idxs: torch.Tensor) -> None: + def grid(META): + grid_num = (triton.cdiv(META['N'], META['BLOCK_N']),) + return grid_num + _group[grid]( + # A_ptr, stride_an, stride_ai, + A, A.stride(0), A.stride(1), has_coeff, coeff, fan_out, + # Y_ptr, stride_yn, stride_yk, + Y, Y.stride(0), Y.stride(1), + # grouped_idx_ptr, + sorted_expert_idxs, + # N: tl.constexpr, K: tl.constexpr, + N, K + ) + + +@triton.autotune(configs=_config_grouping(), key=['K']) +@triton.heuristics({ + "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0 +}) +@triton.jit +def _group( + src_ptr, stride_sn, stride_sk, has_coeff: tl.constexpr, coeff_ptr, FAN_OUT: tl.constexpr, + tgt_ptr, stride_tn, stride_ti, + grouped_idx_ptr, + N, K: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + NO_K_MASK: tl.constexpr +): + pid = tl.program_id(axis=0) + + N_block_id = pid + N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_blk < N + N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N) + N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0) + + K_blk = tl.arange(0, BLOCK_K) + src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk + tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti + + if has_coeff: + c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None] + + iters = tl.cdiv(K, BLOCK_K) + for i in range(0, iters): + if NO_K_MASK or i < iters - 1: + block = tl.load(src_blk_ptrs, mask=N_mask[:, None]) + if has_coeff: + block *= c + tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None]) + + else: + K_mask = (i * BLOCK_K + K_blk) < K + mask = N_mask[:, None] & K_mask[None, :] + block = tl.load(src_blk_ptrs, mask=mask) + if has_coeff: + block *= c + tl.store(tgt_blk_ptrs, block, mask=mask) + src_blk_ptrs += BLOCK_K * stride_sk + tgt_blk_ptrs += BLOCK_K * stride_ti diff --git a/scattermoe/torch-ext/scattermoe/kernels/single.py b/scattermoe/torch-ext/scattermoe/kernels/single.py new file mode 100644 index 0000000..1438efb --- /dev/null +++ b/scattermoe/torch-ext/scattermoe/kernels/single.py @@ -0,0 +1,59 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def _single2scatter( + X_ptr, stride_xm, stride_xk, + W_ptr, stride_we, stride_wk, stride_wn, + Y_ptr, stride_ym, stride_yn, + expert_idxs_ptr, + FAN_OUT: tl.constexpr, + K: tl.constexpr, N: tl.constexpr, E: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, +): + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + + N_block_id = pid0 + if FAN_OUT == 1: + in_idx = pid1 + else: + in_idx = 0 + out_idx = pid1 + + K_block = tl.arange(0, BLOCK_K) + N_block = tl.max_contiguous(tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N), BLOCK_N) + E_idx = tl.load(expert_idxs_ptr + pid1) + X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk + W_blk_ptrs = W_ptr + E_idx * stride_we + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE) + for K_block_id in range(0, tl.cdiv(K, BLOCK_K)): + x = tl.load(X_blk_ptrs) + w = tl.load(W_blk_ptrs) + acc += tl.sum(x * w, axis=0)[None, :] + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn + tl.store(Y_blk_ptrs, acc) + +def single2scatter(X, W, expert_idxs): + E, xdim, ydim = W.size() + k = expert_idxs.size(1) + assert X.size(0) == k or X.size(0) == 1 + Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype) + BLOCK_N = 128 + BLOCK_K = 128 + grid = ydim // BLOCK_N, k + _single2scatter[grid]( + X, X.stride(0), X.stride(1), + W, W.stride(0), W.stride(1), W.stride(2), + Y, Y.stride(0), Y.stride(1), + expert_idxs, + FAN_OUT=Y.size(0) // X.size(0), + K=xdim, N=ydim, E=E, + BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, + ACC_TYPE=tl.float32 + ) + return Y diff --git a/scattermoe/torch-ext/scattermoe/layers.py b/scattermoe/torch-ext/scattermoe/layers.py new file mode 100644 index 0000000..cabe7d5 --- /dev/null +++ b/scattermoe/torch-ext/scattermoe/layers.py @@ -0,0 +1,52 @@ +import torch +from torch.nn import functional as F +from torch import nn + +from . import parallel_linear, flatten_sort_count + +class ScatterMoEGatedMLP(nn.Module): + def forward(self, layer_input): + """ + Forward pass of the mixture of experts layer. + + Args: + layer_input (Tensor): + Input tensor. + + Returns: + Tensor: + Output tensor. + Tensor: + Router logits. + """ + bsz, length, emb_size = layer_input.size() + layer_input = layer_input.reshape(-1, emb_size) + # compute the top_k routing decision + router_logits = self.router.layer(layer_input) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.router.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(layer_input.dtype) + sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = \ + flatten_sort_count(selected_experts, num_experts=self.router.num_experts) + + # compute experts + gates, h = parallel_linear( + layer_input, self.input_linear.weight.transpose(2, 1), + self.router.top_k, + sorted_expert_idxs, sorted_scattered_idxs, + expert_offsets, + grouped_in=False, grouped_out=True, + ).chunk(2, dim=-1) + h = self.activation(gates) * h + layer_output = parallel_linear( + h, self.output_linear.weight.transpose(2, 1), + 1, + sorted_expert_idxs, sorted_scattered_idxs, + expert_offsets, + grouped_in=True, grouped_out=False, + gates=routing_weights + ) + layer_output = layer_output.view(bsz, length, emb_size) + return layer_output + diff --git a/scattermoe/torch-ext/scattermoe/mlp.py b/scattermoe/torch-ext/scattermoe/mlp.py new file mode 100644 index 0000000..67796ec --- /dev/null +++ b/scattermoe/torch-ext/scattermoe/mlp.py @@ -0,0 +1,96 @@ +import torch +from torch import nn + +from .parallel_experts import ParallelExperts, flatten_sort_count + +class MLP(nn.Module): + def __init__( + self, + input_size, + hidden_size, + num_experts, + top_k, + bias=False, + activation=None, + ): + super(MLP, self).__init__() + + self.num_experts = num_experts + self.input_size = input_size + self.hidden_size = hidden_size + self.experts = ParallelExperts(num_experts, input_size, hidden_size, bias=bias) + self.output_experts = ParallelExperts(num_experts, hidden_size, input_size, bias=bias) + self.top_k = min(top_k, self.num_experts) + self.activation = activation + + def extra_repr(self): + return 'k={}'.format(self.top_k) + + def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor): + x_shape = x.size() + x = x.view(-1, x_shape[-1]) + sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = \ + flatten_sort_count(expert_idxs, num_experts=self.num_experts) + + h = self.experts( + x, self.top_k, + sorted_expert_idxs, sorted_scattered_idxs, + expert_offsets, + grouped_out=True + ) + h = self.activation(h) + y = self.output_experts( + h, 1, sorted_expert_idxs, sorted_scattered_idxs, + expert_offsets, + grouped_in=True, + gates=expert_p, + ) + y = y.view(*x_shape[:-1], y.size(-1)) + return y + +class GLUMLP(nn.Module): + def __init__( + self, + input_size, + hidden_size, + num_experts, + top_k, + bias=False, + activation=nn.SiLU(), + ): + super(GLUMLP, self).__init__() + + self.num_experts = num_experts + self.input_size = input_size + self.hidden_size = hidden_size + self.experts = ParallelExperts(num_experts, input_size, 2 * hidden_size, bias=bias) + self.output_experts = ParallelExperts(num_experts, hidden_size, input_size, bias=bias) + self.top_k = min(top_k, self.num_experts) + self.activation = activation + + def extra_repr(self): + return 'k={}'.format(self.top_k) + + def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor): + x_shape = x.size() + x = x.view(-1, x_shape[-1]) + sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = \ + flatten_sort_count(expert_idxs, num_experts=self.num_experts) + + + h, gates = self.experts( + x, self.top_k, + sorted_expert_idxs, sorted_scattered_idxs, + expert_offsets, + grouped_out=True + ).chunk(2, dim=-1) + h = self.activation(gates) * h + y = self.output_experts( + h, 1, sorted_expert_idxs, sorted_scattered_idxs, + expert_offsets, + grouped_in=True, + gates=expert_p, + ) + y = y.view(*x_shape[:-1], y.size(-1)) + return y + diff --git a/scattermoe/torch-ext/scattermoe/parallel_experts.py b/scattermoe/torch-ext/scattermoe/parallel_experts.py new file mode 100644 index 0000000..fe67f25 --- /dev/null +++ b/scattermoe/torch-ext/scattermoe/parallel_experts.py @@ -0,0 +1,182 @@ +import torch +import torch.nn as nn +from . import kernels +from typing import Optional + +@torch.library.custom_op("scattermoe::bincount", mutates_args={}) +def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor: + return x.bincount(minlength=minlength) + +@compileable_bincount.register_fake +def _(x: torch.Tensor, minlength: int) -> torch.Tensor: + return torch.empty(minlength, dtype=torch.long, device=x.device) + +@torch.compile +def flatten_sort_count(expert_idxs: torch.Tensor, num_experts: int): + with torch.no_grad(): + flattened_expert_idxs = expert_idxs.flatten() + sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs) + expert_counts = compileable_bincount(flattened_expert_idxs, minlength=num_experts) + expert_offsets = expert_counts.cumsum(-1) + return sorted_expert_idxs, sorted_scattered_idxs, expert_offsets + + + +class ParallelLinear(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, expert_weights: torch.Tensor, k: int, + sorted_expert_idxs: torch.Tensor, sorted_scattered_idxs: torch.Tensor, + expert_offsets: torch.Tensor, + expert_biases: Optional[torch.Tensor]=None, + gates: Optional[torch.Tensor]=None, + grouped_in: bool =False, grouped_out: bool=False, + ): + with torch.device(x.device): + output = kernels.ops.scatter2scatter( + X=x, W=expert_weights, + b=expert_biases, k=k, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + x_grouped=grouped_in, y_grouped=grouped_out + ) + if gates is not None: + output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1)) + output = (gates.unsqueeze(1) @ output_expanded).squeeze(1) + else: + output_expanded = None + + ctx.save_for_backward( + x, expert_weights, + expert_biases, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + gates, + output_expanded + ) + ctx.grouped_in = grouped_in + ctx.grouped_out = grouped_out + ctx.k = k + return output + @staticmethod + def backward(ctx, grad_out: torch.Tensor): + with torch.device(grad_out.device): + (x, expert_weights, expert_biases, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + gates, output_expanded) = ctx.saved_tensors + k = ctx.k + grouped_in = ctx.grouped_in + grouped_out = ctx.grouped_out + # print("backward") + + if gates is not None: + # calculate gates gradient + # d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1) + d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1) + gates_flat = gates.flatten() + gate_fan = gates.size(1) + grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later + else: + d_gates = None + gates_flat = None + gate_fan = 1 + grouped_grad_out = None + + if grouped_out: + grouped_grad_out = grad_out + else: + grouped_grad_out = kernels.ops.group(grad_out, sorted_scattered_idxs, + fan_out=gate_fan, coeff=gates_flat, + out=grouped_grad_out) + if grouped_in: + grouped_x = x + d_expanded_input = None + else: + grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k) + d_expanded_input = grouped_x + + d_weights, d_biases = kernels.ops.group_bwd_W( + DY=grouped_grad_out, X=grouped_x, + expert_offsets=expert_offsets, + E=expert_weights.size(0), + has_bias=expert_biases is not None + ) + + + d_expanded_input = kernels.ops.scatter2scatter( + X=grouped_grad_out, x_grouped=True, + W=expert_weights.permute(0, 2, 1), + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=1, + y_grouped=grouped_in, + out=d_expanded_input # Reuse grouped_x buffer + ) + + if k == 1: + d_input = d_expanded_input + else: + d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2) + # print("backward end.") + return ( + # x, expert_weights, + d_input, d_weights, + # k, sorted_expert_idxs, sorted_scattered_idxs, expert_offsets, + None, None, None, None, + # bias, gates + d_biases, d_gates, + # grouped_in, grouped_out, + None, None + ) + +def parallel_linear(inputs, expert_weights, k, + sorted_expert_idxs, sorted_scattered_idxs, + expert_offsets, + expert_biases=None, + gates=None, grouped_in=False, grouped_out=False): + results = ParallelLinear.apply(inputs, expert_weights, k, + sorted_expert_idxs, sorted_scattered_idxs, + expert_offsets, + expert_biases, + gates, grouped_in, grouped_out) + return results + +class ParallelExperts(nn.Module): + def __init__(self, num_experts, input_size, output_size, bias=False) -> None: + super().__init__() + self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size)) + + if bias: + self.bias = nn.Parameter(torch.empty(num_experts, output_size)) + else: + self.bias = None + + self.num_experts = num_experts + self.input_size = input_size + self.output_size = output_size + self.reset_parameters() + + def extra_repr(self): + return 'num_experts={}, input_size={}, output_size={}'.format( + self.num_experts, self.input_size, self.output_size) + + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, std=0.02) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, inputs, k, sorted_expert_idxs, sorted_scattered_idxs, + expert_offsets, + gates=None, grouped_in=False, grouped_out=False): + + results = parallel_linear( + inputs, self.weight.permute(0, 2, 1), k, + sorted_expert_idxs, sorted_scattered_idxs, expert_offsets, + expert_biases=self.bias, + gates=gates, grouped_in=grouped_in, grouped_out=grouped_out + ) + return results From 9dd104ee5e29833c61b8b50b8452e16bcd7b8308 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Fri, 10 Oct 2025 20:42:44 +0000 Subject: [PATCH 2/3] Remove MLP class. --- scattermoe/README.md | 2 + scattermoe/torch-ext/scattermoe/__init__.py | 2 - scattermoe/torch-ext/scattermoe/layers.py | 2 +- scattermoe/torch-ext/scattermoe/mlp.py | 96 --------------------- 4 files changed, 3 insertions(+), 99 deletions(-) delete mode 100644 scattermoe/torch-ext/scattermoe/mlp.py diff --git a/scattermoe/README.md b/scattermoe/README.md index 69ac952..d70352c 100644 --- a/scattermoe/README.md +++ b/scattermoe/README.md @@ -1,3 +1,5 @@ +# ScatterMoE + This is the [ScatterMoE](https://arxiv.org/abs/2403.08245) written in Triton The main repository for ScatterMoE is [here](https://github.com/shawntan/scattermoe) \ No newline at end of file diff --git a/scattermoe/torch-ext/scattermoe/__init__.py b/scattermoe/torch-ext/scattermoe/__init__.py index 2af2dd0..e52e82f 100644 --- a/scattermoe/torch-ext/scattermoe/__init__.py +++ b/scattermoe/torch-ext/scattermoe/__init__.py @@ -1,7 +1,6 @@ from .parallel_experts import flatten_sort_count, parallel_linear, ParallelExperts from . import parallel_experts from . import kernels -from . import mlp from . import layers __all__ = [ @@ -10,6 +9,5 @@ "ParallelExperts", "parallel_experts", "kernels", - "mlp", "layers" ] diff --git a/scattermoe/torch-ext/scattermoe/layers.py b/scattermoe/torch-ext/scattermoe/layers.py index cabe7d5..b2aa535 100644 --- a/scattermoe/torch-ext/scattermoe/layers.py +++ b/scattermoe/torch-ext/scattermoe/layers.py @@ -48,5 +48,5 @@ def forward(self, layer_input): gates=routing_weights ) layer_output = layer_output.view(bsz, length, emb_size) - return layer_output + return layer_output, router_logits diff --git a/scattermoe/torch-ext/scattermoe/mlp.py b/scattermoe/torch-ext/scattermoe/mlp.py deleted file mode 100644 index 67796ec..0000000 --- a/scattermoe/torch-ext/scattermoe/mlp.py +++ /dev/null @@ -1,96 +0,0 @@ -import torch -from torch import nn - -from .parallel_experts import ParallelExperts, flatten_sort_count - -class MLP(nn.Module): - def __init__( - self, - input_size, - hidden_size, - num_experts, - top_k, - bias=False, - activation=None, - ): - super(MLP, self).__init__() - - self.num_experts = num_experts - self.input_size = input_size - self.hidden_size = hidden_size - self.experts = ParallelExperts(num_experts, input_size, hidden_size, bias=bias) - self.output_experts = ParallelExperts(num_experts, hidden_size, input_size, bias=bias) - self.top_k = min(top_k, self.num_experts) - self.activation = activation - - def extra_repr(self): - return 'k={}'.format(self.top_k) - - def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor): - x_shape = x.size() - x = x.view(-1, x_shape[-1]) - sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = \ - flatten_sort_count(expert_idxs, num_experts=self.num_experts) - - h = self.experts( - x, self.top_k, - sorted_expert_idxs, sorted_scattered_idxs, - expert_offsets, - grouped_out=True - ) - h = self.activation(h) - y = self.output_experts( - h, 1, sorted_expert_idxs, sorted_scattered_idxs, - expert_offsets, - grouped_in=True, - gates=expert_p, - ) - y = y.view(*x_shape[:-1], y.size(-1)) - return y - -class GLUMLP(nn.Module): - def __init__( - self, - input_size, - hidden_size, - num_experts, - top_k, - bias=False, - activation=nn.SiLU(), - ): - super(GLUMLP, self).__init__() - - self.num_experts = num_experts - self.input_size = input_size - self.hidden_size = hidden_size - self.experts = ParallelExperts(num_experts, input_size, 2 * hidden_size, bias=bias) - self.output_experts = ParallelExperts(num_experts, hidden_size, input_size, bias=bias) - self.top_k = min(top_k, self.num_experts) - self.activation = activation - - def extra_repr(self): - return 'k={}'.format(self.top_k) - - def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor): - x_shape = x.size() - x = x.view(-1, x_shape[-1]) - sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = \ - flatten_sort_count(expert_idxs, num_experts=self.num_experts) - - - h, gates = self.experts( - x, self.top_k, - sorted_expert_idxs, sorted_scattered_idxs, - expert_offsets, - grouped_out=True - ).chunk(2, dim=-1) - h = self.activation(gates) * h - y = self.output_experts( - h, 1, sorted_expert_idxs, sorted_scattered_idxs, - expert_offsets, - grouped_in=True, - gates=expert_p, - ) - y = y.view(*x_shape[:-1], y.size(-1)) - return y - From 6c4f7750b62d2134715424e44a3f41952c7a65bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Sat, 11 Oct 2025 11:47:56 +0200 Subject: [PATCH 3/3] Update flake --- scattermoe/flake.lock | 168 ++++++++++++++++++++++++++++++++++++++++++ scattermoe/flake.nix | 2 +- 2 files changed, 169 insertions(+), 1 deletion(-) create mode 100644 scattermoe/flake.lock diff --git a/scattermoe/flake.lock b/scattermoe/flake.lock new file mode 100644 index 0000000..65c5dc2 --- /dev/null +++ b/scattermoe/flake.lock @@ -0,0 +1,168 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_2": { + "locked": { + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "hf-nix": { + "inputs": { + "flake-compat": "flake-compat_2", + "flake-utils": "flake-utils_2", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1759851564, + "narHash": "sha256-Xybkhm0FM/VzlZ5WndTYq/X/9MAeddd4EQ2Vz8GdkOA=", + "owner": "huggingface", + "repo": "hf-nix", + "rev": "351655d9f124805ed7c1193aa61550ce245f4570", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "hf-nix", + "type": "github" + } + }, + "kernel-builder": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "hf-nix": "hf-nix", + "nixpkgs": [ + "kernel-builder", + "hf-nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1760097360, + "narHash": "sha256-nwQoIwtT5zuV3QF9loxKVjGdS2Y55L4Lkfo3N/2eoTQ=", + "owner": "huggingface", + "repo": "kernel-builder", + "rev": "0cf2febab3e9ff2eca9b43115776b52c2f253f7f", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "kernel-builder", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1755963616, + "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable-small", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "kernel-builder": "kernel-builder" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/scattermoe/flake.nix b/scattermoe/flake.nix index d9491f6..33185d1 100644 --- a/scattermoe/flake.nix +++ b/scattermoe/flake.nix @@ -11,7 +11,7 @@ kernel-builder, }: kernel-builder.lib.genFlakeOutputs { + inherit self; path = ./.; - rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; }; }