99import sys
1010from pathlib import Path
1111from typing import Optional
12-
12+ from safetensors . torch import load_file as load_safetensors_file
1313import torch
1414
1515# support running without installing as a package
@@ -28,62 +28,49 @@ def convert_hf_checkpoint(
2828 if model_name is None :
2929 model_name = checkpoint_dir .name
3030
31- # Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files
32- # need to be copied into model.pth.
33- # Llama 3 70B can't be easily merged into one model.pth file, though, since names of the
34- # weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not
35- # currently supported.
36- # Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken
37- is_llama3 = "Llama-3" in model_name
38- if is_llama3 :
39- # Check if we have multiple original/consolidated.NN.pth files and report error
40- # if we do for Llama 3.
41- original_dir = checkpoint_dir / "original"
42- pattern = re .compile (r"^consolidated\.\d{2}\.pth$" )
43- bin_files = [bin for bin in original_dir .iterdir () if pattern .match (bin .name )]
44- if len (bin_files ) > 1 :
45- raise ValueError (
46- f"Multiple consolidated.NN.pth files found in { original_dir } . "
47- "Merging them into one model.pth file is not supported for Llama 3." )
48-
49-
5031 config = ModelArgs .from_name (model_name )
5132 print (f"Model config { config .__dict__ } " )
5233
5334 # Load the json file containing weight mapping
54- if not is_llama3 :
55- model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
56-
57- assert model_map_json .is_file ()
58-
59- with open (model_map_json ) as json_map :
60- bin_index = json .load (json_map )
61-
62- weight_map = {
63- "model.embed_tokens.weight" : "tok_embeddings.weight" ,
64- "model.layers.{}.self_attn.q_proj.weight" : "layers.{}.attention.wq.weight" ,
65- "model.layers.{}.self_attn.k_proj.weight" : "layers.{}.attention.wk.weight" ,
66- "model.layers.{}.self_attn.v_proj.weight" : "layers.{}.attention.wv.weight" ,
67- "model.layers.{}.self_attn.o_proj.weight" : "layers.{}.attention.wo.weight" ,
68- 'model.layers.{}.self_attn.rotary_emb.inv_freq' : None ,
69- 'model.layers.{}.mlp.gate_proj.weight' : 'layers.{}.feed_forward.w1.weight' ,
70- "model.layers.{}.mlp.up_proj.weight" : "layers.{}.feed_forward.w3.weight" ,
71- "model.layers.{}.mlp.down_proj.weight" : "layers.{}.feed_forward.w2.weight" ,
72- "model.layers.{}.input_layernorm.weight" : "layers.{}.attention_norm.weight" ,
73- "model.layers.{}.post_attention_layernorm.weight" : "layers.{}.ffn_norm.weight" ,
74- "model.norm.weight" : "norm.weight" ,
75- "lm_head.weight" : "output.weight" ,
76- }
77- bin_files = {checkpoint_dir / bin for bin in bin_index ["weight_map" ].values ()}
78- else :
79- # There is no separate pytorch_model.bin.index.json file for llama3.
80- # Instead, we will just use all original/consolidated.NN.pth files.
81- # so, we use model.safetensors.index.json
82- weight_map = None
83- original_dir = checkpoint_dir / "original"
84- pattern = re .compile (r"^consolidated\.\d{2}\.pth$" )
85- bin_files = {bin for bin in original_dir .iterdir () if pattern .match (bin .name )}
86-
35+ model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json'
36+ model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json"
37+ model_map_json = None
38+
39+ try :
40+ assert model_map_json_safetensors .is_file ()
41+ model_map_json = model_map_json_safetensors
42+ print (f"Found safetensors index at { model_map_json_safetensors } " )
43+ except AssertionError :
44+ print (f"{ model_map_json_safetensors } not found" )
45+ if model_map_json is None :
46+ try :
47+ assert model_map_json_pytorch .is_file ()
48+ model_map_json = model_map_json_pytorch
49+ print (f"Found pytorch index at { model_map_json_pytorch } " )
50+ except AssertionError :
51+ print (f"{ model_map_json_pytorch } not found" )
52+
53+ if model_map_json is None : raise Exception ("No model map found!" )
54+
55+ with open (model_map_json ) as json_map :
56+ bin_index = json .load (json_map )
57+
58+ weight_map = {
59+ "model.embed_tokens.weight" : "tok_embeddings.weight" ,
60+ "model.layers.{}.self_attn.q_proj.weight" : "layers.{}.attention.wq.weight" ,
61+ "model.layers.{}.self_attn.k_proj.weight" : "layers.{}.attention.wk.weight" ,
62+ "model.layers.{}.self_attn.v_proj.weight" : "layers.{}.attention.wv.weight" ,
63+ "model.layers.{}.self_attn.o_proj.weight" : "layers.{}.attention.wo.weight" ,
64+ 'model.layers.{}.self_attn.rotary_emb.inv_freq' : None ,
65+ 'model.layers.{}.mlp.gate_proj.weight' : 'layers.{}.feed_forward.w1.weight' ,
66+ "model.layers.{}.mlp.up_proj.weight" : "layers.{}.feed_forward.w3.weight" ,
67+ "model.layers.{}.mlp.down_proj.weight" : "layers.{}.feed_forward.w2.weight" ,
68+ "model.layers.{}.input_layernorm.weight" : "layers.{}.attention_norm.weight" ,
69+ "model.layers.{}.post_attention_layernorm.weight" : "layers.{}.ffn_norm.weight" ,
70+ "model.norm.weight" : "norm.weight" ,
71+ "lm_head.weight" : "output.weight" ,
72+ }
73+ bin_files = {checkpoint_dir / bin for bin in bin_index ["weight_map" ].values ()}
8774
8875 def permute (w , n_head ):
8976 dim = config .dim
@@ -95,39 +82,40 @@ def permute(w, n_head):
9582
9683 merged_result = {}
9784 for file in sorted (bin_files ):
98- state_dict = torch .load (str (file ), map_location = "cpu" , mmap = True , weights_only = True )
99- merged_result .update (state_dict )
85+ if "safetensors" in str (file ):
86+ state_dict = load_safetensors_file (str (file ), device = "cpu" )
87+ merged_result .update (state_dict )
88+ else :
89+ state_dict = torch .load (str (file ), map_location = "cpu" , mmap = True , weights_only = True )
90+ merged_result .update (state_dict )
10091 final_result = {}
101- if weight_map is not None :
102- for key , value in merged_result .items ():
103- if "layers" in key :
104- abstract_key = re .sub (r'(\d+)' , '{}' , key )
105- layer_num = re .search (r'\d+' , key ).group (0 )
106- new_key = weight_map [abstract_key ]
107- if new_key is None :
108- continue
109- new_key = new_key .format (layer_num )
110- else :
111- new_key = weight_map [key ]
112-
113- final_result [new_key ] = value
114-
115- for key in tuple (final_result .keys ()):
116- if "wq" in key :
117- q = final_result [key ]
118- k = final_result [key .replace ("wq" , "wk" )]
119- v = final_result [key .replace ("wq" , "wv" )]
120- q = permute (q , config .n_head )
121- k = permute (k , config .n_local_heads )
122- final_result [key .replace ("wq" , "wqkv" )] = torch .cat ([q , k , v ])
123- del final_result [key ]
124- del final_result [key .replace ("wq" , "wk" )]
125- del final_result [key .replace ("wq" , "wv" )]
126- else :
127- final_result = merged_result
92+ for key , value in merged_result .items ():
93+ if "layers" in key :
94+ abstract_key = re .sub (r'(\d+)' , '{}' , key )
95+ layer_num = re .search (r'\d+' , key ).group (0 )
96+ new_key = weight_map [abstract_key ]
97+ if new_key is None :
98+ continue
99+ new_key = new_key .format (layer_num )
100+ else :
101+ new_key = weight_map [key ]
102+
103+ final_result [new_key ] = value
104+
105+ for key in tuple (final_result .keys ()):
106+ if "wq" in key :
107+ q = final_result [key ]
108+ k = final_result [key .replace ("wq" , "wk" )]
109+ v = final_result [key .replace ("wq" , "wv" )]
110+ q = permute (q , config .n_head )
111+ k = permute (k , config .n_local_heads )
112+ final_result [key .replace ("wq" , "wqkv" )] = torch .cat ([q , k , v ])
113+ del final_result [key ]
114+ del final_result [key .replace ("wq" , "wk" )]
115+ del final_result [key .replace ("wq" , "wv" )]
128116 print (f"Saving checkpoint to { checkpoint_dir / 'model.pth' } " )
129117 torch .save (final_result , checkpoint_dir / "model.pth" )
130- if is_llama3 :
118+ if 'llama-3' in model_name . lower () :
131119 original_dir = checkpoint_dir / "original"
132120 tokenizer_model = original_dir / "tokenizer.model"
133121 tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
0 commit comments