Skip to content

Commit d641957

Browse files
committed
extend checkpoint conversion script for Qwen 1.5B MoE
Summary: Make everything work for a small MoE model Test Plan: Run the resulting checkpoint in vLLM, it works and properly maps to a cutlass w8a8 fused kernel Reviewers: Subscribers: Tasks: Tags:
1 parent cf1f6f5 commit d641957

File tree

2 files changed

+151
-30
lines changed

2 files changed

+151
-30
lines changed

hf_torchao_vllm/convert_torchao_checkpoint_to_compressed_tensors.py

Lines changed: 63 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import filecmp
33
import json
4+
import os
45
import pathlib
56
import shutil
67
import subprocess
@@ -16,6 +17,8 @@
1617
from safetensors import safe_open
1718
from safetensors.torch import save_file
1819

20+
from utils import convert_pt_statedict_to_safetensors, convert_pt_multifile_index_to_safetensors
21+
1922
def ao_config_to_compressed_tensors_config(aobaseconfig: AOBaseConfig) -> Dict[str, Any]:
2023
# for now, allowlist of recipes we know how to convert and hand convert
2124
# them here
@@ -55,20 +58,30 @@ def run(
5558
dir_validation: str = 'data/llmcompressor/fp8-opt-125m',
5659
skip_conversion: bool = False,
5760
):
61+
dir_source = dir_source.rstrip('/')
62+
dir_target = dir_target.rstrip('/')
63+
dir_validation = dir_validation.rstrip('/')
64+
5865
config_name_source = f"{dir_source}/config.json"
5966
config_name_target = f"{dir_target}/config.json"
6067
config_name_validation = f"{dir_validation}/config.json"
6168
weights_name_source = f"{dir_source}/pytorch_model.bin"
6269
weights_name_target = f"{dir_target}/model.safetensors"
6370
weights_name_validation = f"{dir_validation}/model.safetensors"
6471

72+
# create new dir if not yet exists
73+
os.makedirs(dir_target, exist_ok=True)
74+
6575
if not skip_conversion:
76+
source_converted_filenames = set()
77+
6678
#
6779
# convert config.json
6880
#
6981

7082
with open(config_name_source, 'r') as f:
7183
config_source = json.load(f)
84+
print(json.dumps(config_source, indent=2))
7285

7386
# get torchao config format
7487
# example: https://www.internalfb.com/phabricator/paste/view/P1975688376
@@ -78,6 +91,11 @@ def run(
7891
fqn_to_serialized_aobaseconfig = old_hf_quantization_config["quant_type"]
7992
assert len(fqn_to_serialized_aobaseconfig) == 1, "unsupported"
8093

94+
if fqn_to_serialized_aobaseconfig['default']['_type'] == 'ModuleFqnToConfig':
95+
fqn_to_serialized_aobaseconfig = \
96+
fqn_to_serialized_aobaseconfig['default']['_data']['module_fqn_to_config']
97+
98+
8199
new_hf_quantization_config = {
82100
"config_groups": {},
83101
"format": "float-quantized",
@@ -90,13 +108,14 @@ def run(
90108
}
91109

92110
for fqn, serialized_aobaseconfig in fqn_to_serialized_aobaseconfig.items():
93-
print(fqn, serialized_aobaseconfig)
111+
if serialized_aobaseconfig is None:
112+
new_hf_quantization_config['ignore'].append(fqn)
113+
continue
114+
94115
aobaseconfig = config_from_dict(serialized_aobaseconfig)
95-
print(aobaseconfig)
96116
ct_config = ao_config_to_compressed_tensors_config(aobaseconfig)
97-
print(json.dumps(ct_config, indent=2))
98117

99-
assert fqn == "default", "unsupported"
118+
assert fqn in ("default", "_default"), "unsupported"
100119
new_hf_quantization_config["config_groups"]["group_0"] = ct_config
101120

102121
# for now, modify config_source inplace
@@ -106,46 +125,58 @@ def run(
106125
with open(config_name_target, 'w') as f:
107126
json.dump(config_source, f, indent=2)
108127

128+
source_converted_filenames.add(config_name_source)
129+
109130
#
110131
# convert the checkpoint
111132
#
112133

113134
# not sure why I still need this
114135
torch.serialization.add_safe_globals([getattr])
115136

116-
old_state_dict = torch.load(weights_name_source, weights_only=True)
117-
new_state_dict = {}
118-
119-
for k, v in old_state_dict.items():
120-
print(k, v.shape, type(v))
121-
if type(v) == torch.Tensor:
122-
123-
if "lm_head" in k:
124-
# work around issues detailed in
125-
# https://huggingface.co/docs/safetensors/torch_shared_tensors
126-
v = copy.deepcopy(v)
127-
128-
new_state_dict[k] = v
129-
elif type(v) == Float8Tensor:
130-
new_state_dict[k] = v.qdata
131-
# for now, manually cast scale to bfloat16 to match currnt
132-
# llm-compressor script
133-
# TODO(future): prob needs to be user controllable
134-
new_state_dict[k + '_scale'] = v.scale.bfloat16()
135-
else:
136-
raise AssertionError(f'unsupported type {type(v)}')
137-
save_file(new_state_dict, weights_name_target)
137+
is_single_chunk = os.path.isfile(f'{dir_source}/pytorch_model.bin')
138+
if is_single_chunk:
139+
convert_pt_statedict_to_safetensors(weights_name_source, weights_name_target)
140+
source_converted_filenames.add(weights_name_source)
141+
else:
142+
# convert each model state_dict file
143+
model_part_filenames = []
144+
for file_path in pathlib.Path(dir_source).iterdir():
145+
if not file_path.is_file():
146+
continue
147+
if not (('pytorch_model') in str(file_path) and str(file_path).endswith('bin')):
148+
continue
149+
pt_sd_filename = str(file_path)
150+
# dir_source/pytorch_model-00001-of-00004.bin -> dir_target/model-00001-of-00004.safetensors
151+
safetensors_sd_filename = pt_sd_filename.replace(dir_source, dir_target)
152+
safetensors_sd_filename = safetensors_sd_filename.replace('pytorch_model', 'model')
153+
safetensors_sd_filename = safetensors_sd_filename.replace('.bin', '.safetensors')
154+
model_part_filenames.append(safetensors_sd_filename)
155+
print(pt_sd_filename, safetensors_sd_filename)
156+
convert_pt_statedict_to_safetensors(pt_sd_filename, safetensors_sd_filename)
157+
source_converted_filenames.add(pt_sd_filename)
158+
159+
# convert pytorch_model.bin.index.json
160+
convert_pt_multifile_index_to_safetensors(
161+
f'{dir_source}/pytorch_model.bin.index.json',
162+
f'{dir_target}/model.safetensors.index.json',
163+
model_part_filenames,
164+
)
165+
source_converted_filenames.add(f'{dir_source}/pytorch_model.bin.index.json')
166+
167+
print(source_converted_filenames)
138168

139169
# move all the other files over
140170
for dir_and_file_path in pathlib.Path(dir_source).iterdir():
141171
if not dir_and_file_path.is_file():
142172
continue
143-
file_path = dir_and_file_path.parts[-1]
144-
if file_path in ('config.json', 'pytorch_model.bin'):
173+
if str(dir_and_file_path) in source_converted_filenames:
145174
# these are converted in custom logic elsewhere in this script
146175
continue
147176
# if we got here, we just need to copy the file over without any changes
177+
file_path = dir_and_file_path.parts[-1]
148178
target_file_path = f"{dir_target}/{str(file_path)}"
179+
print(f'copying {dir_and_file_path} to {target_file_path}')
149180
shutil.copyfile(dir_and_file_path, target_file_path)
150181

151182
# validate target_dir vs validation_dir
@@ -165,9 +196,11 @@ def run(
165196
# this will always fail, for now, as we are not perfectly matching
166197
print(e.stderr)
167198

199+
# TODO(future, as needed): also validate the other files, they are unlikely to match
200+
# exactly for any model with >1 chunk of state dict files since we are not
201+
# trying to enfore that the same tensors live in the same chunks.
202+
168203
elif file_path_target == 'model.safetensors':
169-
# TODO implement me
170-
pass
171204

172205
with safe_open(dir_and_file_path, framework='pt') as f_target:
173206
with safe_open(dir_and_file_path_validation, framework='pt') as f_validation:

hf_torchao_vllm/utils.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
import copy
12
import json
23
import os
4+
from typing import List
35
import pathlib
46

57
import safetensors
8+
from safetensors.torch import save_file
69

710
import torch
11+
from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor
12+
813

914
torch.serialization.add_safe_globals([getattr])
1015

@@ -57,3 +62,86 @@ def inspect_model_state_dict(dir_name, model_name, model_extension) -> None:
5762
continue
5863
print(file_path)
5964
_inspect_state_dict_file(file_path)
65+
66+
def convert_pt_statedict_to_safetensors(
67+
pt_statedict_filename,
68+
safetensors_statedict_filename,
69+
) -> None:
70+
old_state_dict = torch.load(pt_statedict_filename, weights_only=True)
71+
new_state_dict = {}
72+
73+
for k, v in old_state_dict.items():
74+
print(k, v.shape, type(v))
75+
if type(v) == torch.Tensor:
76+
77+
if "lm_head" in k:
78+
# work around issues detailed in
79+
# https://huggingface.co/docs/safetensors/torch_shared_tensors
80+
v = copy.deepcopy(v)
81+
82+
new_state_dict[k] = v
83+
elif type(v) == Float8Tensor:
84+
new_state_dict[k] = v.qdata
85+
# for now, manually cast scale to bfloat16 to match current
86+
# llm-compressor script
87+
# TODO(future): prob needs to be user controllable
88+
new_state_dict[k + '_scale'] = v.scale.bfloat16()
89+
else:
90+
raise AssertionError(f'unsupported type {type(v)}')
91+
save_file(new_state_dict, safetensors_statedict_filename)
92+
93+
def convert_pt_multifile_index_to_safetensors(
94+
source_filename: str,
95+
target_filename: str,
96+
model_part_filenames: List[str],
97+
) -> None:
98+
"""
99+
Source format
100+
101+
{
102+
"metadata": {...},
103+
"weight_map": {
104+
"foo": "pytorch_model-00001-of-00004.bin",
105+
"bar": "pytorch_model-00002-of-00004.bin",
106+
...
107+
}
108+
}
109+
110+
Target format
111+
112+
{
113+
"metadata": {...},
114+
"weight_map": {
115+
# weight already in high precision
116+
"foo": "pytorch_model-00001-of-00004.bin",
117+
# weight original stored as tensor subclass, but now decomposed
118+
# into qdata and scale
119+
"bar": "model-00002-of-00004.safetensors",
120+
"bar_scale": "model-00002-of-00004.safetensors",
121+
...
122+
}
123+
}
124+
125+
For now, metadata is not updated.
126+
"""
127+
128+
# generate the new fqn to weight location map from the new safetensors files
129+
new_weight_map = {}
130+
for model_part_filename in model_part_filenames:
131+
# print(model_part_filename)
132+
133+
# get the file_name from dir_name/file_name
134+
basename = os.path.basename(model_part_filename)
135+
# print(basename)
136+
137+
with safetensors.safe_open(model_part_filename, framework='pt', device='cpu') as f:
138+
for k in f.keys():
139+
new_weight_map[k] = basename
140+
141+
# save the updated mapping
142+
with open(source_filename, 'r') as f:
143+
source_mapping = json.load(f)
144+
source_mapping['weight_map'] = new_weight_map
145+
# print(json.dumps(source_mapping, indent=2))
146+
with open(target_filename, 'w') as f:
147+
json.dump(source_mapping, f, indent=2)

0 commit comments

Comments
 (0)