Skip to content

Commit e145bda

Browse files
authored
Merge pull request #76 from vkuzo/add_ruff_linting_hf_torchao_vllm
Add ruff linting configuration to hf_torchao_vllm directory
2 parents e44cafd + 2c50e3c commit e145bda

9 files changed

+481
-243
lines changed

hf_torchao_vllm/.ruff.toml

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Ruff configuration for hf_torchao_vllm directory
2+
# See: https://docs.astral.sh/ruff/configuration/
3+
4+
# Target Python version
5+
target-version = "py312"
6+
7+
# Line length for formatting
8+
line-length = 80
9+
10+
# Enable specific rule categories
11+
lint.select = [
12+
"E", # pycodestyle errors
13+
"W", # pycodestyle warnings
14+
"F", # pyflakes
15+
"I", # isort (import sorting)
16+
"B", # flake8-bugbear (common bugs and design problems)
17+
"C4", # flake8-comprehensions (list/dict/set comprehensions)
18+
"UP", # pyupgrade (modernize Python code)
19+
"N", # pep8-naming (naming conventions)
20+
"SIM", # flake8-simplify (code simplification)
21+
]
22+
23+
# Ignore specific rules
24+
lint.ignore = [
25+
"E501", # line too long (handled by formatter)
26+
"E731", # do not assign a lambda expression, use a def
27+
"B008", # do not perform function calls in argument defaults
28+
"SIM108", # use ternary operator instead of if-else
29+
"N806", # variable name should be lowercase (for ML variables like N, K, etc.)
30+
"F841", # unused vars, TODO fix this
31+
"SIM117", # nested context, TODO fix this
32+
"SIM118", # key in dict, TODO fix this
33+
"E721", # type comparison, TODO fix this
34+
"B007", # loop control var unused, TODO fix this
35+
]
36+
37+
# Exclude files and directories
38+
exclude = [
39+
"__pycache__",
40+
"*.pyc",
41+
".git",
42+
"build",
43+
"dist",
44+
"*.egg-info",
45+
"data",
46+
"sparse_logs",
47+
]
48+
49+
[lint.per-file-ignores]
50+
# Allow unused imports in __init__.py files
51+
"__init__.py" = ["F401"]
52+
53+
[lint.isort]
54+
# Sort imports
55+
known-first-party = ["torchao", "transformers", "torch"]
56+
force-single-line = false
57+
split-on-trailing-comma = true
58+
59+
[format]
60+
# Formatting options
61+
quote-style = "double"
62+
indent-style = "space"
63+
line-ending = "auto"

hf_torchao_vllm/README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# HF -> torchao -> vLLM convenience scripts
22

3-
Example
3+
## Usage Example
44

55
```bash
66
# save a quantized model ot data/nvfp4-Qwen1.5-MoE-A2.7B
@@ -9,3 +9,10 @@ python quantize_hf_model_with_torchao.py --model_name "Qwen/Qwen1.5-MoE-A2.7B" -
99
# run the model from above in vLLM
1010
python run_quantized_model_in_vllm.py --model_name "data/torchao/nvfp4-Qwen1.5-MoE-A2.7B" --compile False
1111
```
12+
13+
## Code Quality & Linting
14+
15+
```bash
16+
ruff format .
17+
ruff check . --fix
18+
```
Lines changed: 86 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,39 @@
1-
import copy
21
import filecmp
32
import json
43
import os
54
import pathlib
65
import shutil
76
import subprocess
8-
from typing import Dict, Any
97

108
import fire
11-
12-
import torch
13-
from torchao.core.config import AOBaseConfig, config_from_dict
14-
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow
15-
from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor
16-
179
from safetensors import safe_open
18-
from safetensors.torch import save_file
1910

11+
import torch
12+
from torchao.core.config import config_from_dict
2013
from utils import (
21-
convert_pt_statedict_to_safetensors,
22-
convert_pt_multifile_index_to_safetensors,
2314
ao_config_to_compressed_tensors_config,
15+
convert_pt_multifile_index_to_safetensors,
16+
convert_pt_statedict_to_safetensors,
2417
)
2518

2619

2720
def run(
2821
# original torchao checkpoint
29-
dir_source: str = 'data/torchao/fp8-opt-125m',
22+
dir_source: str = "data/torchao/fp8-opt-125m",
3023
# new compressed-tensors checkpoint
31-
dir_target: str = 'data/torchao_compressed_tensors/fp8-opt-125m',
24+
dir_target: str = "data/torchao_compressed_tensors/fp8-opt-125m",
3225
# existing compressed-tensors checkpoint to validate against
33-
dir_validation: str = 'data/llmcompressor/fp8-opt-125m',
26+
dir_validation: str = "data/llmcompressor/fp8-opt-125m",
3427
skip_conversion: bool = False,
3528
):
36-
dir_source = dir_source.rstrip('/')
37-
dir_target = dir_target.rstrip('/')
38-
dir_validation = dir_validation.rstrip('/')
29+
dir_source = dir_source.rstrip("/")
30+
dir_target = dir_target.rstrip("/")
31+
dir_validation = dir_validation.rstrip("/")
3932

4033
config_name_source = f"{dir_source}/config.json"
4134
config_name_target = f"{dir_target}/config.json"
4235
config_name_validation = f"{dir_validation}/config.json"
43-
weights_name_source = f"{dir_source}/pytorch_model.bin"
36+
weights_name_source = f"{dir_source}/pytorch_model.bin"
4437
weights_name_target = f"{dir_target}/model.safetensors"
4538
weights_name_validation = f"{dir_validation}/model.safetensors"
4639

@@ -54,7 +47,7 @@ def run(
5447
# convert config.json
5548
#
5649

57-
with open(config_name_source, 'r') as f:
50+
with open(config_name_source) as f:
5851
config_source = json.load(f)
5952
print(json.dumps(config_source, indent=2))
6053

@@ -63,13 +56,18 @@ def run(
6356
# we need to translate it to compressed-tensors format
6457
# example: https://www.internalfb.com/phabricator/paste/view/P1975642629
6558
old_hf_quantization_config = config_source["quantization_config"]
66-
fqn_to_serialized_aobaseconfig = old_hf_quantization_config["quant_type"]
59+
fqn_to_serialized_aobaseconfig = old_hf_quantization_config[
60+
"quant_type"
61+
]
6762
assert len(fqn_to_serialized_aobaseconfig) == 1, "unsupported"
6863

69-
if fqn_to_serialized_aobaseconfig['default']['_type'] == 'ModuleFqnToConfig':
70-
fqn_to_serialized_aobaseconfig = \
71-
fqn_to_serialized_aobaseconfig['default']['_data']['module_fqn_to_config']
72-
64+
if (
65+
fqn_to_serialized_aobaseconfig["default"]["_type"]
66+
== "ModuleFqnToConfig"
67+
):
68+
fqn_to_serialized_aobaseconfig = fqn_to_serialized_aobaseconfig[
69+
"default"
70+
]["_data"]["module_fqn_to_config"]
7371

7472
new_hf_quantization_config = {
7573
"config_groups": {},
@@ -82,9 +80,12 @@ def run(
8280
"version": "torchao_hack",
8381
}
8482

85-
for fqn, serialized_aobaseconfig in fqn_to_serialized_aobaseconfig.items():
83+
for (
84+
fqn,
85+
serialized_aobaseconfig,
86+
) in fqn_to_serialized_aobaseconfig.items():
8687
if serialized_aobaseconfig is None:
87-
new_hf_quantization_config['ignore'].append(fqn)
88+
new_hf_quantization_config["ignore"].append(fqn)
8889
continue
8990

9091
aobaseconfig = config_from_dict(serialized_aobaseconfig)
@@ -97,7 +98,7 @@ def run(
9798
config_source["quantization_config"] = new_hf_quantization_config
9899

99100
# save to new location
100-
with open(config_name_target, 'w') as f:
101+
with open(config_name_target, "w") as f:
101102
json.dump(config_source, f, indent=2)
102103

103104
source_converted_filenames.add(config_name_source)
@@ -109,35 +110,50 @@ def run(
109110
# not sure why I still need this
110111
torch.serialization.add_safe_globals([getattr])
111112

112-
is_single_chunk = os.path.isfile(f'{dir_source}/pytorch_model.bin')
113+
is_single_chunk = os.path.isfile(f"{dir_source}/pytorch_model.bin")
113114
if is_single_chunk:
114-
convert_pt_statedict_to_safetensors(weights_name_source, weights_name_target)
115+
convert_pt_statedict_to_safetensors(
116+
weights_name_source, weights_name_target
117+
)
115118
source_converted_filenames.add(weights_name_source)
116119
else:
117-
# convert each model state_dict file
120+
# convert each model state_dict file
118121
model_part_filenames = []
119122
for file_path in pathlib.Path(dir_source).iterdir():
120123
if not file_path.is_file():
121124
continue
122-
if not (('pytorch_model') in str(file_path) and str(file_path).endswith('bin')):
125+
if not (
126+
("pytorch_model") in str(file_path)
127+
and str(file_path).endswith("bin")
128+
):
123129
continue
124130
pt_sd_filename = str(file_path)
125131
# dir_source/pytorch_model-00001-of-00004.bin -> dir_target/model-00001-of-00004.safetensors
126-
safetensors_sd_filename = pt_sd_filename.replace(dir_source, dir_target)
127-
safetensors_sd_filename = safetensors_sd_filename.replace('pytorch_model', 'model')
128-
safetensors_sd_filename = safetensors_sd_filename.replace('.bin', '.safetensors')
132+
safetensors_sd_filename = pt_sd_filename.replace(
133+
dir_source, dir_target
134+
)
135+
safetensors_sd_filename = safetensors_sd_filename.replace(
136+
"pytorch_model", "model"
137+
)
138+
safetensors_sd_filename = safetensors_sd_filename.replace(
139+
".bin", ".safetensors"
140+
)
129141
model_part_filenames.append(safetensors_sd_filename)
130142
print(pt_sd_filename, safetensors_sd_filename)
131-
convert_pt_statedict_to_safetensors(pt_sd_filename, safetensors_sd_filename)
143+
convert_pt_statedict_to_safetensors(
144+
pt_sd_filename, safetensors_sd_filename
145+
)
132146
source_converted_filenames.add(pt_sd_filename)
133147

134148
# convert pytorch_model.bin.index.json
135149
convert_pt_multifile_index_to_safetensors(
136-
f'{dir_source}/pytorch_model.bin.index.json',
137-
f'{dir_target}/model.safetensors.index.json',
150+
f"{dir_source}/pytorch_model.bin.index.json",
151+
f"{dir_target}/model.safetensors.index.json",
138152
model_part_filenames,
139153
)
140-
source_converted_filenames.add(f'{dir_source}/pytorch_model.bin.index.json')
154+
source_converted_filenames.add(
155+
f"{dir_source}/pytorch_model.bin.index.json"
156+
)
141157

142158
print(source_converted_filenames)
143159

@@ -151,7 +167,7 @@ def run(
151167
# if we got here, we just need to copy the file over without any changes
152168
file_path = dir_and_file_path.parts[-1]
153169
target_file_path = f"{dir_target}/{str(file_path)}"
154-
print(f'copying {dir_and_file_path} to {target_file_path}')
170+
print(f"copying {dir_and_file_path} to {target_file_path}")
155171
shutil.copyfile(dir_and_file_path, target_file_path)
156172

157173
# validate target_dir vs validation_dir
@@ -160,36 +176,49 @@ def run(
160176
continue
161177
file_path_target = dir_and_file_path.parts[-1]
162178
print("\nvalidating", file_path_target)
163-
dir_and_file_path_validation = f"{dir_validation}/{str(file_path_target)}"
179+
dir_and_file_path_validation = (
180+
f"{dir_validation}/{str(file_path_target)}"
181+
)
164182

165-
if file_path_target == 'config.json':
183+
if file_path_target == "config.json":
166184
# for now just diff and print the output to stdout
167-
command = f'diff {dir_and_file_path} {dir_and_file_path_validation}'
185+
command = f"diff {dir_and_file_path} {dir_and_file_path_validation}"
168186
try:
169-
result = subprocess.run(command, capture_output=False, text=True, shell=True, check=True)
187+
result = subprocess.run(
188+
command,
189+
capture_output=False,
190+
text=True,
191+
shell=True,
192+
check=True,
193+
)
170194
except subprocess.CalledProcessError as e:
171195
# this will always fail, for now, as we are not perfectly matching
172-
print(e.stderr)
196+
print(e.stderr)
173197

174198
# TODO(future, as needed): also validate the other files, they are unlikely to match
175199
# exactly for any model with >1 chunk of state dict files since we are not
176200
# trying to enfore that the same tensors live in the same chunks.
177201

178-
elif file_path_target == 'model.safetensors':
179-
180-
with safe_open(dir_and_file_path, framework='pt') as f_target:
181-
with safe_open(dir_and_file_path_validation, framework='pt') as f_validation:
202+
elif file_path_target == "model.safetensors":
203+
with safe_open(dir_and_file_path, framework="pt") as f_target:
204+
with safe_open(
205+
dir_and_file_path_validation, framework="pt"
206+
) as f_validation:
182207
k_target_seen = set()
183208
for k_target in f_target.keys():
184209
v_target = f_target.get_tensor(k_target)
185210
v_validation = f_validation.get_tensor(k_target)
186211

187212
# ensure metadata matches
188213
if v_target.shape != v_validation.shape:
189-
print(f"shape mismatch: {k_target=}, {v_target.shape=}, {v_validation.shape=}")
214+
print(
215+
f"shape mismatch: {k_target=}, {v_target.shape=}, {v_validation.shape=}"
216+
)
190217

191-
if v_target.dtype != v_validation.dtype:
192-
print(f"dtype mismatch: {k_target=}, {v_target.dtype=}, {v_validation.dtype=}")
218+
if v_target.dtype != v_validation.dtype:
219+
print(
220+
f"dtype mismatch: {k_target=}, {v_target.dtype=}, {v_validation.dtype=}"
221+
)
193222

194223
# for now, no numerical checks
195224

@@ -202,8 +231,11 @@ def run(
202231
else:
203232
# approx check, currently fails because modification timestamp is not the
204233
# same. Since we copy these files ourselves, low-pri to make this better.
205-
is_equal = filecmp.cmp(dir_and_file_path, dir_and_file_path_validation, shallow=False)
206-
print('filecmp equal', is_equal)
234+
is_equal = filecmp.cmp(
235+
dir_and_file_path, dir_and_file_path_validation, shallow=False
236+
)
237+
print("filecmp equal", is_equal)
238+
207239

208-
if __name__ == '__main__':
240+
if __name__ == "__main__":
209241
fire.Fire(run)
Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
11
# inspects the output of model created with llm-compressor
22
# via the `run_llm_compressor.py` script
33

4-
import safetensors
54
import json
5+
66
import fire
77

88
from utils import inspect_model_state_dict
99

10+
1011
def run(
11-
dir_name: str = 'data/llmcompressor/fp8-opt-125m',
12+
dir_name: str = "data/llmcompressor/fp8-opt-125m",
1213
):
13-
json_config_name = f'{dir_name}/config.json'
14-
with open(json_config_name, 'r') as f:
14+
json_config_name = f"{dir_name}/config.json"
15+
with open(json_config_name) as f:
1516
data = json.load(f)
1617
# TODO: pretty print
1718
print(json.dumps(data, indent=2))
1819

19-
model_name, model_extension = 'model', 'safetensors'
20+
model_name, model_extension = "model", "safetensors"
2021
inspect_model_state_dict(dir_name, model_name, model_extension)
2122

22-
if __name__ == '__main__':
23+
24+
if __name__ == "__main__":
2325
fire.Fire(run)

0 commit comments

Comments
 (0)