55# LICENSE file in the root directory of this source tree.
66import json
77import re
8+ import shutil
89import sys
910from pathlib import Path
1011from typing import Optional
@@ -27,33 +28,62 @@ def convert_hf_checkpoint(
2728 if model_name is None :
2829 model_name = checkpoint_dir .name
2930
31+ # Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files
32+ # need to be copied into model.pth.
33+ # Llama 3 70B can't be easily merged into one model.pth file, though, since names of the
34+ # weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not
35+ # currently supported.
36+ # Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken
37+ is_llama3 = "Llama-3" in model_name
38+ if is_llama3 :
39+ # Check if we have multiple original/consolidated.NN.pth files and report error
40+ # if we do for Llama 3.
41+ original_dir = checkpoint_dir / "original"
42+ pattern = re .compile (r"^consolidated\.\d{2}\.pth$" )
43+ bin_files = [bin for bin in original_dir .iterdir () if pattern .match (bin .name )]
44+ if len (bin_files ) > 1 :
45+ raise ValueError (
46+ f"Multiple consolidated.NN.pth files found in { original_dir } . "
47+ "Merging them into one model.pth file is not supported for Llama 3." )
48+
49+
3050 config = ModelArgs .from_name (model_name )
3151 print (f"Model config { config .__dict__ } " )
3252
3353 # Load the json file containing weight mapping
34- model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
35-
36- assert model_map_json .is_file ()
37-
38- with open (model_map_json ) as json_map :
39- bin_index = json .load (json_map )
40-
41- weight_map = {
42- "model.embed_tokens.weight" : "tok_embeddings.weight" ,
43- "model.layers.{}.self_attn.q_proj.weight" : "layers.{}.attention.wq.weight" ,
44- "model.layers.{}.self_attn.k_proj.weight" : "layers.{}.attention.wk.weight" ,
45- "model.layers.{}.self_attn.v_proj.weight" : "layers.{}.attention.wv.weight" ,
46- "model.layers.{}.self_attn.o_proj.weight" : "layers.{}.attention.wo.weight" ,
47- 'model.layers.{}.self_attn.rotary_emb.inv_freq' : None ,
48- 'model.layers.{}.mlp.gate_proj.weight' : 'layers.{}.feed_forward.w1.weight' ,
49- "model.layers.{}.mlp.up_proj.weight" : "layers.{}.feed_forward.w3.weight" ,
50- "model.layers.{}.mlp.down_proj.weight" : "layers.{}.feed_forward.w2.weight" ,
51- "model.layers.{}.input_layernorm.weight" : "layers.{}.attention_norm.weight" ,
52- "model.layers.{}.post_attention_layernorm.weight" : "layers.{}.ffn_norm.weight" ,
53- "model.norm.weight" : "norm.weight" ,
54- "lm_head.weight" : "output.weight" ,
55- }
56- bin_files = {checkpoint_dir / bin for bin in bin_index ["weight_map" ].values ()}
54+ if not is_llama3 :
55+ model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
56+
57+ assert model_map_json .is_file ()
58+
59+ with open (model_map_json ) as json_map :
60+ bin_index = json .load (json_map )
61+
62+ weight_map = {
63+ "model.embed_tokens.weight" : "tok_embeddings.weight" ,
64+ "model.layers.{}.self_attn.q_proj.weight" : "layers.{}.attention.wq.weight" ,
65+ "model.layers.{}.self_attn.k_proj.weight" : "layers.{}.attention.wk.weight" ,
66+ "model.layers.{}.self_attn.v_proj.weight" : "layers.{}.attention.wv.weight" ,
67+ "model.layers.{}.self_attn.o_proj.weight" : "layers.{}.attention.wo.weight" ,
68+ 'model.layers.{}.self_attn.rotary_emb.inv_freq' : None ,
69+ 'model.layers.{}.mlp.gate_proj.weight' : 'layers.{}.feed_forward.w1.weight' ,
70+ "model.layers.{}.mlp.up_proj.weight" : "layers.{}.feed_forward.w3.weight" ,
71+ "model.layers.{}.mlp.down_proj.weight" : "layers.{}.feed_forward.w2.weight" ,
72+ "model.layers.{}.input_layernorm.weight" : "layers.{}.attention_norm.weight" ,
73+ "model.layers.{}.post_attention_layernorm.weight" : "layers.{}.ffn_norm.weight" ,
74+ "model.norm.weight" : "norm.weight" ,
75+ "lm_head.weight" : "output.weight" ,
76+ }
77+ bin_files = {checkpoint_dir / bin for bin in bin_index ["weight_map" ].values ()}
78+ else :
79+ # There is no separate pytorch_model.bin.index.json file for llama3.
80+ # Instead, we will just use all original/consolidated.NN.pth files.
81+ # so, we use model.safetensors.index.json
82+ weight_map = None
83+ original_dir = checkpoint_dir / "original"
84+ pattern = re .compile (r"^consolidated\.\d{2}\.pth$" )
85+ bin_files = {bin for bin in original_dir .iterdir () if pattern .match (bin .name )}
86+
5787
5888 def permute (w , n_head ):
5989 dim = config .dim
@@ -68,32 +98,41 @@ def permute(w, n_head):
6898 state_dict = torch .load (str (file ), map_location = "cpu" , mmap = True , weights_only = True )
6999 merged_result .update (state_dict )
70100 final_result = {}
71- for key , value in merged_result .items ():
72- if "layers" in key :
73- abstract_key = re .sub (r'(\d+)' , '{}' , key )
74- layer_num = re .search (r'\d+' , key ).group (0 )
75- new_key = weight_map [abstract_key ]
76- if new_key is None :
77- continue
78- new_key = new_key .format (layer_num )
79- else :
80- new_key = weight_map [key ]
81-
82- final_result [new_key ] = value
83-
84- for key in tuple (final_result .keys ()):
85- if "wq" in key :
86- q = final_result [key ]
87- k = final_result [key .replace ("wq" , "wk" )]
88- v = final_result [key .replace ("wq" , "wv" )]
89- q = permute (q , config .n_head )
90- k = permute (k , config .n_local_heads )
91- final_result [key .replace ("wq" , "wqkv" )] = torch .cat ([q , k , v ])
92- del final_result [key ]
93- del final_result [key .replace ("wq" , "wk" )]
94- del final_result [key .replace ("wq" , "wv" )]
101+ if weight_map is not None :
102+ for key , value in merged_result .items ():
103+ if "layers" in key :
104+ abstract_key = re .sub (r'(\d+)' , '{}' , key )
105+ layer_num = re .search (r'\d+' , key ).group (0 )
106+ new_key = weight_map [abstract_key ]
107+ if new_key is None :
108+ continue
109+ new_key = new_key .format (layer_num )
110+ else :
111+ new_key = weight_map [key ]
112+
113+ final_result [new_key ] = value
114+
115+ for key in tuple (final_result .keys ()):
116+ if "wq" in key :
117+ q = final_result [key ]
118+ k = final_result [key .replace ("wq" , "wk" )]
119+ v = final_result [key .replace ("wq" , "wv" )]
120+ q = permute (q , config .n_head )
121+ k = permute (k , config .n_local_heads )
122+ final_result [key .replace ("wq" , "wqkv" )] = torch .cat ([q , k , v ])
123+ del final_result [key ]
124+ del final_result [key .replace ("wq" , "wk" )]
125+ del final_result [key .replace ("wq" , "wv" )]
126+ else :
127+ final_result = merged_result
95128 print (f"Saving checkpoint to { checkpoint_dir / 'model.pth' } " )
96129 torch .save (final_result , checkpoint_dir / "model.pth" )
130+ if is_llama3 :
131+ original_dir = checkpoint_dir / "original"
132+ tokenizer_model = original_dir / "tokenizer.model"
133+ tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
134+ print (f"Copying { tokenizer_model } to { tokenizer_model_tiktoken } " )
135+ shutil .copy (tokenizer_model , tokenizer_model_tiktoken )
97136
98137if __name__ == '__main__' :
99138 import argparse
0 commit comments