@@ -32,42 +32,52 @@ def convert_hf_checkpoint(
3232 print (f"Model config { config .__dict__ } " )
3333
3434 weight_map = {
35- "tok_embeddings.weight" : "tok_embeddings.weight" ,
36- "layers.{}.attention.wq.weight" : "layers.{}.attention.wq.weight" ,
37- "layers.{}.attention.wk.weight" : "layers.{}.attention.wk.weight" ,
38- "layers.{}.attention.wv.weight" : "layers.{}.attention.wv.weight" ,
39- "layers.{}.attention.wo.weight" : "layers.{}.attention.wo.weight" ,
40- "layers.{}.block_sparse_moe.w1" : "layers.{}.block_sparse_moe.cond_ffn.w1" ,
41- "layers.{}.block_sparse_moe.w2" : "layers.{}.block_sparse_moe.cond_ffn.w2" ,
42- "layers.{}.block_sparse_moe.w3" : "layers.{}.block_sparse_moe.cond_ffn.w3" ,
43- "layers.{}.block_sparse_moe.gate.weight" : "layers.{}.block_sparse_moe.gate.weight" ,
44- "layers.{}.attention_norm.weight" : "layers.{}.attention_norm.weight" ,
45- "layers.{}.ffn_norm.weight" : "layers.{}.ffn_norm.weight" ,
46- "norm.weight" : "norm.weight" ,
47- "output.weight" : "output.weight" ,
35+ "model.embed_tokens.weight" : "tok_embeddings.weight" ,
36+ "model.layers.{}.attn.q_proj.weight" : "layers.{}.attention.wq.weight" ,
37+ "model.layers.{}.attn.k_proj.weight" : "layers.{}.attention.wk.weight" ,
38+ "model.layers.{}.attn.v_proj.weight" : "layers.{}.attention.wv.weight" ,
39+ "model.layers.{}.attn.o_proj.weight" : "layers.{}.attention.wo.weight" ,
40+ # "layers.{}.attention.wk.weight": "layers.{}.attention.wk.weight",
41+ # "layers.{}.attention.wv.weight": "layers.{}.attention.wv.weight",
42+ # "layers.{}.attention.wo.weight": "layers.{}.attention.wo.weight",
43+ "model.layers.{}.moe_block.experts.{}.linear.weight" : "layers.{}.block_sparse_moe.cond_ffn.w1.{}" ,
44+ "model.layers.{}.moe_block.experts.{}.linear_1.weight" : "layers.{}.block_sparse_moe.cond_ffn.w2.{}" ,
45+ "model.layers.{}.moe_block.experts.{}.linear_v.weight" : "layers.{}.block_sparse_moe.cond_ffn.w3.{}" ,
46+ "model.layers.{}.moe_block.gate.weight" : "layers.{}.block_sparse_moe.gate.weight" ,
47+ "model.layers.{}.pre_attn_norm.scale" : "layers.{}.pre_attn_norm.weight" ,
48+ "model.layers.{}.post_attn_norm.scale" : "layers.{}.post_attn_norm.weight" ,
49+ "model.layers.{}.pre_moe_norm.scale" : "layers.{}.pre_moe_norm.weight" ,
50+ "model.layers.{}.post_moe_norm.scale" : "layers.{}.post_moe_norm.weight" ,
51+ "model.norm.scale" : "norm.weight" ,
52+ "lm_head.weight" : "output.weight" ,
4853 }
4954
50- pt_files = glob .glob (str (checkpoint_dir / "*.pt " ))
55+ pt_files = glob .glob (str (checkpoint_dir / "*.bin " ))
5156
5257 merged_result = {}
5358 for file in sorted (pt_files ):
5459 state_dict = torch .load (str (file ), map_location = "cpu" , mmap = True , weights_only = True )
5560 merged_result .update (state_dict )
5661 final_result = {}
57- for key , value in merged_result .items ():
62+ for key , value in list ( merged_result .items () ):
5863 if "layers" in key :
59- abstract_key = re .sub (r'.(\d+).' , '.{}.' , key )
60- layer_num = re .search (r'\d+' , key ).group (0 )
64+ abstract_key = re .sub (r'\.(\d+)\.' , '.{}.' , key )
65+ nums = re .findall (r'\d+' , key )
66+ if abstract_key not in weight_map :
67+ continue
6168 new_key = weight_map [abstract_key ]
6269 if new_key is None :
6370 continue
64- new_key = new_key .format (layer_num )
71+ new_key = new_key .format (* nums )
6572 else :
73+ if key not in weight_map :
74+ continue
6675 new_key = weight_map [key ]
67-
6876 final_result [new_key ] = value
77+ del merged_result [key ]
6978
7079 for key in tuple (final_result .keys ()):
80+ print (key )
7181 if "wq" in key :
7282 q = final_result [key ]
7383 k = final_result [key .replace ("wq" , "wk" )]
@@ -77,9 +87,21 @@ def convert_hf_checkpoint(
7787 del final_result [key .replace ("wq" , "wk" )]
7888 del final_result [key .replace ("wq" , "wv" )]
7989 elif "w1" in key or "w3" in key :
80- final_result [key ] = final_result [key ].reshape (config .num_experts , config .intermediate_size , config .dim ).contiguous ()
90+ if not key .endswith ('0' ):
91+ continue
92+ full_keys = [key [:- 1 ] + str (i ) for i in range (8 )]
93+ results = [final_result [k ] for k in full_keys ]
94+ final_result [key [:- 2 ]] = torch .stack (results , dim = 0 )
95+ for k in full_keys :
96+ del final_result [k ]
8197 elif "w2" in key :
82- final_result [key ] = final_result [key ].reshape (config .num_experts , config .intermediate_size , config .dim ).permute (0 , 2 , 1 ).contiguous ()
98+ if not key .endswith ('0' ):
99+ continue
100+ full_keys = [key [:- 1 ] + str (i ) for i in range (8 )]
101+ results = [final_result [k ] for k in full_keys ]
102+ final_result [key [:- 2 ]] = torch .stack (results , dim = 0 )
103+ for k in full_keys :
104+ del final_result [k ]
83105 elif "gate" in key :
84106 final_result [key ] = final_result [key ].contiguous ()
85107
0 commit comments