Skip to content

Commit b605e1a

Browse files
committed
update mixtral
1 parent 08ad69b commit b605e1a

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/transformers/models/mixtral/modeling_mixtral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ class MixtralPreTrainedModel(PreTrainedModel):
404404
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
405405
_supports_attention_backend = True
406406
_can_record_outputs = {
407-
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
407+
"router_logits": OutputRecorder(MixtralTopKRouter, index=0),
408408
"hidden_states": MixtralDecoderLayer,
409409
"attentions": MixtralAttention,
410410
}

src/transformers/models/mixtral/modular_mixtral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def forward(
265265
class MixtralPreTrainedModel(MistralPreTrainedModel):
266266
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
267267
_can_record_outputs = {
268-
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
268+
"router_logits": OutputRecorder(MixtralTopKRouter, index=0),
269269
"hidden_states": MixtralDecoderLayer,
270270
"attentions": MixtralAttention,
271271
}

0 commit comments

Comments
 (0)