Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
387 commits
Select commit Hold shift + click to select a range
4d79709
ah actually we don't discard lm head if missing -> needs to be moved …
ArthurZucker Nov 3, 2025
d1e84db
fix some tests
ArthurZucker Nov 3, 2025
f2938df
small fixes
ArthurZucker Nov 3, 2025
22fcdaf
up
ArthurZucker Nov 3, 2025
7d78aa1
up
ArthurZucker Nov 3, 2025
80517f5
dik why we tie weights twice but,..,,.
ArthurZucker Nov 3, 2025
2ff8532
ups
ArthurZucker Nov 3, 2025
d923061
removeunused
ArthurZucker Nov 3, 2025
ce8c1c1
fix hunyuan
ArthurZucker Nov 3, 2025
23e3ed7
small fix
ArthurZucker Nov 3, 2025
a8fb554
nits
ArthurZucker Nov 3, 2025
ab6ee8a
ish
ArthurZucker Nov 3, 2025
77ccbb1
up
ArthurZucker Nov 3, 2025
8a8beff
rev
ArthurZucker Nov 3, 2025
02386ce
fix more tie weights keys
ArthurZucker Nov 3, 2025
1c87945
small fixes
ArthurZucker Nov 3, 2025
00b95ee
nit
ArthurZucker Nov 3, 2025
a170f29
update
ArthurZucker Nov 3, 2025
8b924a3
fix and fix
ArthurZucker Nov 3, 2025
8f7b1d0
fix a test
ArthurZucker Nov 3, 2025
9386217
glubs
ArthurZucker Nov 3, 2025
4894a25
current shitty changes
ArthurZucker Nov 3, 2025
da7dc10
ship validated ones
ArthurZucker Nov 4, 2025
d7c8171
more
ArthurZucker Nov 4, 2025
e088408
more update
ArthurZucker Nov 4, 2025
4f212de
more
ArthurZucker Nov 4, 2025
dc5a22c
more
ArthurZucker Nov 4, 2025
675b2bc
more
ArthurZucker Nov 4, 2025
f85f239
mllama
ArthurZucker Nov 4, 2025
76b6a92
more up
ArthurZucker Nov 4, 2025
ba1a8b6
fix ernie
ArthurZucker Nov 4, 2025
ba3de5a
fix xopies
ArthurZucker Nov 4, 2025
8fd255c
up more
ArthurZucker Nov 4, 2025
5d7507b
more fixes
ArthurZucker Nov 4, 2025
0fb2340
up
ArthurZucker Nov 4, 2025
32b9273
up
ArthurZucker Nov 4, 2025
0b95826
fix-copies
ArthurZucker Nov 4, 2025
5794d27
fix more
ArthurZucker Nov 4, 2025
5e71bd4
more updates
ArthurZucker Nov 4, 2025
20d1b34
AI UPDATE
ArthurZucker Nov 4, 2025
89846e7
up
ArthurZucker Nov 5, 2025
a581fd7
hoey
ArthurZucker Nov 5, 2025
1652c9c
make it fast
Cyrilvallez Nov 5, 2025
dcad703
fix
Cyrilvallez Nov 5, 2025
c921ced
lol
ArthurZucker Nov 5, 2025
50714d8
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 5, 2025
8936cc4
fix asjusting
ArthurZucker Nov 5, 2025
5c54332
more fixes
ArthurZucker Nov 5, 2025
ff10878
_dtype nit
ArthurZucker Nov 5, 2025
9601b82
up
ArthurZucker Nov 5, 2025
db02b9d
nit
ArthurZucker Nov 5, 2025
42fd4c4
update
ArthurZucker Nov 5, 2025
4527171
update
ArthurZucker Nov 5, 2025
bd36211
remove semaphores
Cyrilvallez Nov 5, 2025
e2aefee
fix import to avoid jit execution
Cyrilvallez Nov 5, 2025
74a0e9c
try to remove custom tiing logic when its stupid
ArthurZucker Nov 5, 2025
ead2ac3
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 5, 2025
e7165da
fix more individual models
ArthurZucker Nov 5, 2025
2ff765e
fix whisper as well
ArthurZucker Nov 5, 2025
912562c
fix?
ArthurZucker Nov 5, 2025
c43495a
fox umt5
ArthurZucker Nov 5, 2025
57988f2
improve tqdm bar
Cyrilvallez Nov 5, 2025
8c16de1
cleanup a bit
Cyrilvallez Nov 5, 2025
b8927d6
oupsi
Cyrilvallez Nov 5, 2025
2733ff6
some updates
ArthurZucker Nov 5, 2025
8baa3fe
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 5, 2025
d91701f
improve
Cyrilvallez Nov 5, 2025
5146dec
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
Cyrilvallez Nov 5, 2025
acc5b24
remove all buffering -> much faster without it
Cyrilvallez Nov 5, 2025
58389a1
remove some tie_weights custome funcs when not needed
ArthurZucker Nov 5, 2025
92c0229
more fixes related to strict matching regex
ArthurZucker Nov 5, 2025
d9e7fe6
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 5, 2025
b57d789
remove ALL custom tie weights
ArthurZucker Nov 5, 2025
ef8b6c3
small update
ArthurZucker Nov 5, 2025
a228fd0
revert change to init scheme (no need for params)
Cyrilvallez Nov 5, 2025
07574dd
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 5, 2025
2526cc5
mixtral init
Cyrilvallez Nov 5, 2025
6cb3794
try less strict source check
ArthurZucker Nov 5, 2025
e4cadfb
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 5, 2025
3fea865
tied weight first shot to the fiiiixxxxxx
Cyrilvallez Nov 5, 2025
82f94b8
does this help?
ArthurZucker Nov 5, 2025
84dd6eb
:)
ArthurZucker Nov 5, 2025
cc08195
fix some ppolry defined tied_weights_keys for now
ArthurZucker Nov 5, 2025
f692f4b
subclass nn.Parameters
ArthurZucker Nov 7, 2025
2fa058f
up
ArthurZucker Nov 7, 2025
78d4622
lol
ArthurZucker Nov 7, 2025
8ff4ad5
Ouiiii
ArthurZucker Nov 7, 2025
3222678
fix led
ArthurZucker Nov 7, 2025
9a76a6e
fix long cat flash
ArthurZucker Nov 7, 2025
9fde9f7
fix qwen and long cat flash
ArthurZucker Nov 7, 2025
074a449
properly fix qwen init
ArthurZucker Nov 7, 2025
dde5500
just push this for now
ArthurZucker Nov 7, 2025
0e7d2d0
propnet is dumb
ArthurZucker Nov 7, 2025
18b02ee
update
ArthurZucker Nov 7, 2025
9c0db72
push
ArthurZucker Nov 7, 2025
75d3afc
remove explict sharing of some tied keys.
ArthurZucker Nov 7, 2025
85ab085
update decoder.bias
ArthurZucker Nov 7, 2025
443573a
moe case
ArthurZucker Nov 7, 2025
f8f0973
more changes to untangle old hardcoded ting
ArthurZucker Nov 7, 2025
5c9d56c
fixup
ArthurZucker Nov 7, 2025
a0029f2
Merge branch 'main' into refactor-weight-loading
ArthurZucker Nov 7, 2025
44943fb
fix big faileurs
ArthurZucker Nov 7, 2025
76d66be
fix prophnet
ArthurZucker Nov 7, 2025
d176b48
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 7, 2025
3ffc59e
fix resize token embeddings
ArthurZucker Nov 10, 2025
2a00e49
nits
ArthurZucker Nov 10, 2025
f7d0183
fix xcodex
ArthurZucker Nov 10, 2025
bbf5b00
asyncio?
ArthurZucker Nov 10, 2025
0412832
fix smart apply
ArthurZucker Nov 10, 2025
c137ea3
fix data-2-vec
ArthurZucker Nov 10, 2025
7b7c990
[build-ci-image]
ArthurZucker Nov 10, 2025
de74aeb
checkout
ArthurZucker Nov 10, 2025
94a53d4
uupdate
ArthurZucker Nov 10, 2025
8755a4b
fix hunyuan
ArthurZucker Nov 10, 2025
5be67b9
update error message
ArthurZucker Nov 10, 2025
86a4e51
fix deformable detr
ArthurZucker Nov 10, 2025
09bcd2e
fixes
ArthurZucker Nov 10, 2025
7b457fd
fix init weights for non param gate up projs
ArthurZucker Nov 10, 2025
e033947
shared todo?
ArthurZucker Nov 10, 2025
f93f357
update some models
ArthurZucker Nov 10, 2025
2f0a6ae
big revert, don't break this behaviour
ArthurZucker Nov 10, 2025
3c8c757
ty @SunMarc this fixes the buffers
ArthurZucker Nov 10, 2025
f5a7c33
mt5 fuck
ArthurZucker Nov 10, 2025
647f720
fix lxmbert
ArthurZucker Nov 10, 2025
bed6ea1
nuke slow test fetcher
ArthurZucker Nov 10, 2025
2ec0a5f
fix zamba and deepcopy for now
ArthurZucker Nov 10, 2025
f9c7ef8
fix zamba tied weight keys! ~
ArthurZucker Nov 10, 2025
8df3ffd
fix-copies
ArthurZucker Nov 10, 2025
e76481b
update fetch terst
ArthurZucker Nov 10, 2025
de00751
fix gradient for test modeling common!
ArthurZucker Nov 10, 2025
cdd1a9b
break "shared" for now I will fix tomorrow changes are properly isoal…
ArthurZucker Nov 10, 2025
d3f6476
does this fix marian? probably not
ArthurZucker Nov 10, 2025
0a7db83
fix some vlms
ArthurZucker Nov 10, 2025
1814200
D fine seems to handle this well
ArthurZucker Nov 10, 2025
b77825d
glob is fine actually
ArthurZucker Nov 11, 2025
5dbb783
fix dab detr
ArthurZucker Nov 11, 2025
9edc81b
small steps
ArthurZucker Nov 11, 2025
970f4e5
opusy
ArthurZucker Nov 11, 2025
0361d47
fix some more models?
ArthurZucker Nov 11, 2025
dc75773
yups
ArthurZucker Nov 11, 2025
cdb1284
better erro
ArthurZucker Nov 11, 2025
de9a2d9
fix?
ArthurZucker Nov 11, 2025
b9a9f4d
fix double escape
ArthurZucker Nov 11, 2025
c944619
escape wehere it makes sense
ArthurZucker Nov 11, 2025
f910524
??
ArthurZucker Nov 11, 2025
4aa2ade
fix ibert
ArthurZucker Nov 11, 2025
2ef1c2b
fix tvp as well
ArthurZucker Nov 11, 2025
b98a7bc
more fxes
ArthurZucker Nov 11, 2025
74e6c87
try always download ref PR
ArthurZucker Nov 11, 2025
5064edd
ONONONO
ArthurZucker Nov 11, 2025
3f8a304
big fixup
ArthurZucker Nov 11, 2025
3ecaa63
more fixup
ArthurZucker Nov 11, 2025
f384524
small step
ArthurZucker Nov 11, 2025
290337a
small nits
ArthurZucker Nov 11, 2025
76b388c
nits
ArthurZucker Nov 11, 2025
e69b988
brut force some stuff
ArthurZucker Nov 11, 2025
c2781f5
fix vilt
ArthurZucker Nov 11, 2025
f64ee96
make sure special models that always need tie always tie
ArthurZucker Nov 11, 2025
a3e4015
cleaning up
ArthurZucker Nov 11, 2025
9eecbd2
small nits
ArthurZucker Nov 11, 2025
b2fa432
fix zamba and bridge tower!
ArthurZucker Nov 11, 2025
dbbfdf2
just fixup
ArthurZucker Nov 11, 2025
ab4890c
potential culprits
ArthurZucker Nov 11, 2025
937ebf3
revert bark and fix bridgetower
ArthurZucker Nov 11, 2025
e4f9697
Merge branch 'main' of github.com:huggingface/transformers into refac…
ArthurZucker Nov 11, 2025
17803ce
remove now non existant tie_weights
ArthurZucker Nov 11, 2025
9f6838a
?
ArthurZucker Nov 11, 2025
1afb3eb
lol reformer actually had nothing tied!
ArthurZucker Nov 11, 2025
f01a149
wow these two fucking models were really not well made
ArthurZucker Nov 11, 2025
0b36980
fix sam family!
ArthurZucker Nov 11, 2025
d740c82
fix bark revision
ArthurZucker Nov 11, 2025
6f3940e
fix speech2test ?
ArthurZucker Nov 11, 2025
b2f6f61
push this for now....
ArthurZucker Nov 11, 2025
ade8dab
upsy
ArthurZucker Nov 11, 2025
f956ccf
the fuck
ArthurZucker Nov 11, 2025
99c6fd4
fix rtdetr
ArthurZucker Nov 11, 2025
1ffcfc3
update
ArthurZucker Nov 11, 2025
ee62aec
proper
ArthurZucker Nov 11, 2025
6ec80f8
wow that one 's annoying
ArthurZucker Nov 11, 2025
b05e329
update
ArthurZucker Nov 11, 2025
2606596
try to find the culprit
ArthurZucker Nov 11, 2025
d9e8a09
get some help on common
ArthurZucker Nov 12, 2025
581665a
nit about general init and cls.padding_idx
ArthurZucker Nov 12, 2025
c43bc68
revert num workers update
ArthurZucker Nov 12, 2025
b6fe415
remove old loading func
Cyrilvallez Nov 12, 2025
4bb8e5c
fix glob
ArthurZucker Nov 12, 2025
7d52b06
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 12, 2025
455bcc7
add annotations
Cyrilvallez Nov 12, 2025
fc884c0
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
Cyrilvallez Nov 12, 2025
2e0ed5d
fix re
ArthurZucker Nov 12, 2025
3ddd1cc
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 12, 2025
1f86a10
small improvements
Cyrilvallez Nov 12, 2025
4d56fbf
fix conflict
Cyrilvallez Nov 12, 2025
67a8eeb
clean some stuff
Cyrilvallez Nov 12, 2025
e9168ff
improvements
Cyrilvallez Nov 12, 2025
feda22d
someone did not understannnnnnd what I tried to dooo or does BNB not …
ArthurZucker Nov 12, 2025
70841c9
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 12, 2025
52248ba
gluos
ArthurZucker Nov 12, 2025
e8dd4a4
fix case when `.` is just not there
ArthurZucker Nov 12, 2025
1c67fc4
remove unused arg
Cyrilvallez Nov 12, 2025
e20ed00
recover orignal parameter/buffer using _original
SunMarc Nov 12, 2025
827c42a
fix glob issu
ArthurZucker Nov 12, 2025
e5e4d28
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 12, 2025
4db2aa6
this?
ArthurZucker Nov 12, 2025
2b16c17
deepspeed best-effort
Cyrilvallez Nov 12, 2025
c411ddb
remove unused stuff
Cyrilvallez Nov 12, 2025
56d368b
Update tie weight keys as they were just wroong
ArthurZucker Nov 12, 2025
85d0ac1
up
ArthurZucker Nov 12, 2025
daa642c
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 12, 2025
bbf71b9
augustuc clauss, a gloubs gloups gloubs
ArthurZucker Nov 12, 2025
127e4d5
fixup
ArthurZucker Nov 12, 2025
7954185
fixup
ArthurZucker Nov 12, 2025
f7cd4b3
there was fucking typo
ArthurZucker Nov 12, 2025
f9e747e
mrain
ArthurZucker Nov 12, 2025
57bf5b2
nits
ArthurZucker Nov 12, 2025
c38ad24
fix marian 3 remaining tests
ArthurZucker Nov 12, 2025
d7be7df
one more
ArthurZucker Nov 12, 2025
729e3df
fix some of the copies, not all :)
ArthurZucker Nov 12, 2025
c95a3f1
small cleanup
ArthurZucker Nov 12, 2025
8778840
one propertest
ArthurZucker Nov 13, 2025
1181e3f
fix core model loadig tes
ArthurZucker Nov 13, 2025
b750e6b
attempt a new test
ArthurZucker Nov 13, 2025
3178c3f
fix some of the annoying tests by supporting reading .bin sometimes
ArthurZucker Nov 13, 2025
d6ab250
push
ArthurZucker Nov 13, 2025
0695197
push more small fixes
ArthurZucker Nov 13, 2025
fd5a75a
Merge branch 'main' of github.com:huggingface/transformers into refac…
ArthurZucker Nov 13, 2025
f54b528
remove 1 useless test
ArthurZucker Nov 13, 2025
1abf6a9
up
ArthurZucker Nov 13, 2025
3014290
fix audio flamingo post rebase
ArthurZucker Nov 13, 2025
1f1bea3
fixup
ArthurZucker Nov 13, 2025
c2dbca0
some small updatess
ArthurZucker Nov 13, 2025
347b966
fix sam models
ArthurZucker Nov 13, 2025
40ed636
nits
ArthurZucker Nov 13, 2025
3b2f934
up
ArthurZucker Nov 13, 2025
fb0fb89
updates
ArthurZucker Nov 13, 2025
92e2771
onem ore
ArthurZucker Nov 13, 2025
06f2ba9
skip this stupid test
ArthurZucker Nov 13, 2025
3d5c86c
some other fixes
ArthurZucker Nov 13, 2025
15bc48e
fixup
ArthurZucker Nov 13, 2025
47743f8
update
ArthurZucker Nov 13, 2025
d77cf57
skip more offloaded stuff
ArthurZucker Nov 13, 2025
75f2bd4
oups
ArthurZucker Nov 13, 2025
08ad69b
ups
ArthurZucker Nov 13, 2025
b605e1a
update mixtral
ArthurZucker Nov 13, 2025
91d40b8
skip this one
ArthurZucker Nov 13, 2025
638bbfc
LET"SGO
ArthurZucker Nov 13, 2025
7daacb4
fixup
ArthurZucker Nov 13, 2025
22c19a7
rope delta order
ArthurZucker Nov 13, 2025
6d89354
fix csm
ArthurZucker Nov 13, 2025
9ccb693
small nit
ArthurZucker Nov 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# FILE to store the default conversion mapping that we use in `transformers`.
#
#
#
#
# Either we keep it here, or we move it to the config, but for newcomers, seeing this is kinda weird no?

from ...core_model_loading import Fuse, MergeModuleList, WeightConversion, ConversionType

_checkpoint_conversion_mapping = { "mixtral": {
"experts.*.(w1|w2).weight$": WeightConversion(
"experts.gate_up_proj.weight", [ConversionType.MERGE_MODULE_LIST, ConversionType.FUSE]
),
"self_attn.(q|k|v)_proj": WeightConversion("self_attn.qkv_proj", ConversionType.FUSE),
"experts*.w2.weight": WeightConversion("experts.down_proj.weight", ConversionType.MERGE_MODULE_LIST),
}}

244 changes: 244 additions & 0 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Core helpers for loading model checkpoints."""

from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch

from .quantizers.quantizers_utils import get_module_from_name


"""
For mixtral, the fp8 quantizer should add the "quantization" op.

Quantizer says wether we need all weights or not.

TP probably does not need?


model.layers.0.block_sparse_moe.experts.1.w1.input_scale []
model.layers.0.block_sparse_moe.experts.1.w1.weight [14 336, 4 096]
model.layers.0.block_sparse_moe.experts.1.w1.weight_scale []
model.layers.0.block_sparse_moe.experts.1.w2.input_scale []
model.layers.0.block_sparse_moe.experts.1.w2.weight [4 096, 14 336]
model.layers.0.block_sparse_moe.experts.1.w2.weight_scale []
model.layers.0.block_sparse_moe.experts.1.w3.input_scale []
model.layers.0.block_sparse_moe.experts.1.w3.weight [14 336, 4 096]
model.layers.0.block_sparse_moe.experts.1.w3.weight_scale []
"""


class ConversionOps:
"""
Base class with a reusable buffer to avoid repeated allocations.
Subclasses implement `convert(collected_tensors) -> torch.Tensor` and
write results into a view of `self._buffer`.
"""

target_tensor_shape: torch.Tensor
can_be_quantized: bool = True
can_be_distributed: bool = False

# Lazily created on first use; no __init__ needed.
_buffer: Optional[torch.Tensor] = None

def _ensure_buffer(
self, required_shape: torch.Size, *, dtype: torch.dtype, device: torch.device, growth_factor: float = 1.5
) -> torch.Tensor:
"""
Ensure we have a buffer with enough capacity (and correct dtype/device).
Returns a *view* of the buffer shaped as `required_shape` without new allocation.
"""
required_elems = int(torch.tensor(required_shape).prod().item()) if len(required_shape) else 1

need_new = (
self._buffer is None
or self._buffer.dtype != dtype
or self._buffer.device != device
or self._buffer.numel() < required_elems
)

if need_new:
# grow capacity to reduce future reallocations
capacity = max(required_elems, int(required_elems * growth_factor))
self._buffer = torch.empty(capacity, dtype=dtype, device=device)

# return a view with the requested shape using only the needed slice
return self._buffer[:required_elems].view(required_shape)

def clear_cache(self):
"""Free the cached buffer (optional)."""
self._buffer = None

def convert(self, collected_tensors: Iterable[torch.Tensor]) -> torch.Tensor:
raise NotImplementedError


class Fuse(ConversionOps):
"""
Concatenate along `dim` without allocating a fresh output each call:
copies into a preallocated buffer slice-by-slice.
"""

dim: int = 0 # adjust if you want a different default

def convert(self, collected_tensors: Iterable[torch.Tensor]) -> torch.Tensor:
tensors = tuple(collected_tensors)
if not tensors:
# Return a zero-size view on an empty buffer on CPU by default
self._buffer = None
return torch.empty(0)

# Basic checks & canonical attrs
first = tensors[0]
dtype, device = first.dtype, first.device
dim = self.dim

# Validate shapes/dtypes/devices
base_shape = list(first.shape)
for t in tensors:
if t.dtype != dtype or t.device != device:
raise TypeError("All tensors must share dtype and device for Fuse.")
if len(t.shape) != len(base_shape):
raise ValueError("All tensors must have the same rank for Fuse.")
for d, (a, b) in enumerate(zip(base_shape, t.shape)):
if d == dim:
continue
if a != b:
raise ValueError(f"Non-concat dims must match; got {a} vs {b} at dim {d}.")

# Compute fused shape
total_along_dim = sum(t.shape[dim] for t in tensors)
out_shape = list(base_shape)
out_shape[dim] = total_along_dim
out_shape = torch.Size(out_shape)

with torch.no_grad():
out = self._ensure_buffer(out_shape, dtype=dtype, device=device)

# Copy into preallocated buffer without creating a new result tensor
# We slice along `dim` and copy each piece.
idx = 0
for t in tensors:
slc = [slice(None)] * t.ndim
slc[dim] = slice(idx, idx + t.shape[dim])
out[tuple(slc)].copy_(t)
idx += t.shape[dim]

return out


class MergeModuleList(ConversionOps):
"""
Stack tensors along a new leading dimension without allocating a new tensor:
writes each tensor into a preallocated [N, ...] buffer.
"""

stack_dim: int = 0 # new dimension index in the *output*

def convert(self, collected_tensors: Iterable[torch.Tensor]) -> torch.Tensor:
tensors = tuple(collected_tensors)
if not tensors:
self._buffer = None
return torch.empty(0)

first = tensors[0]
dtype, device = first.dtype, first.device
base_shape = first.shape

# Validate consistency
for t in tensors:
if t.dtype != dtype or t.device != device:
raise TypeError("All tensors must share dtype and device for MergeModuleList.")
if t.shape != base_shape:
raise ValueError("All tensors must have identical shapes to stack.")

N = len(tensors)
# Normalize stack_dim (allow negative)
stack_dim = self.stack_dim % (first.ndim + 1)

# Output shape: insert N at stack_dim
out_shape = list(base_shape)
out_shape.insert(stack_dim, N)
out_shape = torch.Size(out_shape)

with torch.no_grad():
out = self._ensure_buffer(out_shape, dtype=dtype, device=device)

# Write each tensor into the appropriate slice
for i, t in enumerate(tensors):
slc = [slice(None)] * out.ndim
slc[stack_dim] = i
out[tuple(slc)].copy_(t)

return out

class Fp8Quantize(ConversionOps):
def convert(self, collected_tensors):
from .quantizers.quantizers_finegrained_fp8 import FineGrainedFP8HfQuantizer
return FineGrainedFP8HfQuantizer.create_quantized_param(collected_tensors)


class Slice(ConversionOps):
# TODO: implement slicing for tp
def convert(self, inputs):
return inputs

class ConversionType(Enum):
FUSE = Fuse()
MERGE_MODULE_LIST = MergeModuleList()
FP8_QUANTIZE = Fp8Quantize()
SLICE = Slice()
def __call__(self, *args, **kwargs):
# Call enum member as a constructor: ConversionType.FUSE() -> Fuse()
return self.value(*args, **kwargs) @ dataclass(frozen=True)


globals().update({member.name: member for member in ConversionType})


class WeightConversion:
"""

Specification for applying renaming and other operations.

Most probably take the tp_plan here, the quantization_config, and call all the different ops
"""

new_key_name: str
operations: Optional[list[ConversionType]] # if TP or quantization, some ops like "slicing" will be added?S

def __init__(self, new_key_name, operations: Optional[Union[ConversionType, list[ConversionType]]]):
self.new_key_name
self.operations = list(operations) if not isinstance(operations, list) else operations

# Ex rank1 for w1,w3 -> gate_up_proj:
# 1. read the weights
# 2. rename
# 3. MergeModuleList, but dim=0, and there is tp_plan on gate_up_proj -> slice to only experts of this rank
# 4. cat(cat(gate_4, gate_5, gate_6, gate_7), cat(up_4, up_5, up_6, up_7))
# 5. quantize? -> A new ConversionType op

# We want the quantizers to have:
# -


__all__ = ["WeightConversion", "ConversionType"]
1 change: 1 addition & 0 deletions src/transformers/integrations/finegrained_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return output.to(dtype=input.dtype)


# TODO: we do need this....
def _replace_with_fp8_linear(
model,
tp_plan=None,
Expand Down
Loading
Loading