Skip to content

Commit 8598421

Browse files
authored
Much more efficient and clear weight initialization and tie weights (#42191)
* everything untilo informer * everything until perceiver * all of them finally * style * replace by transformers init everywhere * use relative import instead * deprecated models * style * start contexts * small fixes * fix modular * remove class switch * do not initialize tied weights * typo * fix * improve * improve comments * improve * improve * fix zamba * fix import * add the post_init * more post_init * fix * protect * more post_init * fix * fixes * fix * fix * switch flag name * more fixes * fixes * fixes * copies * fix * finally find the culprit * style * last small * big bird * better * update init check * final touch * do it everywhere
1 parent 16c7afd commit 8598421

File tree

416 files changed

+3341
-5443
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

416 files changed

+3341
-5443
lines changed

src/transformers/core_model_loading.py

Lines changed: 0 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from contextlib import contextmanager
2727
from dataclasses import dataclass, field
2828
from functools import partial
29-
from types import MethodType
3029
from typing import TYPE_CHECKING, Any, Optional, Union
3130

3231
import torch
@@ -313,120 +312,6 @@ class ConversionEntry:
313312
GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4
314313

315314

316-
# Factory function to create LoadedParameter subclasses dynamically
317-
def get_loaded_parameter_class(base_cls):
318-
"""
319-
base_cls: an nn.Parameter subclass (or nn.Parameter) or a Tensor
320-
Returns a new class that combines the base_cls with LoadedParameterMixin
321-
322-
"""
323-
324-
class LoadedParam(base_cls):
325-
_inplace_methods = [
326-
"add_",
327-
"mul_",
328-
"clamp_",
329-
"zero_",
330-
"fill_",
331-
"normal_",
332-
"uniform_",
333-
"copy_",
334-
"erfinv_",
335-
"log_",
336-
"__getitem__",
337-
"neg_",
338-
"exp_",
339-
"sub_",
340-
]
341-
342-
def __new__(cls, from_existing, **kwargs):
343-
if isinstance(from_existing, torch.nn.Parameter):
344-
inst = super().__new__(cls, from_existing.data, from_existing.requires_grad, **from_existing.__dict__)
345-
else:
346-
inst = super().__new__(cls, from_existing)
347-
# we store the original object to get it back later on
348-
inst._original = from_existing
349-
# Explicitly override all in-place methods per instance
350-
for method_name in inst._inplace_methods:
351-
setattr(inst, method_name, MethodType(inst._skip, inst))
352-
353-
return inst
354-
355-
def _skip(self, *args, **kwargs):
356-
"""Helper to skip in-place operations."""
357-
return self
358-
359-
def __repr__(self):
360-
return f"LoadedParameter(data={self.data})"
361-
362-
@property
363-
def data(self):
364-
return super().data
365-
366-
@data.setter
367-
def data(self, new):
368-
pass
369-
370-
def __lt__(self, other):
371-
return torch.Tensor.__lt__(self, other)
372-
373-
def __le__(self, other):
374-
return torch.Tensor.__le__(self, other)
375-
376-
def __gt__(self, other):
377-
return torch.Tensor.__gt__(self, other)
378-
379-
def __ge__(self, other):
380-
return torch.Tensor.__ge__(self, other)
381-
382-
def __eq__(self, other):
383-
return torch.Tensor.__eq__(self, other)
384-
385-
def __ne__(self, other):
386-
return torch.Tensor.__ne__(self, other)
387-
388-
def __iadd__(self, *args, **kwargs):
389-
return self
390-
391-
def __isub__(self, *args, **kwargs):
392-
return self
393-
394-
def __imul__(self, *args, **kwargs):
395-
return self
396-
397-
def __imatmul__(self, *args, **kwargs):
398-
return self
399-
400-
def __itruediv__(self, *args, **kwargs):
401-
return self
402-
403-
def __ifloordiv__(self, *args, **kwargs):
404-
return self
405-
406-
def __imod__(self, *args, **kwargs):
407-
return self
408-
409-
def __ipow__(self, *args, **kwargs):
410-
return self
411-
412-
def __iand__(self, *args, **kwargs):
413-
return self
414-
415-
def __ior__(self, *args, **kwargs):
416-
return self
417-
418-
def __ixor__(self, *args, **kwargs):
419-
return self
420-
421-
def __ilshift__(self, *args, **kwargs):
422-
return self
423-
424-
def __irshift__(self, *args, **kwargs):
425-
return self
426-
427-
return LoadedParam
428-
429-
430315
def _materialize_copy(tensor, dtype=None):
431316
tensor = tensor[...]
432317
if dtype is not None:
@@ -527,7 +412,6 @@ def set_param_for_module(
527412
param_value = param_value.to_local()
528413
if param_name not in module_obj._buffers:
529414
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
530-
param_value = get_loaded_parameter_class(param_value.__class__)(from_existing=param_value)
531415

532416
# Remove from missing keys (it's either mismatched, or all good)
533417
missing_keys.discard(layer_name)

src/transformers/generation/watermarking.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torch import nn
2424
from torch.nn import BCELoss
2525

26+
from .. import initialization as init
2627
from ..modeling_utils import PreTrainedModel
2728
from ..utils import ModelOutput, logging
2829
from .configuration_utils import PreTrainedConfig, WatermarkingConfig
@@ -387,7 +388,7 @@ def __init__(self, config):
387388
def _init_weights(self, module):
388389
"""Initialize the weights."""
389390
if isinstance(module, nn.Parameter):
390-
module.weight.normal_(mean=0.0, std=0.02)
391+
init.normal_(module.weight, mean=0.0, std=0.02)
391392

392393
def _compute_posterior(
393394
self,

src/transformers/initialization.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import sys
15+
from collections import defaultdict
16+
from contextlib import contextmanager
17+
18+
import torch
19+
20+
21+
# Record all the torch primitives in advance, so that we can use them without them being modified when we patch torch
22+
# in context managers
23+
TORCH_INIT_FUNCTIONS = {
24+
"uniform_": torch.nn.init.uniform_,
25+
"normal_": torch.nn.init.normal_,
26+
"constant_": torch.nn.init.constant_,
27+
"ones_": torch.nn.init.ones_,
28+
"zeros_": torch.nn.init.zeros_,
29+
"eye_": torch.nn.init.eye_,
30+
"dirac_": torch.nn.init.dirac_,
31+
"xavier_uniform_": torch.nn.init.xavier_uniform_,
32+
"xavier_normal_": torch.nn.init.xavier_normal_,
33+
"kaiming_uniform_": torch.nn.init.kaiming_uniform_,
34+
"kaiming_normal_": torch.nn.init.kaiming_normal_,
35+
"trunc_normal_": torch.nn.init.trunc_normal_,
36+
"orthogonal_": torch.nn.init.orthogonal_,
37+
"sparse_": torch.nn.init.sparse_,
38+
}
39+
40+
41+
def uniform_(
42+
tensor: torch.Tensor, a: float = 0.0, b: float = 1.0, generator: torch.Generator | None = None
43+
) -> torch.Tensor:
44+
if not getattr(tensor, "_is_hf_initialized", False):
45+
return TORCH_INIT_FUNCTIONS["uniform_"](tensor, a=a, b=b, generator=generator)
46+
return tensor
47+
48+
49+
def normal_(
50+
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, generator: torch.Generator | None = None
51+
) -> torch.Tensor:
52+
if not getattr(tensor, "_is_hf_initialized", False):
53+
return TORCH_INIT_FUNCTIONS["normal_"](tensor, mean=mean, std=std, generator=generator)
54+
return tensor
55+
56+
57+
def constant_(tensor: torch.Tensor, val: float) -> torch.Tensor:
58+
if not getattr(tensor, "_is_hf_initialized", False):
59+
return TORCH_INIT_FUNCTIONS["constant_"](tensor, val=val)
60+
return tensor
61+
62+
63+
def ones_(tensor: torch.Tensor) -> torch.Tensor:
64+
if not getattr(tensor, "_is_hf_initialized", False):
65+
return TORCH_INIT_FUNCTIONS["ones_"](tensor)
66+
return tensor
67+
68+
69+
def zeros_(tensor: torch.Tensor) -> torch.Tensor:
70+
if not getattr(tensor, "_is_hf_initialized", False):
71+
return TORCH_INIT_FUNCTIONS["zeros_"](tensor)
72+
return tensor
73+
74+
75+
def eye_(tensor: torch.Tensor) -> torch.Tensor:
76+
if not getattr(tensor, "_is_hf_initialized", False):
77+
return TORCH_INIT_FUNCTIONS["eye_"](tensor)
78+
return tensor
79+
80+
81+
def dirac_(tensor: torch.Tensor, groups: int = 1) -> torch.Tensor:
82+
if not getattr(tensor, "_is_hf_initialized", False):
83+
return TORCH_INIT_FUNCTIONS["dirac_"](tensor, groups=groups)
84+
return tensor
85+
86+
87+
def xavier_uniform_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
88+
if not getattr(tensor, "_is_hf_initialized", False):
89+
return TORCH_INIT_FUNCTIONS["xavier_uniform_"](tensor, gain=gain, generator=generator)
90+
return tensor
91+
92+
93+
def xavier_normal_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
94+
if not getattr(tensor, "_is_hf_initialized", False):
95+
return TORCH_INIT_FUNCTIONS["xavier_normal_"](tensor, gain=gain, generator=generator)
96+
return tensor
97+
98+
99+
def kaiming_uniform_(
100+
tensor: torch.Tensor,
101+
a: float = 0,
102+
mode: str = "fan_in",
103+
nonlinearity: str = "leaky_relu",
104+
generator: torch.Generator | None = None,
105+
) -> torch.Tensor:
106+
if not getattr(tensor, "_is_hf_initialized", False):
107+
return TORCH_INIT_FUNCTIONS["kaiming_uniform_"](
108+
tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
109+
)
110+
return tensor
111+
112+
113+
def kaiming_normal_(
114+
tensor: torch.Tensor,
115+
a: float = 0,
116+
mode: str = "fan_in",
117+
nonlinearity: str = "leaky_relu",
118+
generator: torch.Generator | None = None,
119+
) -> torch.Tensor:
120+
if not getattr(tensor, "_is_hf_initialized", False):
121+
return TORCH_INIT_FUNCTIONS["kaiming_normal_"](
122+
tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
123+
)
124+
return tensor
125+
126+
127+
def trunc_normal_(
128+
tensor: torch.Tensor,
129+
mean: float = 0.0,
130+
std: float = 1.0,
131+
a: float = -2.0,
132+
b: float = 2.0,
133+
generator: torch.Generator | None = None,
134+
) -> torch.Tensor:
135+
if not getattr(tensor, "_is_hf_initialized", False):
136+
return TORCH_INIT_FUNCTIONS["trunc_normal_"](tensor, mean=mean, std=std, a=a, b=b, generator=generator)
137+
return tensor
138+
139+
140+
def orthogonal_(
141+
tensor: torch.Tensor,
142+
gain: float = 1,
143+
generator: torch.Generator | None = None,
144+
) -> torch.Tensor:
145+
if not getattr(tensor, "_is_hf_initialized", False):
146+
return TORCH_INIT_FUNCTIONS["orthogonal_"](tensor, gain=gain, generator=generator)
147+
return tensor
148+
149+
150+
def sparse_(
151+
tensor: torch.Tensor, sparsity: float, std: float = 0.01, generator: torch.Generator | None = None
152+
) -> torch.Tensor:
153+
if not getattr(tensor, "_is_hf_initialized", False):
154+
return TORCH_INIT_FUNCTIONS["sparse_"](tensor, sparsity=sparsity, std=std, generator=generator)
155+
return tensor
156+
157+
158+
def copy_(tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
159+
if not getattr(tensor, "_is_hf_initialized", False):
160+
with torch.no_grad():
161+
return tensor.copy_(other)
162+
return tensor
163+
164+
165+
@contextmanager
166+
def guard_torch_init_functions():
167+
"""
168+
Guard the `torch.nn.init` primitive functions to behave exactly like the functions in this file, i.e. be
169+
protected against the `_is_hf_initialized` flag to avoid re-init if the param was already loaded.
170+
171+
Usually, all models are using the init from `transformers` which are already guarded, but just to make extra sure
172+
and for remote code, we also use this context manager.
173+
"""
174+
originals = defaultdict(dict)
175+
try:
176+
# Replace all torch funcs by the ones in this file
177+
for name in TORCH_INIT_FUNCTIONS.keys():
178+
# Here, we need to check all modules imported, and hot patch all of them, as usually torch does
179+
# something like `from torch.nn.init import xavier_uniform_` in their internals (e.g in torch.nn.modules,
180+
# where MultiHeadAttention lives), so the function name is binded at import time and just doing
181+
# `setattr(torch.nn.init, name, gloabls()[name])` is thus not enough
182+
for module in sys.modules.values():
183+
if module and hasattr(module, name):
184+
originals[module][name] = getattr(module, name)
185+
setattr(module, name, globals()[name])
186+
yield
187+
finally:
188+
# Set back the original functions on all modules
189+
for module, functions in originals.items():
190+
for name, func in functions.items():
191+
setattr(module, name, func)

0 commit comments

Comments
 (0)