Skip to content

Conversation

@shawntan
Copy link
Contributor

@shawntan shawntan commented Oct 8, 2025

What does this PR do?

Adds ScatterMoE kernel support for Granite MoE models.
Started in #40365 but has significantly deviated in approach, so starting a new pull request.

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@MekkCyber already started to provide some comments in #40365.

Copy link
Contributor

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this pr @shawntan ! can you open a pr here : https://github.com/huggingface/kernels-community to add your kernel without the build folder, so that we can review the source code, and once merged we will upload the builds on the hub, and then we can merge this pr so everyone can use the kernel. If you don't have time i'm happy to help 🤗

@shawntan
Copy link
Contributor Author

shawntan commented Oct 9, 2025

So far I've simply been copying from torch-ext to build/torch-universal to "build" the kernel.

Should I follow the example here: https://github.com/huggingface/kernels-community/tree/main/trimul_gpumode

It's another triton based model with universal "build".

@shawntan
Copy link
Contributor Author

Should I be targetting main or the latest release?

shawntan referenced this pull request Oct 13, 2025
* update modeling mixtral

* oups[13;2u

* fix

* better naming?

* compute softmax and top_k inside the experts

* update minamax as well

* models that will need an update

* more models that need a fix

* stash

* fix mixtral

* update olmoe

* update

* update

* current changes

* nits

* molmoe is now fixed

* olmoe is good to go!

* refactor qwen2_moe

* fixes

* fixed moe

* fix qwen2 modular

* nit

* qwen2_moie test script works

* tricky rope !

* fix qwen3

* DeepSeek v3 MoE Standardization (#40538)

* DeepSeek-v3

Shared

Shared

* Dependents of DS3

* Standardize GLM4V MoE (#40539)

* up

* Standardize VitPose's MoE (#40549)

* VitPose

* outside

* outside

* outside

* fix

* update dbrx

* dbrx... the magix

* Refactor Ernie 4.5's MoE (#40547)

* Isolate Ernie fixes

* fix moe

---------

Co-authored-by: Vasqu <antonprogamer@gmail.com>

* fix style

* style

* fix copies

* style

* latest changes

* fixes

* had to stage

* current updaters

* up

* another modular

* modular graniteMoe

* some update

* draft another modular moe

* updaters

* up

* fix nit

* q3 nit

* fix phi moe

* we're going up up up up its our mooooment

* fix switch transformers this time around

* up

* gptsan japanese is deprecated forget about it

* fix mixtral to not be a linear (gives us more freedom)

* update

* fix copies gone wrong try catch nothing

* fix mixtral

* new refactor again

* update aria as well

* up dbrx and deepseekv3

* nit

* fix phimoe?

* fix deepseek v3

* nits

* don't bother with this one please

* up olmoe

* ??

* fix olmoe

* yups

* fiupx

* ish

* hot patch

* new qwen3

* updates

* up

* nit

* fix copies

* fix

* nits

* we're going up up up

* nits

* switch_transformesr edge case

* lol modular gptsan?

* fix deepseek

* finally all modeling match modular

* update

* up

* up

* dang

* up

* up aria

* fix dbrx

* nits here and there

* finish fixing dbrx

* fix deepseek

* upd

* up

* fix flex olmo

* updated

* update jamba

* JAMBA is stil a bit todo

* forward forward

* fix dots11

* update

* fix hunyuan

* fix some other

* update phimoe

* fuck you phimoe you are now submitted

* submit granitemoe as well

* try to fix some other models, reduces some of the failures

* fix olmoe and qwem2moe

* up

* up

* fix qwen2_moe

* update modular make it again, simpler

* nits

* up

* up

* fix

* someswitch reductions

* up

* fix qwen3vl

* some fixes to jetmo

* these should be shipped to the modular to fix jetmoe

* fix most of the nllb failures

* more nllb fixes

* fix the modular

* remove nllb modular as it sucks for now

* ?

* fix granitemoe

* granitemoehybrid don't have rope

* use rope when rope, no rope when no rope

* updates

* finish fixing dumbgrainite

* fix most of minimax

* fix

* update modular

* ?

* up

* up jetmoe still broken

* up

* fix, now align the moe

* fix jetmoe

* fix styling and qwen3 repo consitency

* updatge

* up up

* update ruff?

* nits

* modeling is goot now for switch

* fix

* more fixses to switch!

* fix some siwtch test

* ?

* ?

* up

* fix switch modular!

* nit?

* uip

* subtest

* can't believe I wasted so much time on this...

* fix

* updates

* nits

* nit jamba is fucking annoying

* ?

* fix?

* oups

* good good

* styling

* up

* make sure qwen2 sliding works!

* fix dbrx small

* lol

* nits

* fix one test

* fix load balancing loss issue

* fix jamba

* fix nllbmoe

* fix jamba consistency and doc?

* up

* thse are correct

* up

* up

* up

* some of the final cleanup

* update

* up

* fix some revert in granimoe

* bring back attention multipliers for the granite family we'll see later on if they need removal

* small jamba fix docstring and typing

* fix phimoe

* yup

* fix unk returndict in granitemoes

* up

* fix qwen config

* fix phiemoe check quality

* nits

* update based on caught non relative imports!

* fix dbrx

* Apply suggestions from code review

Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>

* fix copies

* fiuxp

* fix dot1 regression!

* fix phimoe issue

* fix phi moe

* fix float() for some models

* fix jamba regression

* ui

* more dtype issues

* fix deepseek2 and 3?

* proper update

* fix modular deepseek!

* jamba jambaaaaaa

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
Co-authored-by: Vasqu <antonprogamer@gmail.com>
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
@shawntan
Copy link
Contributor Author

Some changes have been made to GraniteMoE in the latest main that affects how the kernel is used: e6a8e7d...7938e91#diff-49f545f0c27ad25a565417d120d925f6bfcde532d6e0ea5539dcec182ad978aaL373

I haven't been able to test out the current main branch successfully, there seems to be many breaking changes, so I need some help in modifying the current PR:

  1. Should I target the main branch or the latest release?
  2. Should I be modifying GraniteMoE to the way it is supposed to be?
  3. Changing GraniteMoeMoE back and regenerating the modeling_ files of the downstream classes would fix the issue, but what other models with MoEs are affected and how was the change decided?

@shawntan shawntan force-pushed the scattermoe branch 2 times, most recently from 4fd05b1 to 2208a35 Compare October 15, 2025 21:32
@shawntan
Copy link
Contributor Author

@MekkCyber will need some help here. Which branch should I target for the PR: main or the latest release?

main has issues with the hf API URLs. The current state of the PR targets the latest release.

@MekkCyber
Copy link
Contributor

MekkCyber commented Oct 16, 2025

Hey @shawntan, let's target main! I think you did some faulty rebase, let's only keep the relevant files in the PR.
I'm not sure what you mean by modifying GraniteMoE but we shouldn't touch the modeling files, the kernel should be applied using the decorator solely.

@shawntan
Copy link
Contributor Author

shawntan commented Oct 16, 2025

I've reverted the PR back to the original one against main

My issue was this line:

main: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granitemoe/modeling_granitemoe.py#L235

v4.57.1: https://github.com/huggingface/transformers/blob/v4.57.1/src/transformers/models/granitemoe/modeling_granitemoe.py#L320

router_logits has been removed. I cannot figure out why or how it got removed, but I think it breaks the Granite models downstream of that class. It's also hard for me to test this since pulling from the hub is also broken in the latest main.

@MekkCyber
Copy link
Contributor

MekkCyber commented Oct 17, 2025

router_logits has been removed

Yes it was removed in the MoE refactor for vllm compatibility with transformers but this shouldn't break anything in the transformers implementation

It's also hard for me to test this since pulling from the hub is also broken in the latest main.

What do you mean ?

@shawntan
Copy link
Contributor Author

shawntan commented Oct 17, 2025

router_logits has been removed

Yes it was removed in the MoE refactor for vllm compatibility with transformers but this shouldn't break anything in the transformers implementation

I see. I will need to change the layer definition in the kernel since it will only produce one output, instead of the tuple it returns right now. I am also confused as to how the auxiliary loss will work if the MoE doesn't return the router_logits. I will look into it.

Update:
Looking at it, the current changes break the output_router_logits=True path.

The attribute is checked for during the forward call:

But it is never passed to the model:

outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
**kwargs,
)

There is also no pathway for passing the router logits, or recomputing them in GraniteMoeModel:

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> MoeModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask( # ONLY DIFF WITH MIXTRAL: NO SLIDING
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
inputs_embeds = inputs_embeds * self.embedding_multiplier
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)

What's the ideal way that the HF team is thinking of allowing for this while still maintaining vllm compatibility? I can make the necessary changes.

It's also hard for me to test this since pulling from the hub is also broken in the latest main.

What do you mean ?

Hmm, sorry, I can't seem to reproduce the same issue I saw before with the jinja files.

Still, there are some incompatibilities from the Mamba end of things, right now with both repositories on main, the error being thrown is this:

  File "/proj/checkpoints/shawntan/transformers/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py", line 47, in <module>
    from mamba_ssm.ops.triton.selective_state_update import selective_state_update
  File "/u/shawntan/mamba/mamba_ssm/__init__.py", line 6, in <module>
    from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
  File "/u/shawntan/mamba/mamba_ssm/models/mixer_seq_simple.py", line 20, in <module>
    from mamba_ssm.utils.generation import GenerationMixin
  File "/u/shawntan/mamba/mamba_ssm/utils/generation.py", line 14, in <module>
    from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
ImportError: cannot import name 'GreedySearchDecoderOnlyOutput' from 'transformers.generation' (/proj/checkpoints/shawntan/transformers/src/transformers/generation/__init__.py)

@shawntan
Copy link
Contributor Author

shawntan commented Oct 20, 2025

TL;DR:

  1. If I change the kernel to suit the current state of GraniteMoeMoE, it would need 1 output, but the problem would be the calculating of the aux loss would be broken.
  2. I don't believe that is the intended behaviour, so my question is: How should the GraniteMoeMoE class be modified to best support the vLLM compatibility while not breaking previous functionality (aux loss)?

@MekkCyber

@shawntan shawntan mentioned this pull request Oct 21, 2025
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! #41580 was merged, IDK if you need more changes but indeed for collecting router logits you might need a small change! You could put the collection on the gate linear layer?

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: granitemoe, granitemoehybrid, granitemoeshared

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants