Skip to content

Commit 8cfca15

Browse files
authored
Merge pull request #69 from vkuzo/20251003_compressed_tensors_qwen
extend llmcompressor script and inspection scripts to handle Qwen-1.5…
2 parents 9732514 + 5c7cd99 commit 8cfca15

File tree

4 files changed

+105
-21
lines changed

4 files changed

+105
-21
lines changed

hf_torchao_vllm/inspect_llm_compressor_output.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import json
66
import fire
77

8+
from utils import inspect_model_state_dict
9+
810
def run(
911
dir_name: str = 'data/llmcompressor/fp8-opt-125m',
1012
):
@@ -14,13 +16,17 @@ def run(
1416
# TODO: pretty print
1517
print(json.dumps(data, indent=2))
1618

17-
# inpect the model, saved in safetensors format
18-
model_name = f'{dir_name}/model.safetensors'
19-
with safetensors.safe_open(model_name, framework='pt', device='cpu') as f:
20-
print(f.metadata())
21-
for k in f.keys():
22-
t = f.get_tensor(k)
23-
print(k, t.shape, t.dtype)
19+
model_name, model_extension = 'model', 'safetensors'
20+
inspect_model_state_dict(dir_name, model_name, model_extension)
21+
22+
if False:
23+
# inpect the model, saved in safetensors format
24+
model_name = f'{dir_name}/model.safetensors'
25+
with safetensors.safe_open(model_name, framework='pt', device='cpu') as f:
26+
print(f.metadata())
27+
for k in f.keys():
28+
t = f.get_tensor(k)
29+
print(k, t.shape, t.dtype)
2430

2531
if __name__ == '__main__':
2632
fire.Fire(run)

hf_torchao_vllm/inspect_torchao_output.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22
# via the `torchao_hf_script.py` script
33

44
import json
5+
import os
6+
import pathlib
57
import torch
68
import torchao # this is needed to run torch.serialization.add_safe_globals([torchao.quantization.Float8Tensor])
79
import fire
810

11+
from utils import inspect_model_state_dict
12+
913
# not sure why I still need this
1014
torch.serialization.add_safe_globals([getattr])
1115

@@ -15,14 +19,21 @@ def run(dir_name: str = 'data/torchao/fp8-opt-125m'):
1519
# inspect the config
1620
with open(json_config_name, 'r') as f:
1721
data = json.load(f)
18-
# TODO: pretty print
1922
print(json.dumps(data, indent=2))
2023

2124
# inspect the data
22-
model_name = f'{dir_name}/pytorch_model.bin'
23-
state_dict = torch.load(model_name, weights_only=True)
24-
for k, v in state_dict.items():
25-
print(k, v.shape, type(v))
25+
#
26+
# if there is a single chunk, the state dict is named `pytorch_model.bin`
27+
#
28+
# if there are multiple chunks, the state dict is spread across multiple files:
29+
#
30+
# pytorch_model-00001-of-00004.bin
31+
# ...
32+
# pytorch_model-00004-of-00004.bin
33+
# pytorch_model.bin.index.json
34+
#
35+
model_name, model_extension = 'pytorch_model', 'bin'
36+
inspect_model_state_dict(dir_name, model_name, model_extension)
2637

2738
if __name__ == '__main__':
2839
fire.Fire(run)

hf_torchao_vllm/quantize_hf_model_with_llm_compressor.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,29 @@
1010

1111
import fire
1212

13-
def run():
14-
15-
# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
16-
MODEL_ID = "facebook/opt-125m"
17-
13+
def run(model_name: str = 'facebook/opt-125m'):
1814
# Load model.
19-
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
20-
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
15+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
16+
print(model)
17+
tokenizer = AutoTokenizer.from_pretrained(model_name)
2118

2219
# Configure the quantization algorithm and scheme.
2320
# In this case, we:
2421
# * quantize the weights to fp8 with per channel via ptq
2522
# * quantize the activations to fp8 with dynamic per token
2623
recipe = QuantizationModifier(
27-
targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"]
24+
targets="Linear",
25+
scheme="FP8_DYNAMIC",
26+
ignore=[
27+
"lm_head",
28+
# for Qwen MoE, but ok to just hardcode here for now
29+
# https://github.com/vllm-project/llm-compressor/blob/33ef5f497a9801893764c6a2c880cb1f560067fa/examples/quantizing_moe/qwen_example.py#L10
30+
"re:.*mlp.gate$",
31+
"re:.*mlp.shared_expert_gate$",
32+
# also skip attention and shared expert, to focus on MoE for now
33+
"re:.*self_attn.*",
34+
"re:.*shared_expert.*",
35+
],
2836
)
2937

3038
# Apply quantization.
@@ -41,7 +49,7 @@ def run():
4149
print("==========================================")
4250

4351
# Save to disk in compressed-tensors format.
44-
SAVE_DIR = "data/llmcompressor/" + "fp8-" + MODEL_ID.rstrip("/").split("/")[-1]
52+
SAVE_DIR = "data/llmcompressor/" + "fp8-" + model_name.rstrip("/").split("/")[-1]
4553
model.save_pretrained(SAVE_DIR)
4654
tokenizer.save_pretrained(SAVE_DIR)
4755

hf_torchao_vllm/utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import json
2+
import os
3+
import pathlib
4+
5+
import safetensors
6+
7+
import torch
8+
9+
torch.serialization.add_safe_globals([getattr])
10+
11+
def _inspect_state_dict_file(model_name):
12+
if str(model_name).endswith('safetensors'):
13+
# safetensors format
14+
with safetensors.safe_open(model_name, framework='pt', device='cpu') as f:
15+
print(f.metadata())
16+
for k in f.keys():
17+
t = f.get_tensor(k)
18+
print(k, type(t), t.shape, t.dtype)
19+
else:
20+
# pytorch format
21+
state_dict = torch.load(model_name, weights_only=True)
22+
for k, v in state_dict.items():
23+
print(k, type(v), v.shape, v.dtype)
24+
25+
def inspect_model_state_dict(dir_name, model_name, model_extension) -> None:
26+
"""
27+
Inspect the model state_dict from HuggingFace and print data to stdout.
28+
For example, if model_name == `pytorch_model` and extension == `bin`,
29+
1. if there is a single chunk, the state dict is named `pytorch_model.bin`
30+
2. if there are multiple chunks, the state dict is spread across multiple
31+
files:
32+
33+
pytorch_model-00001-of-00004.bin
34+
...
35+
pytorch_model-00004-of-00004.bin
36+
pytorch_model.bin.index.json
37+
"""
38+
is_single_chunk = os.path.isfile(f'{dir_name}/{model_name}.{model_extension}')
39+
if is_single_chunk:
40+
print('single state dict file')
41+
model_name = f'{dir_name}/{model_name}.{model_extension}'
42+
_inspect_state_dict_file(model_name)
43+
else:
44+
print('multiple state dict files')
45+
46+
index_name = f'{dir_name}/{model_name}.{model_extension}.index.json'
47+
print(index_name)
48+
with open(index_name, 'r') as f:
49+
data = json.load(f)
50+
print(json.dumps(data, indent=2))
51+
52+
# iterate through each file
53+
for file_path in pathlib.Path(dir_name).iterdir():
54+
if not file_path.is_file():
55+
continue
56+
if not (model_name in str(file_path) and str(file_path).endswith(model_extension)):
57+
continue
58+
print(file_path)
59+
_inspect_state_dict_file(file_path)

0 commit comments

Comments
 (0)