Skip to content

Commit 2208a35

Browse files
committed
Adding ScatterMoE.
1 parent 8ac2b91 commit 2208a35

File tree

5 files changed

+20
-0
lines changed

5 files changed

+20
-0
lines changed

src/transformers/dependency_versions_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
"opencv-python": "opencv-python",
4747
"optimum-benchmark": "optimum-benchmark>=0.3.0",
4848
"optuna": "optuna",
49+
"optax": "optax>=0.08,<=0.1.4",
4950
"pandas": "pandas<2.3.0",
5051
"packaging": "packaging>=20.0",
5152
"parameterized": "parameterized>=0.9",

src/transformers/integrations/hub_kernels.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,16 @@
8484
)
8585
},
8686
},
87+
"ScatterMoEGatedMLP": {
88+
"cuda": {
89+
Mode.TRAINING: LayerRepository(
90+
repo_id="kernels-community/scattermoe", layer_name="ScatterMoEGatedMLP"
91+
),
92+
Mode.INFERENCE: LayerRepository(
93+
repo_id="kernels-community/scattermoe", layer_name="ScatterMoEGatedMLP"
94+
),
95+
},
96+
},
8797
"FastGELU": {
8898
"cuda": {
8999
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(

src/transformers/models/granitemoe/modeling_granitemoe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ...activations import ACT2FN
2323
from ...cache_utils import Cache, DynamicCache
2424
from ...generation import GenerationMixin
25+
from ...integrations import use_kernel_forward_from_hub
2526
from ...modeling_attn_mask_utils import AttentionMaskConverter
2627
from ...modeling_layers import GradientCheckpointingLayer
2728
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -129,6 +130,7 @@ def load_balancing_loss_func(
129130

130131

131132
# Copied from transformers.models.granite.modeling_granite.GraniteRMSNorm with Granite->GraniteMoe
133+
@use_kernel_forward_from_hub("RMSNorm")
132134
class GraniteMoeRMSNorm(nn.Module):
133135
def __init__(self, hidden_size, eps=1e-6):
134136
"""
@@ -317,6 +319,7 @@ def forward(self, hidden_states):
317319
return index_sorted_experts, batch_index, batch_gates, expert_size, logits
318320

319321

322+
@use_kernel_forward_from_hub("ScatterMoEGatedMLP")
320323
class GraniteMoeMoE(nn.Module):
321324
"""
322325
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.

src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from ...cache_utils import Cache
3131
from ...generation import GenerationMixin
32+
from ...integrations import use_kernel_forward_from_hub
3233
from ...modeling_attn_mask_utils import AttentionMaskConverter
3334
from ...modeling_layers import GradientCheckpointingLayer
3435
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -934,6 +935,7 @@ class GraniteFlashAttentionKwargs(TypedDict, total=False):
934935
seq_idx: torch.IntTensor
935936

936937

938+
@use_kernel_forward_from_hub("RMSNorm")
937939
class GraniteMoeHybridRMSNorm(nn.Module):
938940
def __init__(self, hidden_size, eps=1e-6):
939941
"""
@@ -1047,6 +1049,7 @@ def forward(self, hidden_states):
10471049
return index_sorted_experts, batch_index, batch_gates, expert_size, logits
10481050

10491051

1052+
@use_kernel_forward_from_hub("ScatterMoEGatedMLP")
10501053
class GraniteMoeHybridMoE(nn.Module):
10511054
"""
10521055
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.

src/transformers/models/granitemoeshared/modeling_granitemoeshared.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ...activations import ACT2FN
2929
from ...cache_utils import Cache, DynamicCache
3030
from ...generation import GenerationMixin
31+
from ...integrations import use_kernel_forward_from_hub
3132
from ...modeling_attn_mask_utils import AttentionMaskConverter
3233
from ...modeling_layers import GradientCheckpointingLayer
3334
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -99,6 +100,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
99100
return hidden_states
100101

101102

103+
@use_kernel_forward_from_hub("RMSNorm")
102104
class GraniteMoeSharedRMSNorm(nn.Module):
103105
def __init__(self, hidden_size, eps=1e-6):
104106
"""
@@ -212,6 +214,7 @@ def forward(self, hidden_states):
212214
return index_sorted_experts, batch_index, batch_gates, expert_size, logits
213215

214216

217+
@use_kernel_forward_from_hub("ScatterMoEGatedMLP")
215218
class GraniteMoeSharedMoE(nn.Module):
216219
"""
217220
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.

0 commit comments

Comments
 (0)