11import copy
22import filecmp
33import json
4+ import os
45import pathlib
56import shutil
67import subprocess
1617from safetensors import safe_open
1718from safetensors .torch import save_file
1819
20+ from utils import convert_pt_statedict_to_safetensors , convert_pt_multifile_index_to_safetensors
21+
1922def ao_config_to_compressed_tensors_config (aobaseconfig : AOBaseConfig ) -> Dict [str , Any ]:
2023 # for now, allowlist of recipes we know how to convert and hand convert
2124 # them here
@@ -55,20 +58,30 @@ def run(
5558 dir_validation : str = 'data/llmcompressor/fp8-opt-125m' ,
5659 skip_conversion : bool = False ,
5760):
61+ dir_source = dir_source .rstrip ('/' )
62+ dir_target = dir_target .rstrip ('/' )
63+ dir_validation = dir_validation .rstrip ('/' )
64+
5865 config_name_source = f"{ dir_source } /config.json"
5966 config_name_target = f"{ dir_target } /config.json"
6067 config_name_validation = f"{ dir_validation } /config.json"
6168 weights_name_source = f"{ dir_source } /pytorch_model.bin"
6269 weights_name_target = f"{ dir_target } /model.safetensors"
6370 weights_name_validation = f"{ dir_validation } /model.safetensors"
6471
72+ # create new dir if not yet exists
73+ os .makedirs (dir_target , exist_ok = True )
74+
6575 if not skip_conversion :
76+ source_converted_filenames = set ()
77+
6678 #
6779 # convert config.json
6880 #
6981
7082 with open (config_name_source , 'r' ) as f :
7183 config_source = json .load (f )
84+ print (json .dumps (config_source , indent = 2 ))
7285
7386 # get torchao config format
7487 # example: https://www.internalfb.com/phabricator/paste/view/P1975688376
@@ -78,6 +91,11 @@ def run(
7891 fqn_to_serialized_aobaseconfig = old_hf_quantization_config ["quant_type" ]
7992 assert len (fqn_to_serialized_aobaseconfig ) == 1 , "unsupported"
8093
94+ if fqn_to_serialized_aobaseconfig ['default' ]['_type' ] == 'ModuleFqnToConfig' :
95+ fqn_to_serialized_aobaseconfig = \
96+ fqn_to_serialized_aobaseconfig ['default' ]['_data' ]['module_fqn_to_config' ]
97+
98+
8199 new_hf_quantization_config = {
82100 "config_groups" : {},
83101 "format" : "float-quantized" ,
@@ -90,13 +108,14 @@ def run(
90108 }
91109
92110 for fqn , serialized_aobaseconfig in fqn_to_serialized_aobaseconfig .items ():
93- print (fqn , serialized_aobaseconfig )
111+ if serialized_aobaseconfig is None :
112+ new_hf_quantization_config ['ignore' ].append (fqn )
113+ continue
114+
94115 aobaseconfig = config_from_dict (serialized_aobaseconfig )
95- print (aobaseconfig )
96116 ct_config = ao_config_to_compressed_tensors_config (aobaseconfig )
97- print (json .dumps (ct_config , indent = 2 ))
98117
99- assert fqn == "default" , "unsupported"
118+ assert fqn in ( "default" , "_default" ) , "unsupported"
100119 new_hf_quantization_config ["config_groups" ]["group_0" ] = ct_config
101120
102121 # for now, modify config_source inplace
@@ -106,46 +125,58 @@ def run(
106125 with open (config_name_target , 'w' ) as f :
107126 json .dump (config_source , f , indent = 2 )
108127
128+ source_converted_filenames .add (config_name_source )
129+
109130 #
110131 # convert the checkpoint
111132 #
112133
113134 # not sure why I still need this
114135 torch .serialization .add_safe_globals ([getattr ])
115136
116- old_state_dict = torch .load (weights_name_source , weights_only = True )
117- new_state_dict = {}
118-
119- for k , v in old_state_dict .items ():
120- print (k , v .shape , type (v ))
121- if type (v ) == torch .Tensor :
122-
123- if "lm_head" in k :
124- # work around issues detailed in
125- # https://huggingface.co/docs/safetensors/torch_shared_tensors
126- v = copy .deepcopy (v )
127-
128- new_state_dict [k ] = v
129- elif type (v ) == Float8Tensor :
130- new_state_dict [k ] = v .qdata
131- # for now, manually cast scale to bfloat16 to match currnt
132- # llm-compressor script
133- # TODO(future): prob needs to be user controllable
134- new_state_dict [k + '_scale' ] = v .scale .bfloat16 ()
135- else :
136- raise AssertionError (f'unsupported type { type (v )} ' )
137- save_file (new_state_dict , weights_name_target )
137+ is_single_chunk = os .path .isfile (f'{ dir_source } /pytorch_model.bin' )
138+ if is_single_chunk :
139+ convert_pt_statedict_to_safetensors (weights_name_source , weights_name_target )
140+ source_converted_filenames .add (weights_name_source )
141+ else :
142+ # convert each model state_dict file
143+ model_part_filenames = []
144+ for file_path in pathlib .Path (dir_source ).iterdir ():
145+ if not file_path .is_file ():
146+ continue
147+ if not (('pytorch_model' ) in str (file_path ) and str (file_path ).endswith ('bin' )):
148+ continue
149+ pt_sd_filename = str (file_path )
150+ # dir_source/pytorch_model-00001-of-00004.bin -> dir_target/model-00001-of-00004.safetensors
151+ safetensors_sd_filename = pt_sd_filename .replace (dir_source , dir_target )
152+ safetensors_sd_filename = safetensors_sd_filename .replace ('pytorch_model' , 'model' )
153+ safetensors_sd_filename = safetensors_sd_filename .replace ('.bin' , '.safetensors' )
154+ model_part_filenames .append (safetensors_sd_filename )
155+ print (pt_sd_filename , safetensors_sd_filename )
156+ convert_pt_statedict_to_safetensors (pt_sd_filename , safetensors_sd_filename )
157+ source_converted_filenames .add (pt_sd_filename )
158+
159+ # convert pytorch_model.bin.index.json
160+ convert_pt_multifile_index_to_safetensors (
161+ f'{ dir_source } /pytorch_model.bin.index.json' ,
162+ f'{ dir_target } /model.safetensors.index.json' ,
163+ model_part_filenames ,
164+ )
165+ source_converted_filenames .add (f'{ dir_source } /pytorch_model.bin.index.json' )
166+
167+ print (source_converted_filenames )
138168
139169 # move all the other files over
140170 for dir_and_file_path in pathlib .Path (dir_source ).iterdir ():
141171 if not dir_and_file_path .is_file ():
142172 continue
143- file_path = dir_and_file_path .parts [- 1 ]
144- if file_path in ('config.json' , 'pytorch_model.bin' ):
173+ if str (dir_and_file_path ) in source_converted_filenames :
145174 # these are converted in custom logic elsewhere in this script
146175 continue
147176 # if we got here, we just need to copy the file over without any changes
177+ file_path = dir_and_file_path .parts [- 1 ]
148178 target_file_path = f"{ dir_target } /{ str (file_path )} "
179+ print (f'copying { dir_and_file_path } to { target_file_path } ' )
149180 shutil .copyfile (dir_and_file_path , target_file_path )
150181
151182 # validate target_dir vs validation_dir
@@ -165,9 +196,11 @@ def run(
165196 # this will always fail, for now, as we are not perfectly matching
166197 print (e .stderr )
167198
199+ # TODO(future, as needed): also validate the other files, they are unlikely to match
200+ # exactly for any model with >1 chunk of state dict files since we are not
201+ # trying to enfore that the same tensors live in the same chunks.
202+
168203 elif file_path_target == 'model.safetensors' :
169- # TODO implement me
170- pass
171204
172205 with safe_open (dir_and_file_path , framework = 'pt' ) as f_target :
173206 with safe_open (dir_and_file_path_validation , framework = 'pt' ) as f_validation :
0 commit comments