Skip to content

Conversation

@mayank31398
Copy link
Contributor

No description provided.

@Rocketknight1
Copy link
Member

cc @ArthurZucker

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @mayank31398 ! Nice pr, happy to add something like that, do you mind using kernels like what we do for GPT_OSS?!
This way we keep a slow path, compatible with all torch, all hardwares etc and don't have code changes for the core modeling, and just have the kernel on the hub!

WDYT? 🤗

@mayank31398
Copy link
Contributor Author

@ArthurZucker scattermoe doesnt support bias for now, I will add this soon!
meanwhile supporting every model is hard since some models have expert weights as a moduleList instead of a 3D tensor :/

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: granitemoe

@shawntan
Copy link
Contributor

Hello @mayank31398 ! Nice pr, happy to add something like that, do you mind using kernels like what we do for GPT_OSS?! This way we keep a slow path, compatible with all torch, all hardwares etc and don't have code changes for the core modeling, and just have the kernel on the hub!

WDYT? 🤗

Is there an existing triton kernel you could point to that I could follow?

@ArthurZucker
Copy link
Collaborator

This one: https://huggingface.co/kernels-community/megablocks/tree/main/torch-ext/megablocks (might not be triton) and otherwise https://huggingface.co/kernels-community/triton_kernels fully triton!

@3outeille 3outeille assigned 3outeille and unassigned 3outeille Aug 25, 2025
@shawntan
Copy link
Contributor

shawntan commented Oct 5, 2025

Sorry it took a while. I've tried to piece together the various guides for community kernels to package scattermoe.

https://huggingface.co/shawntan/scattermoe

This is what I have so far, and the following seems to work:

from kernels import (
    LocalLayerRepository,
    use_kernel_mapping, 
    Mode,
    use_kernel_forward_from_hub,
    kernelize
)

from transformers import AutoTokenizer, AutoConfig
from transformers.activations import ACT2FN

from pathlib import Path
import torch
from torch import nn
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import (
    GraniteMoeHybridConfig,
    GraniteMoeHybridParallelExperts,
    GraniteMoeHybridTopKGating,
)


@use_kernel_forward_from_hub('ScatterMoEGatedMLP')
class GraniteMoeHybridMoE(nn.Module):
    """
    A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.

    Args:
        config:
            Configuration object with model hyperparameters.
    """

    def __init__(self, config: GraniteMoeHybridConfig):
        super().__init__()

        self.input_size = config.hidden_size
        self.hidden_size = config.intermediate_size
        self.activation = ACT2FN[config.hidden_act]
        self.input_linear = GraniteMoeHybridParallelExperts(
            config.num_local_experts, self.input_size, self.hidden_size * 2
        )
        self.output_linear = GraniteMoeHybridParallelExperts(
            config.num_local_experts, self.hidden_size, self.input_size
        )

        self.router = GraniteMoeHybridTopKGating(
            input_size=self.input_size,
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
        )

    def forward(self, layer_input):
        """
        Forward pass of the mixture of experts layer.

        Args:
            layer_input (Tensor):
                Input tensor.

        Returns:
            Tensor:
                Output tensor.
            Tensor:
                Router logits.
        """
        bsz, length, emb_size = layer_input.size()
        layer_input = layer_input.reshape(-1, emb_size)
        _, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input)

        expert_inputs = layer_input[batch_index]
        hidden_states = self.input_linear(expert_inputs, expert_size)
        chunked_hidden_states = hidden_states.chunk(2, dim=-1)
        hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
        expert_outputs = self.output_linear(hidden_states, expert_size)

        expert_outputs = expert_outputs * batch_gates[:, None]

        zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
        layer_output = zeros.index_add(0, batch_index, expert_outputs)
        layer_output = layer_output.view(bsz, length, self.input_size)
        return layer_output, router_logits


model_path = "ibm-granite/granite-4.0-h-tiny-base"
device = torch.device("cuda")

kernel_layer_mapping = {
    "ScatterMoEGatedMLP": {
        "cuda": LocalLayerRepository(
            repo_path=Path("/u/shawntan/hf_scattermoe"),
            package_name='scattermoe',
            layer_name="ScatterMoEGatedMLP"
        )
        # "cuda": LayerRepository(
        #     repo_id='shawntan/scattermoe',
        #     layer_name='ScatterMoEGatedMLP'
        # )
    }
}

# scattermoe = get_kernel("shawntan/scattermoe")
config = AutoConfig.from_pretrained(model_path, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = GraniteMoeHybridMoE(config).to(device)
for p in model.parameters():
    torch.nn.init.normal_(p, std=0.02)
x = torch.randn(4, 4096, 1536, device=device)
out_reference, _ = model(x)
with use_kernel_mapping(kernel_layer_mapping):
    model = kernelize(model, mode=Mode.TRAINING)
    out_kernel, _ = model(x)

print((out_reference - out_kernel).abs().max())

What further steps need to be done for submitting it to community kernels? And also to include it in the use_kernel_forward_from_hub mapping?

@shawntan
Copy link
Contributor

shawntan commented Oct 6, 2025

@ArthurZucker any thoughts?

@3outeille
Copy link
Member

@MekkCyber

@MekkCyber
Copy link
Contributor

Hey @shawntan thanks for this contribution, the kernel looks good to me! We can have it in kernels-community, then we need some tests, and benchmarks to see if we include it in the kernel mapping in transformers

@shawntan
Copy link
Contributor

shawntan commented Oct 7, 2025

Alright! What do you need specifically? I can come up with some benchmarks for the Granite models.

Copy link
Contributor

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that would be great! we need some latency & memory benchmarks with and without the kernel for different seq len sizes to see if we have some speedups or improved memory consumption

@shawntan
Copy link
Contributor

shawntan commented Oct 7, 2025

https://arxiv.org/pdf/2403.08245 The report for ScatterMoE has some benchmarks on previous models compared against Megablocks.

It includes both speedup and memory usage comparisons.

image

@MekkCyber
Copy link
Contributor

Very nice performance @shawntan ! Thanks for sharing ! So the kernel is used for training not inference ?

@shawntan
Copy link
Contributor

shawntan commented Oct 8, 2025

It can be used for both. But gains will mainly come for training, or in cases of prefill.

I've made the draft changes for transformers: shawntan@3297b5e

@shawntan
Copy link
Contributor

shawntan commented Oct 8, 2025

Started another PR here #41458 because it is significantly different from the current one.

Copy link
Contributor

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let’s close this PR then it will be superseded by #41458

@mayank31398 mayank31398 closed this Oct 9, 2025
@mayank31398
Copy link
Contributor Author

ciao

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants