-
Notifications
You must be signed in to change notification settings - Fork 31.3k
add scattermoe kernel for fast MoE training #40365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @mayank31398 ! Nice pr, happy to add something like that, do you mind using kernels like what we do for GPT_OSS?!
This way we keep a slow path, compatible with all torch, all hardwares etc and don't have code changes for the core modeling, and just have the kernel on the hub!
WDYT? 🤗
|
@ArthurZucker scattermoe doesnt support bias for now, I will add this soon! |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: granitemoe |
Is there an existing triton kernel you could point to that I could follow? |
|
This one: https://huggingface.co/kernels-community/megablocks/tree/main/torch-ext/megablocks (might not be triton) and otherwise https://huggingface.co/kernels-community/triton_kernels fully triton! |
|
Sorry it took a while. I've tried to piece together the various guides for community kernels to package https://huggingface.co/shawntan/scattermoe This is what I have so far, and the following seems to work: from kernels import (
LocalLayerRepository,
use_kernel_mapping,
Mode,
use_kernel_forward_from_hub,
kernelize
)
from transformers import AutoTokenizer, AutoConfig
from transformers.activations import ACT2FN
from pathlib import Path
import torch
from torch import nn
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import (
GraniteMoeHybridConfig,
GraniteMoeHybridParallelExperts,
GraniteMoeHybridTopKGating,
)
@use_kernel_forward_from_hub('ScatterMoEGatedMLP')
class GraniteMoeHybridMoE(nn.Module):
"""
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
Args:
config:
Configuration object with model hyperparameters.
"""
def __init__(self, config: GraniteMoeHybridConfig):
super().__init__()
self.input_size = config.hidden_size
self.hidden_size = config.intermediate_size
self.activation = ACT2FN[config.hidden_act]
self.input_linear = GraniteMoeHybridParallelExperts(
config.num_local_experts, self.input_size, self.hidden_size * 2
)
self.output_linear = GraniteMoeHybridParallelExperts(
config.num_local_experts, self.hidden_size, self.input_size
)
self.router = GraniteMoeHybridTopKGating(
input_size=self.input_size,
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
)
def forward(self, layer_input):
"""
Forward pass of the mixture of experts layer.
Args:
layer_input (Tensor):
Input tensor.
Returns:
Tensor:
Output tensor.
Tensor:
Router logits.
"""
bsz, length, emb_size = layer_input.size()
layer_input = layer_input.reshape(-1, emb_size)
_, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input)
expert_inputs = layer_input[batch_index]
hidden_states = self.input_linear(expert_inputs, expert_size)
chunked_hidden_states = hidden_states.chunk(2, dim=-1)
hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
expert_outputs = self.output_linear(hidden_states, expert_size)
expert_outputs = expert_outputs * batch_gates[:, None]
zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
layer_output = zeros.index_add(0, batch_index, expert_outputs)
layer_output = layer_output.view(bsz, length, self.input_size)
return layer_output, router_logits
model_path = "ibm-granite/granite-4.0-h-tiny-base"
device = torch.device("cuda")
kernel_layer_mapping = {
"ScatterMoEGatedMLP": {
"cuda": LocalLayerRepository(
repo_path=Path("/u/shawntan/hf_scattermoe"),
package_name='scattermoe',
layer_name="ScatterMoEGatedMLP"
)
# "cuda": LayerRepository(
# repo_id='shawntan/scattermoe',
# layer_name='ScatterMoEGatedMLP'
# )
}
}
# scattermoe = get_kernel("shawntan/scattermoe")
config = AutoConfig.from_pretrained(model_path, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = GraniteMoeHybridMoE(config).to(device)
for p in model.parameters():
torch.nn.init.normal_(p, std=0.02)
x = torch.randn(4, 4096, 1536, device=device)
out_reference, _ = model(x)
with use_kernel_mapping(kernel_layer_mapping):
model = kernelize(model, mode=Mode.TRAINING)
out_kernel, _ = model(x)
print((out_reference - out_kernel).abs().max())What further steps need to be done for submitting it to community kernels? And also to include it in the |
|
@ArthurZucker any thoughts? |
|
Hey @shawntan thanks for this contribution, the kernel looks good to me! We can have it in |
|
Alright! What do you need specifically? I can come up with some benchmarks for the Granite models. |
MekkCyber
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that would be great! we need some latency & memory benchmarks with and without the kernel for different seq len sizes to see if we have some speedups or improved memory consumption
|
https://arxiv.org/pdf/2403.08245 The report for ScatterMoE has some benchmarks on previous models compared against Megablocks. It includes both speedup and memory usage comparisons.
|
|
Very nice performance @shawntan ! Thanks for sharing ! So the kernel is used for training not inference ? |
|
It can be used for both. But gains will mainly come for training, or in cases of prefill. I've made the draft changes for |
|
Started another PR here #41458 because it is significantly different from the current one. |
MekkCyber
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let’s close this PR then it will be superseded by #41458
|
ciao |

No description provided.