Skip to content

Commit b130677

Browse files
authored
Merge pull request #74 from vkuzo/20251007_torchao_nvfp4_to_llmcompressor
convert dense nvfp4 checkpoint to compressed-tensors, and run in vllm
2 parents 53777ac + d06c1a9 commit b130677

File tree

5 files changed

+125
-32
lines changed

5 files changed

+125
-32
lines changed

hf_torchao_vllm/convert_torchao_checkpoint_to_compressed_tensors.py

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,37 +17,12 @@
1717
from safetensors import safe_open
1818
from safetensors.torch import save_file
1919

20-
from utils import convert_pt_statedict_to_safetensors, convert_pt_multifile_index_to_safetensors
21-
22-
def ao_config_to_compressed_tensors_config(aobaseconfig: AOBaseConfig) -> Dict[str, Any]:
23-
# for now, allowlist of recipes we know how to convert and hand convert
24-
# them here
25-
# for a production version, we'll need a more scalable way to do this
26-
27-
assert isinstance(aobaseconfig, Float8DynamicActivationFloat8WeightConfig), "unsupported"
28-
assert aobaseconfig.granularity == [PerRow(), PerRow()], "unsupported"
29-
30-
ct_config = {
31-
"format": "float-quantized",
32-
"input_activations": {
33-
"dynamic": True,
34-
"num_bits": 8,
35-
"strategy": "token",
36-
"symmetric": True,
37-
"type": "float",
38-
},
39-
"output_activations": None,
40-
"targets": ["Linear"],
41-
"weights": {
42-
"dynamic": False,
43-
"num_bits": 8,
44-
"observer": "minmax",
45-
"strategy": "channel",
46-
"symmetric": True,
47-
"type": "float",
48-
},
49-
}
50-
return ct_config
20+
from utils import (
21+
convert_pt_statedict_to_safetensors,
22+
convert_pt_multifile_index_to_safetensors,
23+
ao_config_to_compressed_tensors_config,
24+
)
25+
5126

5227
def run(
5328
# original torchao checkpoint

hf_torchao_vllm/inspect_torchao_output.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010

1111
from utils import inspect_model_state_dict
1212

13+
# ensure NVFP4Tensor can be loaded
14+
import torchao.prototype.mx_formats.inference_workflow
15+
16+
# TODO: ensure the line below happens in torchao
17+
import torchao
18+
torch.serialization.add_safe_globals([torchao.prototype.mx_formats.nvfp4_tensor.QuantizeTensorToNVFP4Kwargs])
19+
1320
# not sure why I still need this
1421
torch.serialization.add_safe_globals([getattr])
1522

hf_torchao_vllm/quantize_hf_model_with_torchao.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,15 @@ def get_quantization_config(args):
147147
single_config = NVFP4InferenceConfig(
148148
mm_config=NVFP4MMConfig.WEIGHT_ONLY,
149149
use_triton_kernel=False,
150+
#
151+
# weight_only and use_dynamic_per_tensor_scale=True works here
152+
# but garbage output in vLLM, probably because we currently don't have a way
153+
# in torchao to enforce the scales for attention and ffn weights that
154+
# are going to be fused for inference to be the same
155+
# TODO: file a torchao issue about this, and fix in torchao
156+
#
157+
# dynamic and use_dynamic_per_tensor_scale=False not supported in torch._scaled_mm
158+
#
150159
use_dynamic_per_tensor_scale=False,
151160
)
152161
if args.experts_only_qwen_1_5_moe_a_2_7b:

hf_torchao_vllm/run_quantized_model_in_vllm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def print_vllm_torchao_quant_info(model: torch.nn.Module):
4545
for name, mod in model.named_modules():
4646
if "Linear" not in str(type(mod)):
4747
continue
48+
if not hasattr(mod, "weight"):
49+
continue
4850
mod_and_weight_type = type(mod), type(mod.weight)
4951
if mod_and_weight_type in seen_types:
5052
continue

hf_torchao_vllm/utils.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
import copy
22
import json
33
import os
4-
from typing import List
4+
from typing import List, Dict, Any
55
import pathlib
66

77
import safetensors
88
from safetensors.torch import save_file
99

1010
import torch
11+
import torchao
12+
from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor
13+
14+
from torchao.core.config import AOBaseConfig, config_from_dict
15+
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow
16+
from torchao.prototype.mx_formats.inference_workflow import NVFP4InferenceConfig
17+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
18+
from torchao.prototype.mx_formats.utils import from_blocked
1119
from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor
1220

1321

@@ -80,12 +88,42 @@ def convert_pt_statedict_to_safetensors(
8088
v = copy.deepcopy(v)
8189

8290
new_state_dict[k] = v
91+
8392
elif type(v) == Float8Tensor:
8493
new_state_dict[k] = v.qdata
8594
# for now, manually cast scale to bfloat16 to match current
8695
# llm-compressor script
8796
# TODO(future): prob needs to be user controllable
8897
new_state_dict[k + '_scale'] = v.scale.bfloat16()
98+
99+
elif type(v) == NVFP4Tensor:
100+
# example checkpoint format: https://www.internalfb.com/phabricator/paste/view/P1981272933
101+
102+
# torchao does not support nvfp4 activation calibration yet,
103+
# set activation global scale to 1.0
104+
new_state_dict[k.replace('weight', 'input_global_scale')] = torch.tensor([1.0])
105+
# torchao does not support fusion-aware nvfp4 weight global scale yet,
106+
# set weight scale to 1.0
107+
new_state_dict[k.replace('weight', 'weight_global_scale')] = torch.tensor([1.0])
108+
new_state_dict[k + '_packed'] = v.qdata.view(torch.uint8)
109+
# compressed-tensors stores the nvfp4 scale in row-major format,
110+
# convert from swizzled to row-major
111+
swizzled_scale = v._scale_e4m3
112+
original_rows = v.qdata.shape[0]
113+
# multiply by 2 to undo the packing, then divide by nvfp4 block size of 16
114+
original_cols = v.qdata.shape[1] * 2 // 16
115+
# TODO(future) also do the padding calculation here and remove the
116+
# assertions
117+
assert original_rows % 128 == 0, "unsupported"
118+
assert original_cols % 4 == 0, "unsupported"
119+
# import pdb; pdb.set_trace()
120+
row_major_scale = from_blocked(
121+
swizzled_scale,
122+
original_rows,
123+
original_cols,
124+
)
125+
new_state_dict[k + '_scale'] = row_major_scale
126+
89127
else:
90128
raise AssertionError(f'unsupported type {type(v)}')
91129
save_file(new_state_dict, safetensors_statedict_filename)
@@ -145,3 +183,65 @@ def convert_pt_multifile_index_to_safetensors(
145183
# print(json.dumps(source_mapping, indent=2))
146184
with open(target_filename, 'w') as f:
147185
json.dump(source_mapping, f, indent=2)
186+
187+
188+
def ao_config_to_compressed_tensors_config(aobaseconfig: AOBaseConfig) -> Dict[str, Any]:
189+
# for now, allowlist of recipes we know how to convert and hand convert
190+
# them here
191+
# for a production version, we'll need a more scalable way to do this
192+
193+
if isinstance(aobaseconfig, Float8DynamicActivationFloat8WeightConfig):
194+
assert aobaseconfig.granularity == [PerRow(), PerRow()], "unsupported"
195+
196+
ct_config = {
197+
"format": "float-quantized",
198+
"input_activations": {
199+
"dynamic": True,
200+
"num_bits": 8,
201+
"strategy": "token",
202+
"symmetric": True,
203+
"type": "float",
204+
},
205+
"output_activations": None,
206+
"targets": ["Linear"],
207+
"weights": {
208+
"dynamic": False,
209+
"num_bits": 8,
210+
"observer": "minmax",
211+
"strategy": "channel",
212+
"symmetric": True,
213+
"type": "float",
214+
},
215+
}
216+
217+
elif isinstance(aobaseconfig, NVFP4InferenceConfig):
218+
219+
ct_config = {
220+
"format": "nvfp4-pack-quantized",
221+
"input_activations": {
222+
"dynamic": "local",
223+
"group_size": 16,
224+
"num_bits": 4,
225+
"observer": "minmax",
226+
"observer_kwargs": {},
227+
"strategy": "tensor_group",
228+
"symmetric": True,
229+
"type": "float",
230+
},
231+
"output_activations": None,
232+
"targets": ["Linear"],
233+
"weights": {
234+
"dynamic": False,
235+
"group_size": 16,
236+
"num_bits": 4,
237+
"observer": "minmax",
238+
"observer_kwargs": {},
239+
"strategy": "tensor_group",
240+
"symmetric": True,
241+
"type": "float"
242+
},
243+
}
244+
245+
else:
246+
raise AssertionError(f"unsupported type {type(aobaseconfig)}")
247+
return ct_config

0 commit comments

Comments
 (0)