Skip to content

Commit 7f9f4d9

Browse files
authored
Fix TP plans for MoE models (#42236)
* start * more fixes
1 parent 462beff commit 7f9f4d9

16 files changed

+72
-84
lines changed

src/transformers/models/deepseek_v2/configuration_deepseek_v2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,9 @@ class DeepseekV2Config(PreTrainedConfig):
127127
"layers.*.self_attn.q_b_proj": "colwise",
128128
"layers.*.self_attn.kv_b_proj": "colwise",
129129
"layers.*.self_attn.o_proj": "rowwise",
130-
"layers.*.mlp.gate_up_proj": "colwise",
131-
"layers.*.mlp.down_proj": "rowwise",
130+
"layers.*.mlp.experts.gate_up_proj": "local_colwise",
131+
"layers.*.mlp.experts.down_proj": "local_rowwise",
132+
"layers.*.mlp.experts": "gather",
132133
}
133134
base_model_pp_plan = {
134135
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

src/transformers/models/deepseek_v2/modular_deepseek_v2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,9 @@ class DeepseekV2Config(LlamaConfig):
142142
"layers.*.self_attn.q_b_proj": "colwise",
143143
"layers.*.self_attn.kv_b_proj": "colwise",
144144
"layers.*.self_attn.o_proj": "rowwise",
145-
"layers.*.mlp.gate_up_proj": "colwise",
146-
"layers.*.mlp.down_proj": "rowwise",
145+
"layers.*.mlp.experts.gate_up_proj": "local_colwise",
146+
"layers.*.mlp.experts.down_proj": "local_rowwise",
147+
"layers.*.mlp.experts": "gather",
147148
}
148149

149150
model_type = "deepseek_v2"

src/transformers/models/deepseek_v3/configuration_deepseek_v3.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -131,19 +131,16 @@ class DeepseekV3Config(PreTrainedConfig):
131131

132132
model_type = "deepseek_v3"
133133
keys_to_ignore_at_inference = ["past_key_values"]
134-
base_model_tp_plan = { # TODO: only replicate attention layers when > first_k_dense_replace
135-
"layers.*.mlp.experts.*.gate_proj": "local_colwise",
136-
"layers.*.mlp.experts.*.up_proj": "local_colwise",
137-
"layers.*.mlp.experts.*.down_proj": "local_rowwise",
138-
"layers.*.mlp.experts.*": "local", # each expert is wrapped in a module list
139-
"layers.*.mlp.shared_experts.gate_proj": "local_colwise",
140-
"layers.*.mlp.shared_experts.up_proj": "local_colwise",
141-
"layers.*.mlp.shared_experts.down_proj": "local_rowwise",
142-
"layers.*.mlp.shared_experts": "local",
143-
"layers.*.mlp.gate_proj": "local_colwise",
144-
"layers.*.mlp.up_proj": "local_colwise",
145-
"layers.*.mlp.down_proj": "local_rowwise",
146-
"layers.*.mlp": "gather", # This is the only moment where results are gathered
134+
base_model_tp_plan = {
135+
"layers.*.mlp.experts.gate_up_proj": "local_rowwise",
136+
"layers.*.mlp.experts.down_proj": "local_rowwise",
137+
"layers.*.mlp.experts": "gather",
138+
"layers.*.mlp.shared_experts.gate_proj": "colwise",
139+
"layers.*.mlp.shared_experts.up_proj": "colwise",
140+
"layers.*.mlp.shared_experts.down_proj": "rowwise",
141+
"layers.*.mlp.gate_proj": "colwise",
142+
"layers.*.mlp.up_proj": "colwise",
143+
"layers.*.mlp.down_proj": "rowwise",
147144
}
148145
base_model_pp_plan = {
149146
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

src/transformers/models/dots1/configuration_dots1.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -109,23 +109,20 @@ class Dots1Config(PreTrainedConfig):
109109
model_type = "dots1"
110110
keys_to_ignore_at_inference = ["past_key_values"]
111111

112-
base_model_tp_plan = { # TODO: only replicate attention layers when > first_k_dense_replace
112+
base_model_tp_plan = {
113113
"layers.*.self_attn.q_proj": "colwise",
114114
"layers.*.self_attn.k_proj": "colwise",
115115
"layers.*.self_attn.v_proj": "colwise",
116116
"layers.*.self_attn.o_proj": "rowwise",
117-
"layers.*.mlp.experts.*.gate_proj": "local_colwise",
118-
"layers.*.mlp.experts.*.up_proj": "local_colwise",
119-
"layers.*.mlp.experts.*.down_proj": "local_rowwise",
120-
"layers.*.mlp.experts.*": "local", # each expert is wrapped in a module list
121-
"layers.*.mlp.shared_experts.gate_proj": "local_colwise",
122-
"layers.*.mlp.shared_experts.up_proj": "local_colwise",
123-
"layers.*.mlp.shared_experts.down_proj": "local_rowwise",
124-
"layers.*.mlp.shared_experts": "local",
125-
"layers.*.mlp.gate_proj": "local_colwise",
126-
"layers.*.mlp.up_proj": "local_colwise",
127-
"layers.*.mlp.down_proj": "local_rowwise",
128-
"layers.*.mlp": "gather", # This is the only moment where results are gathered
117+
"layers.*.mlp.experts.gate_up_proj": "local_rowwise",
118+
"layers.*.mlp.experts.down_proj": "local_rowwise",
119+
"layers.*.mlp.experts": "gather",
120+
"layers.*.mlp.shared_experts.gate_proj": "colwise",
121+
"layers.*.mlp.shared_experts.up_proj": "colwise",
122+
"layers.*.mlp.shared_experts.down_proj": "rowwise",
123+
"layers.*.mlp.gate_proj": "colwise",
124+
"layers.*.mlp.up_proj": "colwise",
125+
"layers.*.mlp.down_proj": "rowwise",
129126
}
130127

131128
base_model_pp_plan = {

src/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -122,21 +122,15 @@ class Ernie4_5_MoeConfig(PreTrainedConfig):
122122
"layers.*.self_attn.k_proj": "colwise",
123123
"layers.*.self_attn.v_proj": "colwise",
124124
"layers.*.self_attn.o_proj": "rowwise",
125-
# sequence parallel is pretty slow
126-
# "norm.weight": "sequence_parallel",
127-
# "layers.*.input_layernorm.weight": "sequence_parallel",
128-
# "layers.*.post_attention_layernorm.weight": "sequence_parallel",
129-
"layers.*.mlp.shared_experts.gate_proj": "local_colwise",
130-
"layers.*.mlp.shared_experts.up_proj": "local_colwise",
131-
"layers.*.mlp.shared_experts.down_proj": "local_rowwise",
132-
"layers.*.mlp.experts.*.gate_proj": "local_colwise",
133-
"layers.*.mlp.experts.*.up_proj": "local_colwise",
134-
"layers.*.mlp.experts.*.down_proj": "local_rowwise",
135-
"layers.*.mlp.experts": "local",
136-
"layers.*.mlp.gate_proj": "local_colwise",
137-
"layers.*.mlp.up_proj": "local_colwise",
138-
"layers.*.mlp.down_proj": "local_rowwise",
139-
"layers.*.mlp": "gather",
125+
"layers.*.mlp.experts.gate_up_proj": "local_rowwise",
126+
"layers.*.mlp.experts.down_proj": "local_rowwise",
127+
"layers.*.mlp.experts": "gather",
128+
"layers.*.mlp.shared_experts.gate_proj": "colwise",
129+
"layers.*.mlp.shared_experts.up_proj": "colwise",
130+
"layers.*.mlp.shared_experts.down_proj": "rowwise",
131+
"layers.*.mlp.gate_proj": "colwise",
132+
"layers.*.mlp.up_proj": "colwise",
133+
"layers.*.mlp.down_proj": "rowwise",
140134
}
141135
base_model_pp_plan = {
142136
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

src/transformers/models/flex_olmo/configuration_flex_olmo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ class FlexOlmoConfig(PreTrainedConfig):
115115
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
116116
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
117117
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
118-
"layers.*.mlp.experts.*.gate_proj": "colwise",
119-
"layers.*.mlp.experts.*.up_proj": "colwise",
120-
"layers.*.mlp.experts.*.down_proj": "rowwise",
118+
"layers.*.mlp.experts.gate_up_proj": "local_rowwise",
119+
"layers.*.mlp.experts.down_proj": "local_rowwise",
120+
"layers.*.mlp.experts": "gather",
121121
}
122122
base_model_pp_plan = {
123123
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

src/transformers/models/flex_olmo/modular_flex_olmo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ class FlexOlmoConfig(OlmoeConfig):
125125
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
126126
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
127127
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
128-
"layers.*.mlp.experts.*.gate_proj": "colwise",
129-
"layers.*.mlp.experts.*.up_proj": "colwise",
130-
"layers.*.mlp.experts.*.down_proj": "rowwise",
128+
"layers.*.mlp.experts.gate_up_proj": "local_rowwise",
129+
"layers.*.mlp.experts.down_proj": "local_rowwise",
130+
"layers.*.mlp.experts": "gather",
131131
}
132132
base_model_pp_plan = {
133133
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

src/transformers/models/glm4_moe/configuration_glm4_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ class Glm4MoeConfig(PreTrainedConfig):
121121
"layers.*.self_attn.k_proj": "colwise",
122122
"layers.*.self_attn.v_proj": "colwise",
123123
"layers.*.self_attn.o_proj": "rowwise",
124-
"layers.*.mlp.experts.*.gate_proj": "colwise",
125-
"layers.*.mlp.experts.*.up_proj": "colwise",
126-
"layers.*.mlp.experts.*.down_proj": "rowwise",
124+
"layers.*.mlp.experts.gate_up_proj": "local_rowwise",
125+
"layers.*.mlp.experts.down_proj": "local_rowwise",
126+
"layers.*.mlp.experts": "gather",
127127
"layers.*.mlp.gate_proj": "colwise",
128128
"layers.*.mlp.up_proj": "colwise",
129129
"layers.*.mlp.down_proj": "rowwise",

src/transformers/models/glm4_moe/modular_glm4_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ class Glm4MoeConfig(PreTrainedConfig):
135135
"layers.*.self_attn.k_proj": "colwise",
136136
"layers.*.self_attn.v_proj": "colwise",
137137
"layers.*.self_attn.o_proj": "rowwise",
138-
"layers.*.mlp.experts.*.gate_proj": "colwise",
139-
"layers.*.mlp.experts.*.up_proj": "colwise",
140-
"layers.*.mlp.experts.*.down_proj": "rowwise",
138+
"layers.*.mlp.experts.gate_up_proj": "local_rowwise",
139+
"layers.*.mlp.experts.down_proj": "local_rowwise",
140+
"layers.*.mlp.experts": "gather",
141141
"layers.*.mlp.gate_proj": "colwise",
142142
"layers.*.mlp.up_proj": "colwise",
143143
"layers.*.mlp.down_proj": "rowwise",

src/transformers/models/longcat_flash/configuration_longcat_flash.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,9 @@ class LongcatFlashConfig(PreTrainedConfig):
129129
"layers.*.mlps.*.gate_proj": "colwise",
130130
"layers.*.mlps.*.up_proj": "colwise",
131131
"layers.*.mlps.*.down_proj": "rowwise",
132-
"layers.*.mlp.experts.*.gate_proj": "colwise",
133-
"layers.*.mlp.experts.*.up_proj": "colwise",
134-
"layers.*.mlp.experts.*.down_proj": "rowwise",
132+
"layers.*.mlp.experts.gate_up_proj": "local_rowwise",
133+
"layers.*.mlp.experts.down_proj": "local_rowwise",
134+
"layers.*.mlp.experts": "gather",
135135
}
136136

137137
base_model_pp_plan = {

0 commit comments

Comments
 (0)