Skip to content

Commit 5ecc871

Browse files
shuhuayuAchazwl
andauthored
Fix bugs in initial_load_in_hf when enable_weight_tying=true in Qwen3 (#1999)
Rebased on main to merge this pr: #1964 --------- Co-authored-by: William <acha131441373@gmail.com> Co-authored-by: Achazwl <323163497@qq.com>
1 parent b6b2c2d commit 5ecc871

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

torchtitan/models/qwen3/model/state_dict_adapter.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
104104
else:
105105
if key not in to_hf_map:
106106
continue
107+
if self.model_args.enable_weight_tying and key == "output.weight":
108+
continue
107109
new_key = to_hf_map[key]
108110
hf_state_dict[new_key] = value
109111

@@ -118,6 +120,13 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
118120
state_dict = {}
119121
expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}}
120122

123+
if (
124+
self.model_args.enable_weight_tying
125+
and "lm_head.weight" not in hf_state_dict
126+
):
127+
assert "model.embed_tokens.weight" in hf_state_dict
128+
hf_state_dict["lm_head.weight"] = hf_state_dict["model.embed_tokens.weight"]
129+
121130
for key, value in hf_state_dict.items():
122131
if "mlp.experts" in key:
123132
abstract_key = re.sub(r"(\d+)", "{}", key, count=2)

0 commit comments

Comments
 (0)