|
1 | 1 | import copy |
2 | 2 | import json |
3 | 3 | import os |
4 | | -from typing import List |
| 4 | +from typing import List, Dict, Any |
5 | 5 | import pathlib |
6 | 6 |
|
7 | 7 | import safetensors |
8 | 8 | from safetensors.torch import save_file |
9 | 9 |
|
10 | 10 | import torch |
| 11 | +import torchao |
| 12 | +from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor |
| 13 | + |
| 14 | +from torchao.core.config import AOBaseConfig, config_from_dict |
| 15 | +from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow |
| 16 | +from torchao.prototype.mx_formats.inference_workflow import NVFP4InferenceConfig |
| 17 | +from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor |
| 18 | +from torchao.prototype.mx_formats.utils import from_blocked |
11 | 19 | from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor |
12 | 20 |
|
13 | 21 |
|
@@ -80,12 +88,42 @@ def convert_pt_statedict_to_safetensors( |
80 | 88 | v = copy.deepcopy(v) |
81 | 89 |
|
82 | 90 | new_state_dict[k] = v |
| 91 | + |
83 | 92 | elif type(v) == Float8Tensor: |
84 | 93 | new_state_dict[k] = v.qdata |
85 | 94 | # for now, manually cast scale to bfloat16 to match current |
86 | 95 | # llm-compressor script |
87 | 96 | # TODO(future): prob needs to be user controllable |
88 | 97 | new_state_dict[k + '_scale'] = v.scale.bfloat16() |
| 98 | + |
| 99 | + elif type(v) == NVFP4Tensor: |
| 100 | + # example checkpoint format: https://www.internalfb.com/phabricator/paste/view/P1981272933 |
| 101 | + |
| 102 | + # torchao does not support nvfp4 activation calibration yet, |
| 103 | + # set activation global scale to 1.0 |
| 104 | + new_state_dict[k.replace('weight', 'input_global_scale')] = torch.tensor([1.0]) |
| 105 | + # torchao does not support fusion-aware nvfp4 weight global scale yet, |
| 106 | + # set weight scale to 1.0 |
| 107 | + new_state_dict[k.replace('weight', 'weight_global_scale')] = torch.tensor([1.0]) |
| 108 | + new_state_dict[k + '_packed'] = v.qdata.view(torch.uint8) |
| 109 | + # compressed-tensors stores the nvfp4 scale in row-major format, |
| 110 | + # convert from swizzled to row-major |
| 111 | + swizzled_scale = v._scale_e4m3 |
| 112 | + original_rows = v.qdata.shape[0] |
| 113 | + # multiply by 2 to undo the packing, then divide by nvfp4 block size of 16 |
| 114 | + original_cols = v.qdata.shape[1] * 2 // 16 |
| 115 | + # TODO(future) also do the padding calculation here and remove the |
| 116 | + # assertions |
| 117 | + assert original_rows % 128 == 0, "unsupported" |
| 118 | + assert original_cols % 4 == 0, "unsupported" |
| 119 | + # import pdb; pdb.set_trace() |
| 120 | + row_major_scale = from_blocked( |
| 121 | + swizzled_scale, |
| 122 | + original_rows, |
| 123 | + original_cols, |
| 124 | + ) |
| 125 | + new_state_dict[k + '_scale'] = row_major_scale |
| 126 | + |
89 | 127 | else: |
90 | 128 | raise AssertionError(f'unsupported type {type(v)}') |
91 | 129 | save_file(new_state_dict, safetensors_statedict_filename) |
@@ -145,3 +183,65 @@ def convert_pt_multifile_index_to_safetensors( |
145 | 183 | # print(json.dumps(source_mapping, indent=2)) |
146 | 184 | with open(target_filename, 'w') as f: |
147 | 185 | json.dump(source_mapping, f, indent=2) |
| 186 | + |
| 187 | + |
| 188 | +def ao_config_to_compressed_tensors_config(aobaseconfig: AOBaseConfig) -> Dict[str, Any]: |
| 189 | + # for now, allowlist of recipes we know how to convert and hand convert |
| 190 | + # them here |
| 191 | + # for a production version, we'll need a more scalable way to do this |
| 192 | + |
| 193 | + if isinstance(aobaseconfig, Float8DynamicActivationFloat8WeightConfig): |
| 194 | + assert aobaseconfig.granularity == [PerRow(), PerRow()], "unsupported" |
| 195 | + |
| 196 | + ct_config = { |
| 197 | + "format": "float-quantized", |
| 198 | + "input_activations": { |
| 199 | + "dynamic": True, |
| 200 | + "num_bits": 8, |
| 201 | + "strategy": "token", |
| 202 | + "symmetric": True, |
| 203 | + "type": "float", |
| 204 | + }, |
| 205 | + "output_activations": None, |
| 206 | + "targets": ["Linear"], |
| 207 | + "weights": { |
| 208 | + "dynamic": False, |
| 209 | + "num_bits": 8, |
| 210 | + "observer": "minmax", |
| 211 | + "strategy": "channel", |
| 212 | + "symmetric": True, |
| 213 | + "type": "float", |
| 214 | + }, |
| 215 | + } |
| 216 | + |
| 217 | + elif isinstance(aobaseconfig, NVFP4InferenceConfig): |
| 218 | + |
| 219 | + ct_config = { |
| 220 | + "format": "nvfp4-pack-quantized", |
| 221 | + "input_activations": { |
| 222 | + "dynamic": "local", |
| 223 | + "group_size": 16, |
| 224 | + "num_bits": 4, |
| 225 | + "observer": "minmax", |
| 226 | + "observer_kwargs": {}, |
| 227 | + "strategy": "tensor_group", |
| 228 | + "symmetric": True, |
| 229 | + "type": "float", |
| 230 | + }, |
| 231 | + "output_activations": None, |
| 232 | + "targets": ["Linear"], |
| 233 | + "weights": { |
| 234 | + "dynamic": False, |
| 235 | + "group_size": 16, |
| 236 | + "num_bits": 4, |
| 237 | + "observer": "minmax", |
| 238 | + "observer_kwargs": {}, |
| 239 | + "strategy": "tensor_group", |
| 240 | + "symmetric": True, |
| 241 | + "type": "float" |
| 242 | + }, |
| 243 | + } |
| 244 | + |
| 245 | + else: |
| 246 | + raise AssertionError(f"unsupported type {type(aobaseconfig)}") |
| 247 | + return ct_config |
0 commit comments