2828import os
2929import time
3030
31- from collections .abc import Sequence
3231from absl import app
3332from absl import flags
3433from etils import epath
9493 "output.weight" : "ColumnParallelLinear" ,
9594}
9695
97- _QUANTIZED_WEIGHTS_TO_SCALER_NAME = {
96+ _LLAMA_QUANTIZED_WEIGHTS_TO_SCALER_NAME = {
9897 "tok_embeddings.weight" : "tok_embeddings.weight_scaler" ,
9998 "attention.wq.weight" : "attention.wq.weight_scaler" ,
10099 "attention.wk.weight" : "attention.wk.weight_scaler" ,
106105 "output.weight" : "output.weight_scaler" ,
107106}
108107
109-
110- def _quantize_state_dict (state_dict ):
111- updated_weights = {}
112- for key , val in state_dict .items ():
113- for qname , qscale_name in _QUANTIZED_WEIGHTS_TO_SCALER_NAME .items ():
114- if key .endswith (qname ):
115- new_weights , scaler = quantize .quantize_torch_int8 (
116- val , reduce_axis = (1 ,)
117- )
118- updated_weights [key ] = new_weights
119- scale_name = key [: - len (qname )] + qscale_name
120- updated_weights [scale_name ] = scaler
121- state_dict .update (updated_weights )
122- return state_dict
123-
124-
125- _QUANTIZE_LINEAR_WEIGHTS = {
126- "attention.wq.weight" ,
127- "attention.wk.weight" ,
128- "attention.wv.weight" ,
129- "attention.wo.weight" ,
130- "feed_forward.w1.weight" ,
131- "feed_forward.w2.weight" ,
132- "feed_forward.w3.weight" ,
133- "output.weight" ,
108+ _GEMMA_QUANTIZED_WEIGHTS_TO_SCALER_NAME = {
109+ "self_attn.o_proj.weight" : "self_attn.o_proj.weight_scaler" ,
110+ "self_attn.wq.weight" : "self_attn.wq.weight_scaler" ,
111+ "self_attn.wk.weight" : "self_attn.wk.weight_scaler" ,
112+ "self_attn.wv.weight" : "self_attn.wv.weight_scaler" ,
113+ "mlp.gate_proj.weight" : "mlp.gate_proj.weight_scaler" ,
114+ "mlp.up_proj.weight" : "mlp.up_proj.weight_scaler" ,
115+ "mlp.down_proj.weight" : "mlp.down_proj.weight_scaler" ,
116+ "embedder.weight" : "embedder.weight_scaler" ,
134117}
135118
136119
137- def _quantize_state_dict (state_dict ):
120+ def _quantize_state_dict (state_dict , weight_map , weight_axis ):
138121 updated_weights = {}
139122 for key , val in state_dict .items ():
140- for qname in _QUANTIZE_LINEAR_WEIGHTS :
123+ for qname , qscale_name in weight_map . items () :
141124 if key .endswith (qname ):
142125 new_weights , scaler = quantize .quantize_torch_int8 (
143- val , reduce_axis = (1 ,)
126+ val , reduce_axis = (weight_axis ( key ) ,)
144127 )
145128 updated_weights [key ] = new_weights
146- scale_name = key + "_scaler"
129+ scale_name = key [: - len ( qname )] + qscale_name
147130 updated_weights [scale_name ] = scaler .squeeze ()
148- tok_weights , tok_scalers = quantize .quantize_torch_int8 (
149- state_dict ["tok_embeddings.weight" ], reduce_axis = (0 ,)
150- )
151- updated_weights ["tok_embeddings.weight" ] = tok_weights
152- updated_weights ["tok_embeddings.weight_scaler" ] = tok_scalers .squeeze ()
153- state_dict .update (updated_weights )
154- return state_dict
155-
156-
157- _QUANTIZE_LINEAR_WEIGHTS = {
158- "attention.wq.weight" ,
159- "attention.wk.weight" ,
160- "attention.wv.weight" ,
161- "attention.wo.weight" ,
162- "feed_forward.w1.weight" ,
163- "feed_forward.w2.weight" ,
164- "feed_forward.w3.weight" ,
165- "output.weight" ,
166- }
167-
168-
169- def _quantize_state_dict (state_dict ):
170- updated_weights = {}
171- for key , val in state_dict .items ():
172- for qname in _QUANTIZE_LINEAR_WEIGHTS :
173- if key .endswith (qname ):
174- new_weights , scaler = quantize .quantize_torch_int8 (
175- val , reduce_axis = (1 ,)
176- )
177- updated_weights [key ] = new_weights
178- scale_name = key + "_scaler"
179- updated_weights [scale_name ] = scaler .squeeze ()
180- tok_weights , tok_scalers = quantize .quantize_torch_int8 (
181- state_dict ["tok_embeddings.weight" ], reduce_axis = (0 ,)
182- )
183- updated_weights ["tok_embeddings.weight" ] = tok_weights
184- updated_weights ["tok_embeddings.weight_scaler" ] = tok_scalers .squeeze ()
185131 state_dict .update (updated_weights )
186132 return state_dict
187133
@@ -222,7 +168,9 @@ def _tensors_have_same_shape(tensors):
222168
223169
224170# pylint: disable-next=all
225- def _merge_weights (checkpoints , minimize_memory_footprint , enable_float32 ):
171+ def _merge_llama_weights (
172+ checkpoints , minimize_memory_footprint , enable_float32
173+ ):
226174 print ("Starting to merge weights." )
227175 state_dict = {}
228176 tmp_dir : epath .Path = None
@@ -362,13 +310,7 @@ def _export_to_local(output_ckpt_dir: epath.Path, params, state_dict):
362310 checklist_file .write_text (_generate_md5_checklist (output_ckpt_dir ))
363311
364312
365- def merge_weights (
366- input_ckpt_dir : epath .Path ,
367- output_ckpt_dir : epath .Path ,
368- minimize_memory_footprint : bool = True ,
369- enable_float32 : bool = False ,
370- ) -> None :
371- """merge weights"""
313+ def _get_llama_state_dict (input_ckpt_dir ):
372314 start = time .perf_counter ()
373315 if "gs://" in str (input_ckpt_dir ):
374316 print (
@@ -382,35 +324,15 @@ def merge_weights(
382324 print (f"Loading checkpoints takes { end - start } seconds" )
383325
384326 start = time .perf_counter ()
385- state_dict = _merge_weights (
386- checkpoints , minimize_memory_footprint , enable_float32
327+ state_dict = _merge_llama_weights (
328+ checkpoints , _MINIMIZE_MEMORY_FOOTPRINT . value , _ENABLE_FLOAT32 . value
387329 )
388330 end = time .perf_counter ()
389331 print (f"Merging weights takes { end - start } seconds" )
390-
391- if _QUANTIZE .value :
392- start = time .perf_counter ()
393- state_dict = _quantize_state_dict (state_dict )
394- end = time .perf_counter ()
395- print (f"Quantizing weights takes { end - start } seconds" )
396-
397- print (f"Writing merged weights to dir { output_ckpt_dir } " )
398- start = time .perf_counter ()
399- if "gs://" in str (output_ckpt_dir ):
400- _export_to_gcs (output_ckpt_dir , params , state_dict )
401- else :
402- _export_to_local (output_ckpt_dir , params , state_dict )
403- end = time .perf_counter ()
404- print (f"Export outputs takes { end - start } seconds" )
332+ return state_dict , params
405333
406334
407- def convert_hf_gemma_weights (
408- input_ckpt_dir : epath .Path , output_ckpt_dir : epath .Path
409- ):
410- """Convert gemma weights from Huggingface to be compatible with JetStream
411- 1. Map attention weights to new names.
412- 2. Split qkv fusion.
413- """
335+ def _get_gemma_state_dict (input_ckpt_dir ):
414336 ckpt_file = list (input_ckpt_dir .glob ("*.ckpt" ))
415337 assert len (ckpt_file ) == 1 , "only expect 1 ckpt file for Gemma model."
416338 ckpt_file = ckpt_file [0 ]
@@ -450,24 +372,37 @@ def convert_hf_gemma_weights(
450372
451373 if new_key != key :
452374 state_dict [new_key ] = state_dict .pop (key )
453- _export_to_local ( output_ckpt_dir , model_config , state_dict )
375+ return state_dict , model_config
454376
455377
456- def main (argv : Sequence [str ]) -> None :
457- """convert checkpoint main function"""
458- if len (argv ) > 1 :
459- raise app .UsageError ("Too many command-line arguments." )
460- if "gemma" in _MODEL_TYPE .value :
461- convert_hf_gemma_weights (
462- _INPUT_CHECKPOINT_DIR .value , _OUTPUT_CHECKPOINT_DIR .value
463- )
378+ def main (argv ) -> None :
379+ """merge weights"""
380+
381+ if _MODEL_TYPE .value == "gemma" :
382+ state_dict , params = _get_gemma_state_dict (_INPUT_CHECKPOINT_DIR .value )
383+ quantize_weight_map = _GEMMA_QUANTIZED_WEIGHTS_TO_SCALER_NAME
384+ weight_axis = lambda x : 0 if x == "embedder.weight" else 1
464385 else :
465- merge_weights (
466- _INPUT_CHECKPOINT_DIR .value ,
467- _OUTPUT_CHECKPOINT_DIR .value ,
468- _MINIMIZE_MEMORY_FOOTPRINT .value ,
469- _ENABLE_FLOAT32 .value ,
386+ state_dict , params = _get_llama_state_dict (_INPUT_CHECKPOINT_DIR .value )
387+ quantize_weight_map = _LLAMA_QUANTIZED_WEIGHTS_TO_SCALER_NAME
388+ weight_axis = lambda x : 0 if x == "tok_embeddings.weight" else 1
389+
390+ if _QUANTIZE .value :
391+ start = time .perf_counter ()
392+ state_dict = _quantize_state_dict (
393+ state_dict , quantize_weight_map , weight_axis
470394 )
395+ end = time .perf_counter ()
396+ print (f"Quantizing weights takes { end - start } seconds" )
397+
398+ print (f"Writing merged weights to dir { _OUTPUT_CHECKPOINT_DIR .value } " )
399+ start = time .perf_counter ()
400+ if "gs://" in str (_OUTPUT_CHECKPOINT_DIR .value ):
401+ _export_to_gcs (_OUTPUT_CHECKPOINT_DIR .value , params , state_dict )
402+ else :
403+ _export_to_local (_OUTPUT_CHECKPOINT_DIR .value , params , state_dict )
404+ end = time .perf_counter ()
405+ print (f"Export outputs takes { end - start } seconds" )
471406
472407
473408if __name__ == "__main__" :
0 commit comments