-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Refactor weight loading #41580
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor weight loading #41580
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Impressive effort
…to correct device and etc
|
[For maintainers] Suggested jobs to run (before merge) run-slow: aimv2, albert, align |
|
Very cool! 🔥 After pulling latest changes from main and trying to load gpt-oss-20b, I get this error: >>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("openai/gpt-oss-20b")
Unrecognized keys in `rope_parameters` for 'rope_type'='yarn': {'truncate'}
Unrecognized keys in `rope_parameters` for 'rope_type'='yarn': {'truncate'}
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████| 459/459 [00:00<00:00, 574.18it/s, Materializing param=lm_head.weight]
GptOssForCausalLM LOAD REPORT from: openai/gpt-oss-20b
Key | Status |
------------------------------------------------------+------------+-
model.layers.{0...23}.mlp.experts.down_proj_blocks | UNEXPECTED |
model.layers.{0...23}.mlp.experts.gate_up_proj_scales | UNEXPECTED |
model.layers.{0...23}.mlp.experts.gate_up_proj_blocks | UNEXPECTED |
model.layers.{0...23}.mlp.experts.down_proj_scales | UNEXPECTED |
model.layers.{0...23}.mlp.experts.gate_up_proj | MISSING |
model.layers.{0...23}.mlp.experts.down_proj | MISSING |
Notes:
- UNEXPECTED :can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING :those params were newly initialized because missing form the checkpoint. Consider training on your downstream task.Just flagging as it seems to break backwards compatibility. I can also confirm that checking out the 2nd last commit (i.e., without this change) does not result in the error. |
|
It won't break, @MekkCyber and @SunMarc are working on MXFp4 support! |
| for _ in range(config.num_experts): | ||
| self.append(Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is not straightforward and breaks downstream libraries expecting Qwen2MoeExperts experts to be nn.Linear. Is there an easy workaround?
| for _ in range(self.num_experts): | ||
| self.append(Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment
|
Just for my understanding - is this expected to land in 4.58? |

CORE REFACTORING, loading, converting, logging
More helpful debugging report when loading weights

If you just want to fuse qkv:

It can. You just need to make sure you change the model code and pouf!
For deepseek we will embed the rope permute:
WeightConverterAPI:The API allows you to define a mapping using
WeightConverter. You can define many to one source/target keys, quantization opérations and distributed opérations along with normal opérations. For nowMergeModuleLIstandConcatenate, will add theRopePermuteone soon.We use to have this:
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L4545-L4568
But now its just explicit:
and its faster cuz we don't iterate over the whole checkpoint
The
corelogic is:Iterate over all of the dict keys:
(mlp.experts.*.gate_proj.weight|mlp.experts.*.up_proj.weight)into a dict with key target keyThis produces:
We need to keep track of which layers were collected, and from which source pattern.
1bis. Schedule tensor materialization, without blocking the GIL (as this takes the most amount of time). We distribute the tensor at this stage, before any operations. This IS the trickiest. We do this during collection to not waste time.
operationson all the collected values (at this point{ "mlp.experts.0.w1.weight": [t0, t1, t2, etc], "mlp.experts.1.w1.weight": [t0, t1, t2, etc]}.values()gives a list of lists.dictwith thetarget_keyand the outputvalues. We pass this to the quantizerKeys are handled a lot better!
Enable MoE quantization for FP8
This script does not work on main
Enable TP + MoE without OOM
This script does not work on main
Enable
device_map="auto"+ MoE + FP8This script does not work on main
Refactor the way we load weights, faster, flexible and better overall
Uses staging buffers per conversion op
device_map="auto"MoEquantization with FP8TODOS:
Script:
loading took 14.271092891693115 ['<s> hey how are you?\n\nI am a 20 year old male and I have been having']⬆️ is with: merge modulelist, concat gate_up
⬇️ is naive loading.
loading took 54.271092891693115 ['<s> hey how are you?\n\nI am a 20 year old male and I have been having']