1- import copy
21import filecmp
32import json
43import os
54import pathlib
65import shutil
76import subprocess
8- from typing import Dict , Any
97
108import fire
11-
12- import torch
13- from torchao .core .config import AOBaseConfig , config_from_dict
14- from torchao .quantization import Float8DynamicActivationFloat8WeightConfig , PerRow
15- from torchao .quantization .quantize_ .workflows .float8 .float8_tensor import Float8Tensor
16-
179from safetensors import safe_open
18- from safetensors .torch import save_file
1910
11+ import torch
12+ from torchao .core .config import config_from_dict
2013from utils import (
21- convert_pt_statedict_to_safetensors ,
22- convert_pt_multifile_index_to_safetensors ,
2314 ao_config_to_compressed_tensors_config ,
15+ convert_pt_multifile_index_to_safetensors ,
16+ convert_pt_statedict_to_safetensors ,
2417)
2518
2619
2720def run (
2821 # original torchao checkpoint
29- dir_source : str = ' data/torchao/fp8-opt-125m' ,
22+ dir_source : str = " data/torchao/fp8-opt-125m" ,
3023 # new compressed-tensors checkpoint
31- dir_target : str = ' data/torchao_compressed_tensors/fp8-opt-125m' ,
24+ dir_target : str = " data/torchao_compressed_tensors/fp8-opt-125m" ,
3225 # existing compressed-tensors checkpoint to validate against
33- dir_validation : str = ' data/llmcompressor/fp8-opt-125m' ,
26+ dir_validation : str = " data/llmcompressor/fp8-opt-125m" ,
3427 skip_conversion : bool = False ,
3528):
36- dir_source = dir_source .rstrip ('/' )
37- dir_target = dir_target .rstrip ('/' )
38- dir_validation = dir_validation .rstrip ('/' )
29+ dir_source = dir_source .rstrip ("/" )
30+ dir_target = dir_target .rstrip ("/" )
31+ dir_validation = dir_validation .rstrip ("/" )
3932
4033 config_name_source = f"{ dir_source } /config.json"
4134 config_name_target = f"{ dir_target } /config.json"
4235 config_name_validation = f"{ dir_validation } /config.json"
43- weights_name_source = f"{ dir_source } /pytorch_model.bin"
36+ weights_name_source = f"{ dir_source } /pytorch_model.bin"
4437 weights_name_target = f"{ dir_target } /model.safetensors"
4538 weights_name_validation = f"{ dir_validation } /model.safetensors"
4639
@@ -54,7 +47,7 @@ def run(
5447 # convert config.json
5548 #
5649
57- with open (config_name_source , 'r' ) as f :
50+ with open (config_name_source ) as f :
5851 config_source = json .load (f )
5952 print (json .dumps (config_source , indent = 2 ))
6053
@@ -63,13 +56,18 @@ def run(
6356 # we need to translate it to compressed-tensors format
6457 # example: https://www.internalfb.com/phabricator/paste/view/P1975642629
6558 old_hf_quantization_config = config_source ["quantization_config" ]
66- fqn_to_serialized_aobaseconfig = old_hf_quantization_config ["quant_type" ]
59+ fqn_to_serialized_aobaseconfig = old_hf_quantization_config [
60+ "quant_type"
61+ ]
6762 assert len (fqn_to_serialized_aobaseconfig ) == 1 , "unsupported"
6863
69- if fqn_to_serialized_aobaseconfig ['default' ]['_type' ] == 'ModuleFqnToConfig' :
70- fqn_to_serialized_aobaseconfig = \
71- fqn_to_serialized_aobaseconfig ['default' ]['_data' ]['module_fqn_to_config' ]
72-
64+ if (
65+ fqn_to_serialized_aobaseconfig ["default" ]["_type" ]
66+ == "ModuleFqnToConfig"
67+ ):
68+ fqn_to_serialized_aobaseconfig = fqn_to_serialized_aobaseconfig [
69+ "default"
70+ ]["_data" ]["module_fqn_to_config" ]
7371
7472 new_hf_quantization_config = {
7573 "config_groups" : {},
@@ -82,9 +80,12 @@ def run(
8280 "version" : "torchao_hack" ,
8381 }
8482
85- for fqn , serialized_aobaseconfig in fqn_to_serialized_aobaseconfig .items ():
83+ for (
84+ fqn ,
85+ serialized_aobaseconfig ,
86+ ) in fqn_to_serialized_aobaseconfig .items ():
8687 if serialized_aobaseconfig is None :
87- new_hf_quantization_config [' ignore' ].append (fqn )
88+ new_hf_quantization_config [" ignore" ].append (fqn )
8889 continue
8990
9091 aobaseconfig = config_from_dict (serialized_aobaseconfig )
@@ -97,7 +98,7 @@ def run(
9798 config_source ["quantization_config" ] = new_hf_quantization_config
9899
99100 # save to new location
100- with open (config_name_target , 'w' ) as f :
101+ with open (config_name_target , "w" ) as f :
101102 json .dump (config_source , f , indent = 2 )
102103
103104 source_converted_filenames .add (config_name_source )
@@ -109,35 +110,50 @@ def run(
109110 # not sure why I still need this
110111 torch .serialization .add_safe_globals ([getattr ])
111112
112- is_single_chunk = os .path .isfile (f' { dir_source } /pytorch_model.bin' )
113+ is_single_chunk = os .path .isfile (f" { dir_source } /pytorch_model.bin" )
113114 if is_single_chunk :
114- convert_pt_statedict_to_safetensors (weights_name_source , weights_name_target )
115+ convert_pt_statedict_to_safetensors (
116+ weights_name_source , weights_name_target
117+ )
115118 source_converted_filenames .add (weights_name_source )
116119 else :
117- # convert each model state_dict file
120+ # convert each model state_dict file
118121 model_part_filenames = []
119122 for file_path in pathlib .Path (dir_source ).iterdir ():
120123 if not file_path .is_file ():
121124 continue
122- if not (('pytorch_model' ) in str (file_path ) and str (file_path ).endswith ('bin' )):
125+ if not (
126+ ("pytorch_model" ) in str (file_path )
127+ and str (file_path ).endswith ("bin" )
128+ ):
123129 continue
124130 pt_sd_filename = str (file_path )
125131 # dir_source/pytorch_model-00001-of-00004.bin -> dir_target/model-00001-of-00004.safetensors
126- safetensors_sd_filename = pt_sd_filename .replace (dir_source , dir_target )
127- safetensors_sd_filename = safetensors_sd_filename .replace ('pytorch_model' , 'model' )
128- safetensors_sd_filename = safetensors_sd_filename .replace ('.bin' , '.safetensors' )
132+ safetensors_sd_filename = pt_sd_filename .replace (
133+ dir_source , dir_target
134+ )
135+ safetensors_sd_filename = safetensors_sd_filename .replace (
136+ "pytorch_model" , "model"
137+ )
138+ safetensors_sd_filename = safetensors_sd_filename .replace (
139+ ".bin" , ".safetensors"
140+ )
129141 model_part_filenames .append (safetensors_sd_filename )
130142 print (pt_sd_filename , safetensors_sd_filename )
131- convert_pt_statedict_to_safetensors (pt_sd_filename , safetensors_sd_filename )
143+ convert_pt_statedict_to_safetensors (
144+ pt_sd_filename , safetensors_sd_filename
145+ )
132146 source_converted_filenames .add (pt_sd_filename )
133147
134148 # convert pytorch_model.bin.index.json
135149 convert_pt_multifile_index_to_safetensors (
136- f' { dir_source } /pytorch_model.bin.index.json' ,
137- f' { dir_target } /model.safetensors.index.json' ,
150+ f" { dir_source } /pytorch_model.bin.index.json" ,
151+ f" { dir_target } /model.safetensors.index.json" ,
138152 model_part_filenames ,
139153 )
140- source_converted_filenames .add (f'{ dir_source } /pytorch_model.bin.index.json' )
154+ source_converted_filenames .add (
155+ f"{ dir_source } /pytorch_model.bin.index.json"
156+ )
141157
142158 print (source_converted_filenames )
143159
@@ -151,7 +167,7 @@ def run(
151167 # if we got here, we just need to copy the file over without any changes
152168 file_path = dir_and_file_path .parts [- 1 ]
153169 target_file_path = f"{ dir_target } /{ str (file_path )} "
154- print (f' copying { dir_and_file_path } to { target_file_path } ' )
170+ print (f" copying { dir_and_file_path } to { target_file_path } " )
155171 shutil .copyfile (dir_and_file_path , target_file_path )
156172
157173 # validate target_dir vs validation_dir
@@ -160,36 +176,49 @@ def run(
160176 continue
161177 file_path_target = dir_and_file_path .parts [- 1 ]
162178 print ("\n validating" , file_path_target )
163- dir_and_file_path_validation = f"{ dir_validation } /{ str (file_path_target )} "
179+ dir_and_file_path_validation = (
180+ f"{ dir_validation } /{ str (file_path_target )} "
181+ )
164182
165- if file_path_target == ' config.json' :
183+ if file_path_target == " config.json" :
166184 # for now just diff and print the output to stdout
167- command = f' diff { dir_and_file_path } { dir_and_file_path_validation } '
185+ command = f" diff { dir_and_file_path } { dir_and_file_path_validation } "
168186 try :
169- result = subprocess .run (command , capture_output = False , text = True , shell = True , check = True )
187+ result = subprocess .run (
188+ command ,
189+ capture_output = False ,
190+ text = True ,
191+ shell = True ,
192+ check = True ,
193+ )
170194 except subprocess .CalledProcessError as e :
171195 # this will always fail, for now, as we are not perfectly matching
172- print (e .stderr )
196+ print (e .stderr )
173197
174198 # TODO(future, as needed): also validate the other files, they are unlikely to match
175199 # exactly for any model with >1 chunk of state dict files since we are not
176200 # trying to enfore that the same tensors live in the same chunks.
177201
178- elif file_path_target == 'model.safetensors' :
179-
180- with safe_open (dir_and_file_path , framework = 'pt' ) as f_target :
181- with safe_open (dir_and_file_path_validation , framework = 'pt' ) as f_validation :
202+ elif file_path_target == "model.safetensors" :
203+ with safe_open (dir_and_file_path , framework = "pt" ) as f_target :
204+ with safe_open (
205+ dir_and_file_path_validation , framework = "pt"
206+ ) as f_validation :
182207 k_target_seen = set ()
183208 for k_target in f_target .keys ():
184209 v_target = f_target .get_tensor (k_target )
185210 v_validation = f_validation .get_tensor (k_target )
186211
187212 # ensure metadata matches
188213 if v_target .shape != v_validation .shape :
189- print (f"shape mismatch: { k_target = } , { v_target .shape = } , { v_validation .shape = } " )
214+ print (
215+ f"shape mismatch: { k_target = } , { v_target .shape = } , { v_validation .shape = } "
216+ )
190217
191- if v_target .dtype != v_validation .dtype :
192- print (f"dtype mismatch: { k_target = } , { v_target .dtype = } , { v_validation .dtype = } " )
218+ if v_target .dtype != v_validation .dtype :
219+ print (
220+ f"dtype mismatch: { k_target = } , { v_target .dtype = } , { v_validation .dtype = } "
221+ )
193222
194223 # for now, no numerical checks
195224
@@ -202,8 +231,11 @@ def run(
202231 else :
203232 # approx check, currently fails because modification timestamp is not the
204233 # same. Since we copy these files ourselves, low-pri to make this better.
205- is_equal = filecmp .cmp (dir_and_file_path , dir_and_file_path_validation , shallow = False )
206- print ('filecmp equal' , is_equal )
234+ is_equal = filecmp .cmp (
235+ dir_and_file_path , dir_and_file_path_validation , shallow = False
236+ )
237+ print ("filecmp equal" , is_equal )
238+
207239
208- if __name__ == ' __main__' :
240+ if __name__ == " __main__" :
209241 fire .Fire (run )
0 commit comments