Skip to content

Conversation

@ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Oct 14, 2025

CORE REFACTORING, loading, converting, logging

More helpful debugging report when loading weights
image

If you just want to fuse qkv:
image

It can. You just need to make sure you change the model code and pouf!

            WeightConverter(
                ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
                "self_attn.qkv_proj",
                operations=[Concatenate(dim=0)],  # more like stack?
            ),

For deepseek we will embed the rope permute:

            WeightConverter(
                ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
                operations=[RopePermute()],  # more like stack?
            ),

WeightConverter API:

The API allows you to define a mapping using WeightConverter. You can define many to one source/target keys, quantization opérations and distributed opérations along with normal opérations. For now MergeModuleLIst and Concatenate, will add the RopePermute one soon.

_checkpoint_conversion_mapping = {
    "mixtral": [
        WeightConverter(
            source_keys=[
                "mlp.experts.*.w1.weight",
                "mlp.experts.*.w3.weight",
            ],
            target_keys="mlp.experts.gate_up_proj",
            operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
        ),
        WeightConverter(
            source_keys=["mlp.experts.*.w2.weight"],
            target_keys="mlp.experts.down_proj",
            operations=[MergeModulelist(dim=0)],
        ),
    ],
}

We use to have this:

https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L4545-L4568

But now its just explicit:

        "legacy": [
            WeightConverter(
                source_keys="LayerNorm.gamma",
                target_keys="LayerNorm.weight",
            ),
            WeightConverter(
                source_keys="LayerNorm.beta",
                target_keys="LayerNorm.bias",
            ),
        ],
    }
    if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
        mapping["legacy"] += [
            WeightConverter(
                source_keys="weight_g",
                target_keys="parametrizations.weight.original0",
            ),
            WeightConverter(
                source_keys="weight_v",
                target_keys="parametrizations.weight.original1",
            ),
        ]
    else:
        mapping["legacy"] += [
            WeightConverter(
                source_keys="parametrizations.weight.original0",
                target_keys="weight_g",
            ),
            WeightConverter(
                source_keys="parametrizations.weight.original1",
                target_keys="weight_v",
            ),
        ]

and its faster cuz we don't iterate over the whole checkpoint

The core logic is:
Iterate over all of the dict keys:

  1. collect the keys that match the glob patterns from all source keys (you pipe the ones that are from the same weight converter): (mlp.experts.*.gate_proj.weight|mlp.experts.*.up_proj.weight) into a dict with key target key

This produces:

{ 
"mlp.experts.gate_up_proj" : 
    {"mlp.experts.*.w1.weight":
        { "mlp.experts.0.w1.weight": [t0, t1, t2, etc], "mlp.experts.1.w1.weight": [t0, t1, t2, etc]},
     "mlp.experts.*.w3.weight":
        { "mlp.experts.0.w3.weight": [t0, t1, t2, etc], "mlp.experts.1.w3.weight": [t0, t1, t2, etc]},
    }
  ....
}

We need to keep track of which layers were collected, and from which source pattern.

1bis. Schedule tensor materialization, without blocking the GIL (as this takes the most amount of time). We distribute the tensor at this stage, before any operations. This IS the trickiest. We do this during collection to not waste time.

  1. We collect the results of materialization, and we apply the operations on all the collected values (at this point { "mlp.experts.0.w1.weight": [t0, t1, t2, etc], "mlp.experts.1.w1.weight": [t0, t1, t2, etc]}.values() gives a list of lists.
  2. We create a dict with the target_key and the output values. We pass this to the quantizer
  3. We quantize the input tensors, outputting the final dict.
  4. We set the param into the model.

Keys are handled a lot better!

Enable MoE quantization for FP8

This script does not work on main

import torch
from transformers import MixtralForCausalLM, AutoTokenizer, FineGrainedFP8Config
import time 
quantization_config = FineGrainedFP8Config(modules_to_not_convert=["model.layers.*.mlp.gate"])
model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", quantization_config=quantization_config, tp_plan="auto")

Enable TP + MoE without OOM

This script does not work on main

model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", tp_plan="auto")

Enable device_map="auto" + MoE + FP8

This script does not work on main

quantization_config = FineGrainedFP8Config(modules_to_not_convert=["model.layers.*.mlp.gate"])
model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", quantization_config=quantization_config, device_map="auto")

Refactor the way we load weights, faster, flexible and better overall

Uses staging buffers per conversion op

  • 4x speedup with device_map="auto"
  • Full MoE quantization with FP8

TODOS:

  • Test with TP / EP
  • Add TQDM!
  • Test with deepspeedd
  • Test with loras and peft
  • Test with vllm backend
  • Test with fsdp
  • Add saving

Script:

import torch
from torch import nn
from transformers import MixtralForCausalLM, AutoTokenizer

import time 
start = time.time()
model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", device_map="auto")
end = time.time() 
print("loading took ", end-start)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
inputs = tokenizer("hey how are you?", return_tensors="pt").to(model.device)
out = model.generate(**inputs, max_new_tokens=16)
print(tokenizer.batch_decode(out))
loading took  14.271092891693115
['<s> hey how are you?\n\nI am a 20 year old male and I have been having']

⬆️ is with: merge modulelist, concat gate_up
⬇️ is naive loading.

loading took  54.271092891693115
['<s> hey how are you?\n\nI am a 20 year old male and I have been having']

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Impressive effort

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: aimv2, albert, align

@ArthurZucker ArthurZucker merged commit 6f6095e into main Nov 13, 2025
21 of 24 checks passed
@ArthurZucker ArthurZucker deleted the refactor-weight-loading branch November 13, 2025 16:12
@xenova
Copy link
Contributor

xenova commented Nov 14, 2025

Very cool! 🔥 After pulling latest changes from main and trying to load gpt-oss-20b, I get this error:

>>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("openai/gpt-oss-20b")
Unrecognized keys in `rope_parameters` for 'rope_type'='yarn': {'truncate'}
Unrecognized keys in `rope_parameters` for 'rope_type'='yarn': {'truncate'}
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████| 459/459 [00:00<00:00, 574.18it/s, Materializing param=lm_head.weight]
GptOssForCausalLM LOAD REPORT from: openai/gpt-oss-20b
Key                                                   | Status     | 
------------------------------------------------------+------------+-
model.layers.{0...23}.mlp.experts.down_proj_blocks    | UNEXPECTED | 
model.layers.{0...23}.mlp.experts.gate_up_proj_scales | UNEXPECTED | 
model.layers.{0...23}.mlp.experts.gate_up_proj_blocks | UNEXPECTED | 
model.layers.{0...23}.mlp.experts.down_proj_scales    | UNEXPECTED | 
model.layers.{0...23}.mlp.experts.gate_up_proj        | MISSING    | 
model.layers.{0...23}.mlp.experts.down_proj           | MISSING    | 

Notes:
- UNEXPECTED    :can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING       :those params were newly initialized because missing form the checkpoint. Consider training on your downstream task.

Just flagging as it seems to break backwards compatibility. I can also confirm that checking out the 2nd last commit (i.e., without this change) does not result in the error.

@ArthurZucker
Copy link
Collaborator Author

It won't break, @MekkCyber and @SunMarc are working on MXFp4 support!

@MekkCyber
Copy link
Contributor

MekkCyber commented Nov 14, 2025

Yes @xenova we are taking care of that here : #42070, we just need to fix some issues and it will be good to go

Comment on lines -300 to -301
for _ in range(config.num_experts):
self.append(Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size))
Copy link
Contributor

@fxmarty-amd fxmarty-amd Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not straightforward and breaks downstream libraries expecting Qwen2MoeExperts experts to be nn.Linear. Is there an easy workaround?

Comment on lines -220 to -221
for _ in range(self.num_experts):
self.append(Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment

@fxmarty-amd
Copy link
Contributor

image

🫠

@fxmarty-amd
Copy link
Contributor

Just for my understanding - is this expected to land in 4.58?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Core: Modeling Internals of the library; Models. for_v5?

Projects

None yet

Development

Successfully merging this pull request may close these issues.