Skip to content

Commit 9732514

Browse files
authored
Merge pull request #68 from vkuzo/20251003_compressed_tensors
scripts for torchao -> compressed_tensors checkpoint conversion
2 parents da4bdc1 + 11117c6 commit 9732514

File tree

4 files changed

+207
-4
lines changed

4 files changed

+207
-4
lines changed
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import copy
2+
import filecmp
3+
import json
4+
import pathlib
5+
import shutil
6+
import subprocess
7+
from typing import Dict, Any
8+
9+
import fire
10+
11+
import torch
12+
from torchao.core.config import AOBaseConfig, config_from_dict
13+
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow
14+
from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor
15+
16+
from safetensors import safe_open
17+
from safetensors.torch import save_file
18+
19+
def ao_config_to_compressed_tensors_config(aobaseconfig: AOBaseConfig) -> Dict[str, Any]:
20+
# for now, allowlist of recipes we know how to convert and hand convert
21+
# them here
22+
# for a production version, we'll need a more scalable way to do this
23+
24+
assert isinstance(aobaseconfig, Float8DynamicActivationFloat8WeightConfig), "unsupported"
25+
assert aobaseconfig.granularity == [PerRow(), PerRow()], "unsupported"
26+
27+
ct_config = {
28+
"format": "float-quantized",
29+
"input_activations": {
30+
"dynamic": True,
31+
"num_bits": 8,
32+
"strategy": "token",
33+
"symmetric": True,
34+
"type": "float",
35+
},
36+
"output_activations": None,
37+
"targets": ["Linear"],
38+
"weights": {
39+
"dynamic": False,
40+
"num_bits": 8,
41+
"observer": "minmax",
42+
"strategy": "channel",
43+
"symmetric": True,
44+
"type": "float",
45+
},
46+
}
47+
return ct_config
48+
49+
def run(
50+
# original torchao checkpoint
51+
dir_source: str = 'data/torchao/fp8-opt-125m',
52+
# new compressed-tensors checkpoint
53+
dir_target: str = 'data/torchao_compressed_tensors/fp8-opt-125m',
54+
# existing compressed-tensors checkpoint to validate against
55+
dir_validation: str = 'data/llmcompressor/fp8-opt-125m',
56+
skip_conversion: bool = False,
57+
):
58+
config_name_source = f"{dir_source}/config.json"
59+
config_name_target = f"{dir_target}/config.json"
60+
config_name_validation = f"{dir_validation}/config.json"
61+
weights_name_source = f"{dir_source}/pytorch_model.bin"
62+
weights_name_target = f"{dir_target}/model.safetensors"
63+
weights_name_validation = f"{dir_validation}/model.safetensors"
64+
65+
if not skip_conversion:
66+
#
67+
# convert config.json
68+
#
69+
70+
with open(config_name_source, 'r') as f:
71+
config_source = json.load(f)
72+
73+
# get torchao config format
74+
# example: https://www.internalfb.com/phabricator/paste/view/P1975688376
75+
# we need to translate it to compressed-tensors format
76+
# example: https://www.internalfb.com/phabricator/paste/view/P1975642629
77+
old_hf_quantization_config = config_source["quantization_config"]
78+
fqn_to_serialized_aobaseconfig = old_hf_quantization_config["quant_type"]
79+
assert len(fqn_to_serialized_aobaseconfig) == 1, "unsupported"
80+
81+
new_hf_quantization_config = {
82+
"config_groups": {},
83+
"format": "float-quantized",
84+
"ignore": ["lm_head"],
85+
"quant_method": "compressed-tensors",
86+
"quantization_status": "compressed",
87+
"sparsity_config": {},
88+
"transform_config": {},
89+
"version": "torchao_hack",
90+
}
91+
92+
for fqn, serialized_aobaseconfig in fqn_to_serialized_aobaseconfig.items():
93+
print(fqn, serialized_aobaseconfig)
94+
aobaseconfig = config_from_dict(serialized_aobaseconfig)
95+
print(aobaseconfig)
96+
ct_config = ao_config_to_compressed_tensors_config(aobaseconfig)
97+
print(json.dumps(ct_config, indent=2))
98+
99+
assert fqn == "default", "unsupported"
100+
new_hf_quantization_config["config_groups"]["group_0"] = ct_config
101+
102+
# for now, modify config_source inplace
103+
config_source["quantization_config"] = new_hf_quantization_config
104+
105+
# save to new location
106+
with open(config_name_target, 'w') as f:
107+
json.dump(config_source, f, indent=2)
108+
109+
#
110+
# convert the checkpoint
111+
#
112+
113+
# not sure why I still need this
114+
torch.serialization.add_safe_globals([getattr])
115+
116+
old_state_dict = torch.load(weights_name_source, weights_only=True)
117+
new_state_dict = {}
118+
119+
for k, v in old_state_dict.items():
120+
print(k, v.shape, type(v))
121+
if type(v) == torch.Tensor:
122+
123+
if "lm_head" in k:
124+
# work around issues detailed in
125+
# https://huggingface.co/docs/safetensors/torch_shared_tensors
126+
v = copy.deepcopy(v)
127+
128+
new_state_dict[k] = v
129+
elif type(v) == Float8Tensor:
130+
new_state_dict[k] = v.qdata
131+
# for now, manually cast scale to bfloat16 to match currnt
132+
# llm-compressor script
133+
# TODO(future): prob needs to be user controllable
134+
new_state_dict[k + '_scale'] = v.scale.bfloat16()
135+
else:
136+
raise AssertionError(f'unsupported type {type(v)}')
137+
save_file(new_state_dict, weights_name_target)
138+
139+
# move all the other files over
140+
for dir_and_file_path in pathlib.Path(dir_source).iterdir():
141+
if not dir_and_file_path.is_file():
142+
continue
143+
file_path = dir_and_file_path.parts[-1]
144+
if file_path in ('config.json', 'pytorch_model.bin'):
145+
# these are converted in custom logic elsewhere in this script
146+
continue
147+
# if we got here, we just need to copy the file over without any changes
148+
target_file_path = f"{dir_target}/{str(file_path)}"
149+
shutil.copyfile(dir_and_file_path, target_file_path)
150+
151+
# validate target_dir vs validation_dir
152+
for dir_and_file_path in pathlib.Path(dir_target).iterdir():
153+
if not dir_and_file_path.is_file():
154+
continue
155+
file_path_target = dir_and_file_path.parts[-1]
156+
print("\nvalidating", file_path_target)
157+
dir_and_file_path_validation = f"{dir_validation}/{str(file_path_target)}"
158+
159+
if file_path_target == 'config.json':
160+
# for now just diff and print the output to stdout
161+
command = f'diff {dir_and_file_path} {dir_and_file_path_validation}'
162+
try:
163+
result = subprocess.run(command, capture_output=False, text=True, shell=True, check=True)
164+
except subprocess.CalledProcessError as e:
165+
# this will always fail, for now, as we are not perfectly matching
166+
print(e.stderr)
167+
168+
elif file_path_target == 'model.safetensors':
169+
# TODO implement me
170+
pass
171+
172+
with safe_open(dir_and_file_path, framework='pt') as f_target:
173+
with safe_open(dir_and_file_path_validation, framework='pt') as f_validation:
174+
k_target_seen = set()
175+
for k_target in f_target.keys():
176+
v_target = f_target.get_tensor(k_target)
177+
v_validation = f_validation.get_tensor(k_target)
178+
179+
# ensure metadata matches
180+
if v_target.shape != v_validation.shape:
181+
print(f"shape mismatch: {k_target=}, {v_target.shape=}, {v_validation.shape=}")
182+
183+
if v_target.dtype != v_validation.dtype:
184+
print(f"dtype mismatch: {k_target=}, {v_target.dtype=}, {v_validation.dtype=}")
185+
186+
# for now, no numerical checks
187+
188+
k_target_seen.add(k_target)
189+
190+
for k_validation in f_validation.keys():
191+
if k_validation not in k_target_seen:
192+
print(f"key {k_validation} not present in target")
193+
194+
else:
195+
# approx check, currently fails because modification timestamp is not the
196+
# same. Since we copy these files ourselves, low-pri to make this better.
197+
is_equal = filecmp.cmp(dir_and_file_path, dir_and_file_path_validation, shallow=False)
198+
print('filecmp equal', is_equal)
199+
200+
if __name__ == '__main__':
201+
fire.Fire(run)

hf_torchao_vllm/inspect_llm_compressor_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import fire
77

88
def run(
9-
dir_name: str = 'data/llmcompressor/opt-125m-FP8-Dynamic',
9+
dir_name: str = 'data/llmcompressor/fp8-opt-125m',
1010
):
1111
json_config_name = f'{dir_name}/config.json'
1212
with open(json_config_name, 'r') as f:

hf_torchao_vllm/quantize_hf_model_with_llm_compressor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# https://github.com/vllm-project/llm-compressor/blob/main/examples/quantization_w8a8_fp8/llama3_example.py
22

3+
import torch
4+
35
from transformers import AutoModelForCausalLM, AutoTokenizer
46

57
from llmcompressor import oneshot
@@ -14,7 +16,7 @@ def run():
1416
MODEL_ID = "facebook/opt-125m"
1517

1618
# Load model.
17-
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
19+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
1820
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
1921

2022
# Configure the quantization algorithm and scheme.
@@ -39,7 +41,7 @@ def run():
3941
print("==========================================")
4042

4143
# Save to disk in compressed-tensors format.
42-
SAVE_DIR = "data/llmcompressor/" + MODEL_ID.rstrip("/").split("/")[-1] + "-FP8-Dynamic"
44+
SAVE_DIR = "data/llmcompressor/" + "fp8-" + MODEL_ID.rstrip("/").split("/")[-1]
4345
model.save_pretrained(SAVE_DIR)
4446
tokenizer.save_pretrained(SAVE_DIR)
4547

hf_torchao_vllm/run_quantized_model_in_vllm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def print_vllm_torchao_quant_info(model: torch.nn.Module):
5454

5555
def main(
5656
# model_name: str = "Qwen/Qwen2-7B-Instruct",
57-
model_name: str = "data/fp8-opt-125m",
57+
model_name: str = "data/torchao/fp8-opt-125m",
5858
max_tokens=64,
5959
tp_size: int = 1,
6060
compile: bool = True,

0 commit comments

Comments
 (0)