Skip to content

Commit cdaa40d

Browse files
committed
ScatterMoE
1 parent 1fd63dd commit cdaa40d

File tree

5 files changed

+11
-0
lines changed

5 files changed

+11
-0
lines changed

src/transformers/integrations/hub_kernels.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ def use_kernel_forward_from_hub(layer_name: str):
115115
)
116116
},
117117
},
118+
"ScatterMoEGatedMLP": {
119+
"cuda": {
120+
Mode.TRAINING: LayerRepository(repo_id="shawntan/scattermoe", layer_name="ScatterMoEGatedMLP"),
121+
Mode.INFERENCE: LayerRepository(repo_id="shawntan/scattermoe", layer_name="ScatterMoEGatedMLP"),
122+
},
123+
},
118124
"FastGELU": {
119125
"cuda": {
120126
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(

src/transformers/models/granitemoe/modeling_granitemoe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def forward(self, hidden_states):
221221
return index_sorted_experts, batch_index, batch_gates, expert_size, logits
222222

223223

224+
@use_kernel_forward_from_hub("ScatterMoEGatedMLP")
224225
class GraniteMoeMoE(nn.Module):
225226
"""
226227
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.

src/transformers/models/granitemoe/modular_granitemoe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from ...activations import ACT2FN
2222
from ...cache_utils import Cache, DynamicCache
23+
from ...integrations import use_kernel_forward_from_hub
2324
from ...masking_utils import create_causal_mask
2425
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
2526
from ...modeling_utils import PreTrainedModel
@@ -49,6 +50,7 @@ class GraniteMoeTopKGating(JetMoeTopKGating):
4950
pass
5051

5152

53+
@use_kernel_forward_from_hub("ScatterMoEGatedMLP")
5254
class GraniteMoeMoE(nn.Module):
5355
"""
5456
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.

src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,7 @@ def forward(self, hidden_states):
10661066
return index_sorted_experts, batch_index, batch_gates, expert_size, logits
10671067

10681068

1069+
@use_kernel_forward_from_hub("ScatterMoEGatedMLP")
10691070
class GraniteMoeHybridMoE(nn.Module):
10701071
"""
10711072
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.

src/transformers/models/granitemoeshared/modeling_granitemoeshared.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def forward(self, hidden_states):
207207
return index_sorted_experts, batch_index, batch_gates, expert_size, logits
208208

209209

210+
@use_kernel_forward_from_hub("ScatterMoEGatedMLP")
210211
class GraniteMoeSharedMoE(nn.Module):
211212
"""
212213
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.

0 commit comments

Comments
 (0)