Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
387 commits
Select commit Hold shift + click to select a range
4d79709
ah actually we don't discard lm head if missing -> needs to be moved …
ArthurZucker Nov 3, 2025
d1e84db
fix some tests
ArthurZucker Nov 3, 2025
f2938df
small fixes
ArthurZucker Nov 3, 2025
22fcdaf
up
ArthurZucker Nov 3, 2025
7d78aa1
up
ArthurZucker Nov 3, 2025
80517f5
dik why we tie weights twice but,..,,.
ArthurZucker Nov 3, 2025
2ff8532
ups
ArthurZucker Nov 3, 2025
d923061
removeunused
ArthurZucker Nov 3, 2025
ce8c1c1
fix hunyuan
ArthurZucker Nov 3, 2025
23e3ed7
small fix
ArthurZucker Nov 3, 2025
a8fb554
nits
ArthurZucker Nov 3, 2025
ab6ee8a
ish
ArthurZucker Nov 3, 2025
77ccbb1
up
ArthurZucker Nov 3, 2025
8a8beff
rev
ArthurZucker Nov 3, 2025
02386ce
fix more tie weights keys
ArthurZucker Nov 3, 2025
1c87945
small fixes
ArthurZucker Nov 3, 2025
00b95ee
nit
ArthurZucker Nov 3, 2025
a170f29
update
ArthurZucker Nov 3, 2025
8b924a3
fix and fix
ArthurZucker Nov 3, 2025
8f7b1d0
fix a test
ArthurZucker Nov 3, 2025
9386217
glubs
ArthurZucker Nov 3, 2025
4894a25
current shitty changes
ArthurZucker Nov 3, 2025
da7dc10
ship validated ones
ArthurZucker Nov 4, 2025
d7c8171
more
ArthurZucker Nov 4, 2025
e088408
more update
ArthurZucker Nov 4, 2025
4f212de
more
ArthurZucker Nov 4, 2025
dc5a22c
more
ArthurZucker Nov 4, 2025
675b2bc
more
ArthurZucker Nov 4, 2025
f85f239
mllama
ArthurZucker Nov 4, 2025
76b6a92
more up
ArthurZucker Nov 4, 2025
ba1a8b6
fix ernie
ArthurZucker Nov 4, 2025
ba3de5a
fix xopies
ArthurZucker Nov 4, 2025
8fd255c
up more
ArthurZucker Nov 4, 2025
5d7507b
more fixes
ArthurZucker Nov 4, 2025
0fb2340
up
ArthurZucker Nov 4, 2025
32b9273
up
ArthurZucker Nov 4, 2025
0b95826
fix-copies
ArthurZucker Nov 4, 2025
5794d27
fix more
ArthurZucker Nov 4, 2025
5e71bd4
more updates
ArthurZucker Nov 4, 2025
20d1b34
AI UPDATE
ArthurZucker Nov 4, 2025
89846e7
up
ArthurZucker Nov 5, 2025
a581fd7
hoey
ArthurZucker Nov 5, 2025
1652c9c
make it fast
Cyrilvallez Nov 5, 2025
dcad703
fix
Cyrilvallez Nov 5, 2025
c921ced
lol
ArthurZucker Nov 5, 2025
50714d8
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 5, 2025
8936cc4
fix asjusting
ArthurZucker Nov 5, 2025
5c54332
more fixes
ArthurZucker Nov 5, 2025
ff10878
_dtype nit
ArthurZucker Nov 5, 2025
9601b82
up
ArthurZucker Nov 5, 2025
db02b9d
nit
ArthurZucker Nov 5, 2025
42fd4c4
update
ArthurZucker Nov 5, 2025
4527171
update
ArthurZucker Nov 5, 2025
bd36211
remove semaphores
Cyrilvallez Nov 5, 2025
e2aefee
fix import to avoid jit execution
Cyrilvallez Nov 5, 2025
74a0e9c
try to remove custom tiing logic when its stupid
ArthurZucker Nov 5, 2025
ead2ac3
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 5, 2025
e7165da
fix more individual models
ArthurZucker Nov 5, 2025
2ff765e
fix whisper as well
ArthurZucker Nov 5, 2025
912562c
fix?
ArthurZucker Nov 5, 2025
c43495a
fox umt5
ArthurZucker Nov 5, 2025
57988f2
improve tqdm bar
Cyrilvallez Nov 5, 2025
8c16de1
cleanup a bit
Cyrilvallez Nov 5, 2025
b8927d6
oupsi
Cyrilvallez Nov 5, 2025
2733ff6
some updates
ArthurZucker Nov 5, 2025
8baa3fe
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 5, 2025
d91701f
improve
Cyrilvallez Nov 5, 2025
5146dec
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
Cyrilvallez Nov 5, 2025
acc5b24
remove all buffering -> much faster without it
Cyrilvallez Nov 5, 2025
58389a1
remove some tie_weights custome funcs when not needed
ArthurZucker Nov 5, 2025
92c0229
more fixes related to strict matching regex
ArthurZucker Nov 5, 2025
d9e7fe6
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 5, 2025
b57d789
remove ALL custom tie weights
ArthurZucker Nov 5, 2025
ef8b6c3
small update
ArthurZucker Nov 5, 2025
a228fd0
revert change to init scheme (no need for params)
Cyrilvallez Nov 5, 2025
07574dd
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 5, 2025
2526cc5
mixtral init
Cyrilvallez Nov 5, 2025
6cb3794
try less strict source check
ArthurZucker Nov 5, 2025
e4cadfb
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 5, 2025
3fea865
tied weight first shot to the fiiiixxxxxx
Cyrilvallez Nov 5, 2025
82f94b8
does this help?
ArthurZucker Nov 5, 2025
84dd6eb
:)
ArthurZucker Nov 5, 2025
cc08195
fix some ppolry defined tied_weights_keys for now
ArthurZucker Nov 5, 2025
f692f4b
subclass nn.Parameters
ArthurZucker Nov 7, 2025
2fa058f
up
ArthurZucker Nov 7, 2025
78d4622
lol
ArthurZucker Nov 7, 2025
8ff4ad5
Ouiiii
ArthurZucker Nov 7, 2025
3222678
fix led
ArthurZucker Nov 7, 2025
9a76a6e
fix long cat flash
ArthurZucker Nov 7, 2025
9fde9f7
fix qwen and long cat flash
ArthurZucker Nov 7, 2025
074a449
properly fix qwen init
ArthurZucker Nov 7, 2025
dde5500
just push this for now
ArthurZucker Nov 7, 2025
0e7d2d0
propnet is dumb
ArthurZucker Nov 7, 2025
18b02ee
update
ArthurZucker Nov 7, 2025
9c0db72
push
ArthurZucker Nov 7, 2025
75d3afc
remove explict sharing of some tied keys.
ArthurZucker Nov 7, 2025
85ab085
update decoder.bias
ArthurZucker Nov 7, 2025
443573a
moe case
ArthurZucker Nov 7, 2025
f8f0973
more changes to untangle old hardcoded ting
ArthurZucker Nov 7, 2025
5c9d56c
fixup
ArthurZucker Nov 7, 2025
a0029f2
Merge branch 'main' into refactor-weight-loading
ArthurZucker Nov 7, 2025
44943fb
fix big faileurs
ArthurZucker Nov 7, 2025
76d66be
fix prophnet
ArthurZucker Nov 7, 2025
d176b48
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 7, 2025
3ffc59e
fix resize token embeddings
ArthurZucker Nov 10, 2025
2a00e49
nits
ArthurZucker Nov 10, 2025
f7d0183
fix xcodex
ArthurZucker Nov 10, 2025
bbf5b00
asyncio?
ArthurZucker Nov 10, 2025
0412832
fix smart apply
ArthurZucker Nov 10, 2025
c137ea3
fix data-2-vec
ArthurZucker Nov 10, 2025
7b7c990
[build-ci-image]
ArthurZucker Nov 10, 2025
de74aeb
checkout
ArthurZucker Nov 10, 2025
94a53d4
uupdate
ArthurZucker Nov 10, 2025
8755a4b
fix hunyuan
ArthurZucker Nov 10, 2025
5be67b9
update error message
ArthurZucker Nov 10, 2025
86a4e51
fix deformable detr
ArthurZucker Nov 10, 2025
09bcd2e
fixes
ArthurZucker Nov 10, 2025
7b457fd
fix init weights for non param gate up projs
ArthurZucker Nov 10, 2025
e033947
shared todo?
ArthurZucker Nov 10, 2025
f93f357
update some models
ArthurZucker Nov 10, 2025
2f0a6ae
big revert, don't break this behaviour
ArthurZucker Nov 10, 2025
3c8c757
ty @SunMarc this fixes the buffers
ArthurZucker Nov 10, 2025
f5a7c33
mt5 fuck
ArthurZucker Nov 10, 2025
647f720
fix lxmbert
ArthurZucker Nov 10, 2025
bed6ea1
nuke slow test fetcher
ArthurZucker Nov 10, 2025
2ec0a5f
fix zamba and deepcopy for now
ArthurZucker Nov 10, 2025
f9c7ef8
fix zamba tied weight keys! ~
ArthurZucker Nov 10, 2025
8df3ffd
fix-copies
ArthurZucker Nov 10, 2025
e76481b
update fetch terst
ArthurZucker Nov 10, 2025
de00751
fix gradient for test modeling common!
ArthurZucker Nov 10, 2025
cdd1a9b
break "shared" for now I will fix tomorrow changes are properly isoal…
ArthurZucker Nov 10, 2025
d3f6476
does this fix marian? probably not
ArthurZucker Nov 10, 2025
0a7db83
fix some vlms
ArthurZucker Nov 10, 2025
1814200
D fine seems to handle this well
ArthurZucker Nov 10, 2025
b77825d
glob is fine actually
ArthurZucker Nov 11, 2025
5dbb783
fix dab detr
ArthurZucker Nov 11, 2025
9edc81b
small steps
ArthurZucker Nov 11, 2025
970f4e5
opusy
ArthurZucker Nov 11, 2025
0361d47
fix some more models?
ArthurZucker Nov 11, 2025
dc75773
yups
ArthurZucker Nov 11, 2025
cdb1284
better erro
ArthurZucker Nov 11, 2025
de9a2d9
fix?
ArthurZucker Nov 11, 2025
b9a9f4d
fix double escape
ArthurZucker Nov 11, 2025
c944619
escape wehere it makes sense
ArthurZucker Nov 11, 2025
f910524
??
ArthurZucker Nov 11, 2025
4aa2ade
fix ibert
ArthurZucker Nov 11, 2025
2ef1c2b
fix tvp as well
ArthurZucker Nov 11, 2025
b98a7bc
more fxes
ArthurZucker Nov 11, 2025
74e6c87
try always download ref PR
ArthurZucker Nov 11, 2025
5064edd
ONONONO
ArthurZucker Nov 11, 2025
3f8a304
big fixup
ArthurZucker Nov 11, 2025
3ecaa63
more fixup
ArthurZucker Nov 11, 2025
f384524
small step
ArthurZucker Nov 11, 2025
290337a
small nits
ArthurZucker Nov 11, 2025
76b388c
nits
ArthurZucker Nov 11, 2025
e69b988
brut force some stuff
ArthurZucker Nov 11, 2025
c2781f5
fix vilt
ArthurZucker Nov 11, 2025
f64ee96
make sure special models that always need tie always tie
ArthurZucker Nov 11, 2025
a3e4015
cleaning up
ArthurZucker Nov 11, 2025
9eecbd2
small nits
ArthurZucker Nov 11, 2025
b2fa432
fix zamba and bridge tower!
ArthurZucker Nov 11, 2025
dbbfdf2
just fixup
ArthurZucker Nov 11, 2025
ab4890c
potential culprits
ArthurZucker Nov 11, 2025
937ebf3
revert bark and fix bridgetower
ArthurZucker Nov 11, 2025
e4f9697
Merge branch 'main' of github.com:huggingface/transformers into refac…
ArthurZucker Nov 11, 2025
17803ce
remove now non existant tie_weights
ArthurZucker Nov 11, 2025
9f6838a
?
ArthurZucker Nov 11, 2025
1afb3eb
lol reformer actually had nothing tied!
ArthurZucker Nov 11, 2025
f01a149
wow these two fucking models were really not well made
ArthurZucker Nov 11, 2025
0b36980
fix sam family!
ArthurZucker Nov 11, 2025
d740c82
fix bark revision
ArthurZucker Nov 11, 2025
6f3940e
fix speech2test ?
ArthurZucker Nov 11, 2025
b2f6f61
push this for now....
ArthurZucker Nov 11, 2025
ade8dab
upsy
ArthurZucker Nov 11, 2025
f956ccf
the fuck
ArthurZucker Nov 11, 2025
99c6fd4
fix rtdetr
ArthurZucker Nov 11, 2025
1ffcfc3
update
ArthurZucker Nov 11, 2025
ee62aec
proper
ArthurZucker Nov 11, 2025
6ec80f8
wow that one 's annoying
ArthurZucker Nov 11, 2025
b05e329
update
ArthurZucker Nov 11, 2025
2606596
try to find the culprit
ArthurZucker Nov 11, 2025
d9e8a09
get some help on common
ArthurZucker Nov 12, 2025
581665a
nit about general init and cls.padding_idx
ArthurZucker Nov 12, 2025
c43bc68
revert num workers update
ArthurZucker Nov 12, 2025
b6fe415
remove old loading func
Cyrilvallez Nov 12, 2025
4bb8e5c
fix glob
ArthurZucker Nov 12, 2025
7d52b06
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 12, 2025
455bcc7
add annotations
Cyrilvallez Nov 12, 2025
fc884c0
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
Cyrilvallez Nov 12, 2025
2e0ed5d
fix re
ArthurZucker Nov 12, 2025
3ddd1cc
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 12, 2025
1f86a10
small improvements
Cyrilvallez Nov 12, 2025
4d56fbf
fix conflict
Cyrilvallez Nov 12, 2025
67a8eeb
clean some stuff
Cyrilvallez Nov 12, 2025
e9168ff
improvements
Cyrilvallez Nov 12, 2025
feda22d
someone did not understannnnnnd what I tried to dooo or does BNB not …
ArthurZucker Nov 12, 2025
70841c9
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 12, 2025
52248ba
gluos
ArthurZucker Nov 12, 2025
e8dd4a4
fix case when `.` is just not there
ArthurZucker Nov 12, 2025
1c67fc4
remove unused arg
Cyrilvallez Nov 12, 2025
e20ed00
recover orignal parameter/buffer using _original
SunMarc Nov 12, 2025
827c42a
fix glob issu
ArthurZucker Nov 12, 2025
e5e4d28
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 12, 2025
4db2aa6
this?
ArthurZucker Nov 12, 2025
2b16c17
deepspeed best-effort
Cyrilvallez Nov 12, 2025
c411ddb
remove unused stuff
Cyrilvallez Nov 12, 2025
56d368b
Update tie weight keys as they were just wroong
ArthurZucker Nov 12, 2025
85d0ac1
up
ArthurZucker Nov 12, 2025
daa642c
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 12, 2025
bbf71b9
augustuc clauss, a gloubs gloups gloubs
ArthurZucker Nov 12, 2025
127e4d5
fixup
ArthurZucker Nov 12, 2025
7954185
fixup
ArthurZucker Nov 12, 2025
f7cd4b3
there was fucking typo
ArthurZucker Nov 12, 2025
f9e747e
mrain
ArthurZucker Nov 12, 2025
57bf5b2
nits
ArthurZucker Nov 12, 2025
c38ad24
fix marian 3 remaining tests
ArthurZucker Nov 12, 2025
d7be7df
one more
ArthurZucker Nov 12, 2025
729e3df
fix some of the copies, not all :)
ArthurZucker Nov 12, 2025
c95a3f1
small cleanup
ArthurZucker Nov 12, 2025
8778840
one propertest
ArthurZucker Nov 13, 2025
1181e3f
fix core model loadig tes
ArthurZucker Nov 13, 2025
b750e6b
attempt a new test
ArthurZucker Nov 13, 2025
3178c3f
fix some of the annoying tests by supporting reading .bin sometimes
ArthurZucker Nov 13, 2025
d6ab250
push
ArthurZucker Nov 13, 2025
0695197
push more small fixes
ArthurZucker Nov 13, 2025
fd5a75a
Merge branch 'main' of github.com:huggingface/transformers into refac…
ArthurZucker Nov 13, 2025
f54b528
remove 1 useless test
ArthurZucker Nov 13, 2025
1abf6a9
up
ArthurZucker Nov 13, 2025
3014290
fix audio flamingo post rebase
ArthurZucker Nov 13, 2025
1f1bea3
fixup
ArthurZucker Nov 13, 2025
c2dbca0
some small updatess
ArthurZucker Nov 13, 2025
347b966
fix sam models
ArthurZucker Nov 13, 2025
40ed636
nits
ArthurZucker Nov 13, 2025
3b2f934
up
ArthurZucker Nov 13, 2025
fb0fb89
updates
ArthurZucker Nov 13, 2025
92e2771
onem ore
ArthurZucker Nov 13, 2025
06f2ba9
skip this stupid test
ArthurZucker Nov 13, 2025
3d5c86c
some other fixes
ArthurZucker Nov 13, 2025
15bc48e
fixup
ArthurZucker Nov 13, 2025
47743f8
update
ArthurZucker Nov 13, 2025
d77cf57
skip more offloaded stuff
ArthurZucker Nov 13, 2025
75f2bd4
oups
ArthurZucker Nov 13, 2025
08ad69b
ups
ArthurZucker Nov 13, 2025
b605e1a
update mixtral
ArthurZucker Nov 13, 2025
91d40b8
skip this one
ArthurZucker Nov 13, 2025
638bbfc
LET"SGO
ArthurZucker Nov 13, 2025
7daacb4
fixup
ArthurZucker Nov 13, 2025
22c19a7
rope delta order
ArthurZucker Nov 13, 2025
6d89354
fix csm
ArthurZucker Nov 13, 2025
9ccb693
small nit
ArthurZucker Nov 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
8 changes: 4 additions & 4 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ jobs:
- run: uv pip install -U -e .
- run: echo 'export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)"' >> "$BASH_ENV" && source "$BASH_ENV"
- run: mkdir -p test_preparation
- run: python utils/tests_fetcher.py | tee tests_fetched_summary.txt
- run: python utils/tests_fetcher.py --filter_tests
- run: python utils/tests_fetcher.py | tee tests_fetched_summary.txt || true
- run: python utils/tests_fetcher.py --filter_tests || true
- run: export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)" && echo $GIT_COMMIT_MESSAGE && python .circleci/create_circleci_config.py --fetcher_folder test_preparation
- run: |
if [ ! -s test_preparation/generated_config.yml ]; then
Expand Down Expand Up @@ -98,8 +98,8 @@ jobs:
- run: uv pip install -U -e .
- run: echo 'export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)"' >> "$BASH_ENV" && source "$BASH_ENV"
- run: mkdir -p test_preparation
- run: python utils/tests_fetcher.py --fetch_all | tee tests_fetched_summary.txt
- run: python utils/tests_fetcher.py --filter_tests
- run: python utils/tests_fetcher.py --fetch_all | tee tests_fetched_summary.txt || true
- run: python utils/tests_fetcher.py --filter_tests || true
- run: export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)" && echo $GIT_COMMIT_MESSAGE && python .circleci/create_circleci_config.py --fetcher_folder test_preparation
- run: |
if [ ! -s test_preparation/generated_config.yml ]; then
Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ repo-consistency:
python utils/check_modular_conversion.py
python utils/check_dummies.py
python utils/check_repo.py
python utils/check_init_weights_data.py
python utils/check_inits.py
python utils/check_pipeline_typing.py
python utils/check_config_docstrings.py
Expand Down
14 changes: 7 additions & 7 deletions docs/source/de/add_new_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -508,16 +508,16 @@ BERT `_init_weights` Methode:
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
```

Sie können weitere benutzerdefinierte Schemata verwenden, wenn Sie eine spezielle Initialisierung für einige Module benötigen. Zum Beispiel in
Expand All @@ -533,9 +533,9 @@ def _init_weights(self, module):
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
```

Das Flag `_is_hf_initialized` wird intern verwendet, um sicherzustellen, dass wir ein Submodul nur einmal initialisieren. Wenn Sie es auf
Expand Down
14 changes: 7 additions & 7 deletions docs/source/en/add_new_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -314,16 +314,16 @@ Random initialization occurs in the `_init_weights` method of `BrandNewLlamaPreT
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
```

The initialization scheme can look different if you need to adapt it to your model. For example, [`Wav2Vec2ForPreTraining`] initializes [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) in its last two linear layers.
Expand All @@ -339,9 +339,9 @@ def _init_weights(self, module):
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
```

### Convert checkpoints to Transformers
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/perf_infer_gpu_multi.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ The example below packs `up_proj` and `gate_proj` into a single `gate_up_proj` m
```python
class Llama4TextExperts(nn.Module):
...
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we do zeros? Might make sense to have 1s instead?

Ig this is tied to not using init weights

```

Batch matrix multiplication can be used in the `forward` pass to compute the output of the `gate_up_proj` module.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/it/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ Per quanto riguarda la classe `TrainingArguments`:
- L'argomento `evaluate_during_training` di `TrainingArguments` è deprecato a favore di `eval_strategy`.

Per quanto riguarda il modello Transfo-XL:
- L'attributo di configurazione `tie_weight` di Transfo-XL diventa `tie_words_embeddings`.
- L'attributo di configurazione `tie_weight` di Transfo-XL diventa `tie_word_embeddings`.
- Il metodo di modellazione `reset_length` di Transfo-XL diventa `reset_memory_length`.

Per quanto riguarda le pipeline:
Expand Down
14 changes: 7 additions & 7 deletions docs/source/ja/add_new_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,16 +406,16 @@ model = BrandNewBertModel(BrandNewBertConfig())
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
```

特定のモジュールに特別な初期化が必要な場合、カスタムスキームをさらに持つことができます。たとえば、
Expand All @@ -431,9 +431,9 @@ def _init_weights(self, module):
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
```

`_is_hf_initialized`フラグは、サブモジュールを一度だけ初期化することを確実にするために内部で使用されます。
Expand Down
14 changes: 7 additions & 7 deletions docs/source/ko/add_new_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -348,16 +348,16 @@ model = BrandNewBertModel(BrandNewBertConfig())
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
```

몇 가지 모듈에 대해 특별한 초기화가 필요한 경우 사용자 정의 방식을 사용할 수도 있습니다. 예를 들어, `Wav2Vec2ForPreTraining`에서 마지막 두 개의 선형 레이어는 일반적인 PyTorch `nn.Linear`의 초기화를 가져야 하지만, 다른 모든 레이어는 위와 같은 초기화를 사용해야 합니다. 이는 다음과 같이 코드화됩니다:
Expand All @@ -371,9 +371,9 @@ def _init_weights(self, module):
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
```

`_is_hf_initialized` 플래그는 서브모듈을 한 번만 초기화하도록 내부적으로 사용됩니다. `module.project_q``module.project_hid`에 대해 `True`로 설정함으로써, 우리가 수행한 사용자 정의 초기화가 이후에 덮어쓰이지 않도록 합니다. 즉, `_init_weights` 함수가 이들에게 적용되지 않습니다.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/ko/perf_infer_gpu_multi.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class ParallelInterface(MutableMapping):
```python
class Llama4TextExperts(nn.Module):
...
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
```

배치 행렬 곱셈을 `forward` 패스에서 사용하여 `gate_up_proj` 모듈의 출력을 계산할 수 있습니다.
Expand Down
20 changes: 7 additions & 13 deletions examples/modular-transformers/modeling_dummy_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,16 +502,10 @@ def __init__(self, config):

# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)

self.bias = nn.Parameter(torch.zeros(config.vocab_size))

# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand All @@ -536,18 +530,18 @@ class DummyBertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, DummyBertLMPredictionHead):
module.bias.data.zero_()
module.bias.zero_()


@auto_docstring(
Expand Down
2 changes: 1 addition & 1 deletion examples/modular-transformers/modeling_my_new_model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def _init_weights(self, module):

# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
if "RMSNorm" in module.__class__.__name__:
module.weight.data.zero_()
module.weight.zero_()


class MyNewModel2ForSequenceClassification(GenericForSequenceClassification, MyNewModel2PreTrainedModel):
Expand Down
16 changes: 12 additions & 4 deletions examples/modular-transformers/modeling_new_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)

if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()


def token_type_ids_mask_function(
Expand Down Expand Up @@ -428,7 +428,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related

def __init__(self, config):
Expand All @@ -440,7 +440,15 @@ def __init__(self, config):
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)

if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
prefix = "model.language_model."
prefixed_mapping = {
f"{prefix}{target}": f"{prefix}{source}"
for target, source in self.language_model._tied_weights_keys.items()
}
if isinstance(self._tied_weights_keys, dict):
self._tied_weights_keys.update(prefixed_mapping)
else:
self._tied_weights_keys = prefixed_mapping
self.post_init()

def get_input_embeddings(self):
Expand Down
20 changes: 7 additions & 13 deletions examples/modular-transformers/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,16 +505,10 @@ def __init__(self, config):

# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)

self.bias = nn.Parameter(torch.zeros(config.vocab_size))

# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand All @@ -539,18 +533,18 @@ class RobertaPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, RobertaLMPredictionHead):
module.bias.data.zero_()
module.bias.zero_()


@auto_docstring(
Expand Down
6 changes: 3 additions & 3 deletions examples/modular-transformers/modeling_test_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,11 +846,11 @@ def _init_weights(self, module):
nn.init.xavier_uniform_(module.output_proj.weight.data)
nn.init.constant_(module.output_proj.bias.data, 0.0)
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if hasattr(module, "reference_points") and not self.config.two_stage:
Expand Down
10 changes: 9 additions & 1 deletion examples/modular-transformers/modular_new_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@ def __init__(self, config):
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)

if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
prefix = "model.language_model."
prefixed_mapping = {
f"{prefix}{target}": f"{prefix}{source}"
for target, source in self.language_model._tied_weights_keys.items()
}
if isinstance(self._tied_weights_keys, dict):
self._tied_weights_keys.update(prefixed_mapping)
else:
self._tied_weights_keys = prefixed_mapping

self.post_init()

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ def to_diff_dict(self) -> dict[str, Any]:
if hasattr(self, "quantization_config"):
serializable_config_dict["quantization_config"] = (
self.quantization_config.to_dict()
if not isinstance(self.quantization_config, dict)
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
else self.quantization_config
)
self.dict_dtype_to_str(serializable_config_dict)
Expand Down Expand Up @@ -910,7 +910,7 @@ def to_dict(self) -> dict[str, Any]:
if hasattr(self, "quantization_config"):
output["quantization_config"] = (
self.quantization_config.to_dict()
if not isinstance(self.quantization_config, dict)
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
else self.quantization_config
)
self.dict_dtype_to_str(output)
Expand Down
Loading