Skip to content

Commit 7938e91

Browse files
ArthurZuckerLysandreJikvasquCyrilvallez
authored
MoE + vllm = 😻 (#40132)
* update modeling mixtral * oups[13;2u * fix * better naming? * compute softmax and top_k inside the experts * update minamax as well * models that will need an update * more models that need a fix * stash * fix mixtral * update olmoe * update * update * current changes * nits * molmoe is now fixed * olmoe is good to go! * refactor qwen2_moe * fixes * fixed moe * fix qwen2 modular * nit * qwen2_moie test script works * tricky rope ! * fix qwen3 * DeepSeek v3 MoE Standardization (#40538) * DeepSeek-v3 Shared Shared * Dependents of DS3 * Standardize GLM4V MoE (#40539) * up * Standardize VitPose's MoE (#40549) * VitPose * outside * outside * outside * fix * update dbrx * dbrx... the magix * Refactor Ernie 4.5's MoE (#40547) * Isolate Ernie fixes * fix moe --------- Co-authored-by: Vasqu <antonprogamer@gmail.com> * fix style * style * fix copies * style * latest changes * fixes * had to stage * current updaters * up * another modular * modular graniteMoe * some update * draft another modular moe * updaters * up * fix nit * q3 nit * fix phi moe * we're going up up up up its our mooooment * fix switch transformers this time around * up * gptsan japanese is deprecated forget about it * fix mixtral to not be a linear (gives us more freedom) * update * fix copies gone wrong try catch nothing * fix mixtral * new refactor again * update aria as well * up dbrx and deepseekv3 * nit * fix phimoe? * fix deepseek v3 * nits * don't bother with this one please * up olmoe * ?? * fix olmoe * yups * fiupx * ish * hot patch * new qwen3 * updates * up * nit * fix copies * fix * nits * we're going up up up * nits * switch_transformesr edge case * lol modular gptsan? * fix deepseek * finally all modeling match modular * update * up * up * dang * up * up aria * fix dbrx * nits here and there * finish fixing dbrx * fix deepseek * upd * up * fix flex olmo * updated * update jamba * JAMBA is stil a bit todo * forward forward * fix dots11 * update * fix hunyuan * fix some other * update phimoe * fuck you phimoe you are now submitted * submit granitemoe as well * try to fix some other models, reduces some of the failures * fix olmoe and qwem2moe * up * up * fix qwen2_moe * update modular make it again, simpler * nits * up * up * fix * someswitch reductions * up * fix qwen3vl * some fixes to jetmo * these should be shipped to the modular to fix jetmoe * fix most of the nllb failures * more nllb fixes * fix the modular * remove nllb modular as it sucks for now * ? * fix granitemoe * granitemoehybrid don't have rope * use rope when rope, no rope when no rope * updates * finish fixing dumbgrainite * fix most of minimax * fix * update modular * ? * up * up jetmoe still broken * up * fix, now align the moe * fix jetmoe * fix styling and qwen3 repo consitency * updatge * up up * update ruff? * nits * modeling is goot now for switch * fix * more fixses to switch! * fix some siwtch test * ? * ? * up * fix switch modular! * nit? * uip * subtest * can't believe I wasted so much time on this... * fix * updates * nits * nit jamba is fucking annoying * ? * fix? * oups * good good * styling * up * make sure qwen2 sliding works! * fix dbrx small * lol * nits * fix one test * fix load balancing loss issue * fix jamba * fix nllbmoe * fix jamba consistency and doc? * up * thse are correct * up * up * up * some of the final cleanup * update * up * fix some revert in granimoe * bring back attention multipliers for the granite family we'll see later on if they need removal * small jamba fix docstring and typing * fix phimoe * yup * fix unk returndict in granitemoes * up * fix qwen config * fix phiemoe check quality * nits * update based on caught non relative imports! * fix dbrx * Apply suggestions from code review Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co> * fix copies * fiuxp * fix dot1 regression! * fix phimoe issue * fix phi moe * fix float() for some models * fix jamba regression * ui * more dtype issues * fix deepseek2 and 3? * proper update * fix modular deepseek! * jamba jambaaaaaa --------- Co-authored-by: Lysandre Debut <hi@lysand.re> Co-authored-by: Vasqu <antonprogamer@gmail.com> Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
1 parent e6a8e7d commit 7938e91

File tree

86 files changed

+8433
-9185
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+8433
-9185
lines changed

docs/source/en/model_doc/switch_transformers.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ print(tokenizer.decode(outputs[0]))
105105
## SwitchTransformersTop1Router
106106

107107
[[autodoc]] SwitchTransformersTop1Router
108-
- _compute_router_probabilities
109108
- forward
110109

111110
## SwitchTransformersSparseMLP

src/transformers/modeling_outputs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ class MoEModelOutput(ModelOutput):
357357
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
358358
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
359359
router_probs: Optional[tuple[torch.FloatTensor]] = None
360+
router_logits: Optional[tuple[torch.FloatTensor]] = None
360361

361362

362363
@dataclass
@@ -494,6 +495,7 @@ class MoEModelOutputWithPastAndCrossAttentions(ModelOutput):
494495
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
495496
cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
496497
router_probs: Optional[tuple[torch.FloatTensor]] = None
498+
router_logits: Optional[tuple[torch.FloatTensor]] = None
497499

498500

499501
@dataclass

src/transformers/models/aria/modeling_aria.py

Lines changed: 31 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -309,117 +309,65 @@ def forward(self, input, tokens_per_expert):
309309
)
310310

311311

312-
class AriaGroupedExpertsMLP(nn.Module):
313-
"""
314-
Grouped MLP module for Mixture of Experts.
315-
316-
Args:
317-
config (`AriaTextConfig`):
318-
Configuration object for the model.
319-
"""
320-
312+
class AriaExperts(nn.Module):
321313
def __init__(self, config: AriaTextConfig) -> None:
322314
super().__init__()
323315
self.config = config
324316
self.fc1 = AriaGroupedExpertsGemm(config.hidden_size, config.intermediate_size * 2, config.moe_num_experts)
325317
self.fc2 = AriaGroupedExpertsGemm(config.intermediate_size, config.hidden_size, config.moe_num_experts)
326318

327-
def forward(self, permuted_tokens, tokens_per_expert):
328-
"""
329-
Forward pass of the Grouped MLP.
330-
331-
Args:
332-
permuted_tokens (torch.Tensor): Permuted input tokens.
333-
tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
334-
335-
Returns:
336-
torch.Tensor: Output tensor after passing through the MLP.
337-
"""
338-
fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
339-
projection, gate = torch.chunk(fc1_output, 2, dim=-1)
340-
fc1_output = nn.functional.silu(projection) * gate
341-
fc2_output = self.fc2(fc1_output, tokens_per_expert)
342-
return fc2_output
343-
344-
345-
# Token permutation adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/token_dispatcher.py#L291-L587
346-
class AriaTextMoELayer(nn.Module):
347-
"""
348-
Aria Text Mixture of Experts (MoE) Layer.
349-
350-
This layer applies a gating mechanism to route input tokens to different experts.
351-
352-
Args:
353-
config (`AriaTextConfig`):
354-
Configuration object for the text component of the model.
355-
"""
356-
357-
def __init__(self, config: AriaTextConfig):
358-
super().__init__()
359-
360-
self.router = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False)
361-
self.experts = AriaGroupedExpertsMLP(config)
362-
self.shared_experts = AriaSharedExpertsMLP(config)
363-
self.config = config
364-
365-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
366-
"""
367-
Forward pass of the MoE Layer.
368-
369-
Args:
370-
hidden_states (`torch.Tensor`):
371-
Input tensor of shape (batch_size, sequence_length, hidden_size).
372-
373-
Returns:
374-
torch.Tensor: Output tensor after passing through the MoE layer.
375-
376-
Process:
377-
1. Route tokens to experts using the router.
378-
2. Permute tokens based on routing decisions.
379-
3. Process tokens through experts.
380-
4. Unpermute and combine expert outputs.
381-
5. Add shared expert output to the final result.
382-
"""
383-
original_shape = hidden_states.shape
384-
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
385-
386-
# Top K Routing
387-
logits = self.router(hidden_states)
388-
top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
319+
def route_tokens_to_experts(self, router_logits):
320+
top_logits, top_indices = torch.topk(router_logits, k=self.config.moe_topk, dim=1)
389321
scores = nn.functional.softmax(top_logits, dim=-1)
322+
return top_indices, scores
390323

391-
original_dtype = top_indices.dtype
392-
324+
def forward(self, hidden_states, router_logits) -> torch.Tensor:
325+
top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits)
326+
original_dtype = top_k_index.dtype
393327
tokens_per_expert = torch.histc(
394-
top_indices.flatten().to(torch.float32),
328+
top_k_index.flatten().to(torch.float32),
395329
bins=self.config.moe_num_experts,
396330
min=0,
397331
max=self.config.moe_num_experts - 1,
398332
).to(original_dtype)
399-
indices = top_indices
333+
indices = top_k_index
400334

401-
# Token permutation
402335
flatten_indices = indices.view(-1)
403336
sorted_indices = torch.argsort(flatten_indices)
404337
permuted_tokens = hidden_states.index_select(0, sorted_indices // self.config.moe_topk)
405338

406-
# Process through experts
407-
expert_output = self.experts(permuted_tokens, tokens_per_expert)
339+
fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
340+
projection, gate = torch.chunk(fc1_output, 2, dim=-1)
341+
fc1_output = nn.functional.silu(projection) * gate
342+
expert_output = self.fc2(fc1_output, tokens_per_expert)
408343

409-
# Token unpermutation
410344
unpermuted_tokens = torch.zeros(
411-
(scores.shape[0] * self.config.moe_topk, expert_output.size(1)),
345+
(top_k_weights.shape[0] * self.config.moe_topk, expert_output.size(1)),
412346
dtype=expert_output.dtype,
413347
device=expert_output.device,
414348
)
415349
unpermuted_tokens.index_copy_(0, sorted_indices, expert_output)
416350
unpermuted_tokens = unpermuted_tokens.view(-1, self.config.moe_topk, expert_output.size(1))
417351

418-
output = (unpermuted_tokens * scores.unsqueeze(-1)).sum(dim=1).view(original_shape)
352+
output = (unpermuted_tokens * top_k_weights.unsqueeze(-1)).sum(dim=1)
353+
return output
419354

420-
# Add shared expert output
355+
356+
class AriaTextMoELayer(nn.Module):
357+
def __init__(self, config: AriaTextConfig):
358+
super().__init__()
359+
self.router = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False)
360+
self.experts = AriaExperts(config)
361+
self.shared_experts = AriaSharedExpertsMLP(config)
362+
self.config = config
363+
364+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
365+
original_shape = hidden_states.shape
366+
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
367+
router_logits = self.router(hidden_states)
368+
expert_output = self.experts(hidden_states, router_logits).view(original_shape)
421369
shared_expert_output = self.shared_experts(hidden_states.view(original_shape))
422-
return output + shared_expert_output
370+
return expert_output + shared_expert_output
423371

424372

425373
def rotate_half(x):

src/transformers/models/aria/modular_aria.py

Lines changed: 31 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,117 +1120,65 @@ def forward(self, input, tokens_per_expert):
11201120
)
11211121

11221122

1123-
class AriaGroupedExpertsMLP(nn.Module):
1124-
"""
1125-
Grouped MLP module for Mixture of Experts.
1126-
1127-
Args:
1128-
config (`AriaTextConfig`):
1129-
Configuration object for the model.
1130-
"""
1131-
1123+
class AriaExperts(nn.Module):
11321124
def __init__(self, config: AriaTextConfig) -> None:
11331125
super().__init__()
11341126
self.config = config
11351127
self.fc1 = AriaGroupedExpertsGemm(config.hidden_size, config.intermediate_size * 2, config.moe_num_experts)
11361128
self.fc2 = AriaGroupedExpertsGemm(config.intermediate_size, config.hidden_size, config.moe_num_experts)
11371129

1138-
def forward(self, permuted_tokens, tokens_per_expert):
1139-
"""
1140-
Forward pass of the Grouped MLP.
1141-
1142-
Args:
1143-
permuted_tokens (torch.Tensor): Permuted input tokens.
1144-
tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
1145-
1146-
Returns:
1147-
torch.Tensor: Output tensor after passing through the MLP.
1148-
"""
1149-
fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
1150-
projection, gate = torch.chunk(fc1_output, 2, dim=-1)
1151-
fc1_output = nn.functional.silu(projection) * gate
1152-
fc2_output = self.fc2(fc1_output, tokens_per_expert)
1153-
return fc2_output
1154-
1155-
1156-
# Token permutation adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/token_dispatcher.py#L291-L587
1157-
class AriaTextMoELayer(nn.Module):
1158-
"""
1159-
Aria Text Mixture of Experts (MoE) Layer.
1160-
1161-
This layer applies a gating mechanism to route input tokens to different experts.
1162-
1163-
Args:
1164-
config (`AriaTextConfig`):
1165-
Configuration object for the text component of the model.
1166-
"""
1167-
1168-
def __init__(self, config: AriaTextConfig):
1169-
super().__init__()
1170-
1171-
self.router = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False)
1172-
self.experts = AriaGroupedExpertsMLP(config)
1173-
self.shared_experts = AriaSharedExpertsMLP(config)
1174-
self.config = config
1175-
1176-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1177-
"""
1178-
Forward pass of the MoE Layer.
1179-
1180-
Args:
1181-
hidden_states (`torch.Tensor`):
1182-
Input tensor of shape (batch_size, sequence_length, hidden_size).
1183-
1184-
Returns:
1185-
torch.Tensor: Output tensor after passing through the MoE layer.
1186-
1187-
Process:
1188-
1. Route tokens to experts using the router.
1189-
2. Permute tokens based on routing decisions.
1190-
3. Process tokens through experts.
1191-
4. Unpermute and combine expert outputs.
1192-
5. Add shared expert output to the final result.
1193-
"""
1194-
original_shape = hidden_states.shape
1195-
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
1196-
1197-
# Top K Routing
1198-
logits = self.router(hidden_states)
1199-
top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
1130+
def route_tokens_to_experts(self, router_logits):
1131+
top_logits, top_indices = torch.topk(router_logits, k=self.config.moe_topk, dim=1)
12001132
scores = nn.functional.softmax(top_logits, dim=-1)
1133+
return top_indices, scores
12011134

1202-
original_dtype = top_indices.dtype
1203-
1135+
def forward(self, hidden_states, router_logits) -> torch.Tensor:
1136+
top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits)
1137+
original_dtype = top_k_index.dtype
12041138
tokens_per_expert = torch.histc(
1205-
top_indices.flatten().to(torch.float32),
1139+
top_k_index.flatten().to(torch.float32),
12061140
bins=self.config.moe_num_experts,
12071141
min=0,
12081142
max=self.config.moe_num_experts - 1,
12091143
).to(original_dtype)
1210-
indices = top_indices
1144+
indices = top_k_index
12111145

1212-
# Token permutation
12131146
flatten_indices = indices.view(-1)
12141147
sorted_indices = torch.argsort(flatten_indices)
12151148
permuted_tokens = hidden_states.index_select(0, sorted_indices // self.config.moe_topk)
12161149

1217-
# Process through experts
1218-
expert_output = self.experts(permuted_tokens, tokens_per_expert)
1150+
fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
1151+
projection, gate = torch.chunk(fc1_output, 2, dim=-1)
1152+
fc1_output = nn.functional.silu(projection) * gate
1153+
expert_output = self.fc2(fc1_output, tokens_per_expert)
12191154

1220-
# Token unpermutation
12211155
unpermuted_tokens = torch.zeros(
1222-
(scores.shape[0] * self.config.moe_topk, expert_output.size(1)),
1156+
(top_k_weights.shape[0] * self.config.moe_topk, expert_output.size(1)),
12231157
dtype=expert_output.dtype,
12241158
device=expert_output.device,
12251159
)
12261160
unpermuted_tokens.index_copy_(0, sorted_indices, expert_output)
12271161
unpermuted_tokens = unpermuted_tokens.view(-1, self.config.moe_topk, expert_output.size(1))
12281162

1229-
output = (unpermuted_tokens * scores.unsqueeze(-1)).sum(dim=1).view(original_shape)
1163+
output = (unpermuted_tokens * top_k_weights.unsqueeze(-1)).sum(dim=1)
1164+
return output
12301165

1231-
# Add shared expert output
1166+
1167+
class AriaTextMoELayer(nn.Module):
1168+
def __init__(self, config: AriaTextConfig):
1169+
super().__init__()
1170+
self.router = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False)
1171+
self.experts = AriaExperts(config)
1172+
self.shared_experts = AriaSharedExpertsMLP(config)
1173+
self.config = config
1174+
1175+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1176+
original_shape = hidden_states.shape
1177+
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
1178+
router_logits = self.router(hidden_states)
1179+
expert_output = self.experts(hidden_states, router_logits).view(original_shape)
12321180
shared_expert_output = self.shared_experts(hidden_states.view(original_shape))
1233-
return output + shared_expert_output
1181+
return expert_output + shared_expert_output
12341182

12351183

12361184
class AriaTextAttention(LlamaAttention):

src/transformers/models/bamba/modeling_bamba.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -159,16 +159,24 @@ def update(
159159

160160
def reorder_cache(self, beam_idx: torch.LongTensor):
161161
"""Reorders the cache for beam search, given the selected beam indices."""
162-
for layer_idx in range(len(self.key_cache)):
163-
device = self.key_cache[layer_idx].device
164-
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
165-
device = self.value_cache[layer_idx].device
166-
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
167-
168-
device = self.conv_states[layer_idx].device
169-
self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
170-
device = self.ssm_states[layer_idx].device
171-
self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
162+
if self.get_seq_length() > 0:
163+
for layer_idx in range(len(self.key_cache)):
164+
device = self.key_cache[layer_idx].device
165+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
166+
device = self.value_cache[layer_idx].device
167+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
168+
169+
device = self.conv_states[layer_idx].device
170+
self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
171+
device = self.ssm_states[layer_idx].device
172+
self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
173+
174+
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
175+
"""Return the length and offset of the cache, used to generate the mask"""
176+
kv_offset = 0
177+
query_length = cache_position.shape[0]
178+
kv_length = self.get_seq_length(layer_idx) + query_length
179+
return kv_length, kv_offset
172180

173181
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
174182
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""

0 commit comments

Comments
 (0)