Skip to content

Commit 4fd05b1

Browse files
committed
ScatterMoE
1 parent 0419ff8 commit 4fd05b1

File tree

5 files changed

+11
-0
lines changed

5 files changed

+11
-0
lines changed

src/transformers/integrations/hub_kernels.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@
8282
)
8383
},
8484
},
85+
"ScatterMoEGatedMLP": {
86+
"cuda": {
87+
Mode.TRAINING: LayerRepository(repo_id="shawntan/scattermoe", layer_name="ScatterMoEGatedMLP"),
88+
Mode.INFERENCE: LayerRepository(repo_id="shawntan/scattermoe", layer_name="ScatterMoEGatedMLP"),
89+
},
90+
},
8591
"FastGELU": {
8692
"cuda": {
8793
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(

src/transformers/models/granitemoe/modeling_granitemoe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def forward(self, hidden_states):
192192
return index_sorted_experts, batch_index, batch_gates, expert_size, logits
193193

194194

195+
@use_kernel_forward_from_hub("ScatterMoEGatedMLP")
195196
class GraniteMoeMoE(nn.Module):
196197
"""
197198
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.

src/transformers/models/granitemoe/modular_granitemoe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from ...activations import ACT2FN
2222
from ...cache_utils import Cache, DynamicCache
23+
from ...integrations import use_kernel_forward_from_hub
2324
from ...masking_utils import create_causal_mask
2425
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
2526
from ...modeling_utils import PreTrainedModel
@@ -49,6 +50,7 @@ class GraniteMoeTopKGating(JetMoeTopKGating):
4950
pass
5051

5152

53+
@use_kernel_forward_from_hub("ScatterMoEGatedMLP")
5254
class GraniteMoeMoE(nn.Module):
5355
"""
5456
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.

src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,7 @@ def forward(self, hidden_states):
10001000
return index_sorted_experts, batch_index, batch_gates, expert_size, logits
10011001

10021002

1003+
@use_kernel_forward_from_hub("ScatterMoEGatedMLP")
10031004
class GraniteMoeHybridMoE(nn.Module):
10041005
"""
10051006
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.

src/transformers/models/granitemoeshared/modeling_granitemoeshared.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def forward(self, hidden_states):
207207
return index_sorted_experts, batch_index, batch_gates, expert_size, logits
208208

209209

210+
@use_kernel_forward_from_hub("ScatterMoEGatedMLP")
210211
class GraniteMoeSharedMoE(nn.Module):
211212
"""
212213
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.

0 commit comments

Comments
 (0)