Skip to content

Commit cc4fc0c

Browse files
add mxfp8 qat (#2299)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1317187 commit cc4fc0c

File tree

10 files changed

+1023
-1
lines changed

10 files changed

+1023
-1
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pylint:disable=import-error
16+
"""QAT (Quantization Aware Tuning)."""
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) Microsoft Corporation.
5+
# Licensed under the MIT License.
6+
#
7+
# Copyright (c) 2025 Intel Corporation
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
"""Quantized Linear."""
21+
22+
23+
import torch
24+
import torch.nn as nn
25+
import torch.nn.functional as F
26+
27+
from .tensor_quantizer import TensorQuantizer
28+
29+
30+
class QuantLinear(nn.Module):
31+
"""Quantized version of nn.Linear."""
32+
33+
def forward(self, input: torch.Tensor):
34+
"""Add weight/input/output of quantization for the original forward method."""
35+
qw = self.weight_quantizer(self.weight)
36+
qi = self.input_quantizer(input)
37+
out = F.linear(qi, qw, self.bias)
38+
out = self.output_quantizer(out)
39+
return out
40+
41+
def _setup(self, quant_cfg):
42+
"""Init quantizer."""
43+
self.weight_quantizer = TensorQuantizer(
44+
data_type=quant_cfg.data_type,
45+
block_size=quant_cfg.group_size,
46+
bits=quant_cfg.bits,
47+
sym=quant_cfg.sym,
48+
if_quant=True,
49+
learn_exponent=False,
50+
)
51+
self.input_quantizer = TensorQuantizer(
52+
data_type=quant_cfg.act_data_type,
53+
block_size=quant_cfg.act_group_size,
54+
bits=quant_cfg.act_bits,
55+
sym=quant_cfg.act_sym,
56+
if_quant=True,
57+
learn_exponent=False,
58+
)
59+
self.output_quantizer = TensorQuantizer(
60+
data_type=quant_cfg.act_data_type,
61+
block_size=quant_cfg.act_group_size,
62+
bits=quant_cfg.act_bits,
63+
sym=quant_cfg.act_sym,
64+
if_quant=False,
65+
)
66+
# Currently don't quant output
67+
self.output_quantizer.disable()
68+
69+
# TODO: remove
70+
self.original_weight_dtype = None if self.weight is None else self.weight.dtype
71+
72+
def extra_repr(self) -> str:
73+
"""Generate extra_repr making sure import keys exist in self.__dict__."""
74+
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
75+
76+
def __repr__(self):
77+
"""Overriding the __repr__ method, makes the output more concise and meaningful."""
78+
return (
79+
f"QuantLinear(\n"
80+
f" in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}\n"
81+
f" (input_quantizer): {self.input_quantizer}\n"
82+
f" (output_quantizer): {self.output_quantizer}\n"
83+
f" (weight_quantizer): {self.weight_quantizer}\n"
84+
f")"
85+
)
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) Microsoft Corporation.
5+
# Licensed under the MIT License.
6+
#
7+
# Copyright (c) 2025 Intel Corporation
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
"""Utils for quantization."""
21+
22+
import types
23+
from typing import Any
24+
25+
import torch
26+
import torch.nn as nn
27+
28+
from .quant_linear import QuantLinear
29+
from .tensor_quantizer import TensorQuantizer
30+
31+
32+
def convert(module: nn.Module, quant_cfg=None, quant_module=None):
33+
"""Convert the model to a quantized one with quant config."""
34+
35+
# update class
36+
original_cls = type(module)
37+
module.__class__ = quant_module
38+
module.forward = types.MethodType(quant_module.forward, module)
39+
40+
# setup quantizers
41+
module._setup(quant_cfg)
42+
43+
return module
44+
45+
46+
def replace_with_quant_linear(model, quant_cfg=None):
47+
"""Recursively replace the module with quantized module."""
48+
49+
# TODO: support more modules, like kv.
50+
for name, child in model.named_children():
51+
if isinstance(child, nn.Linear):
52+
if "lm_head" in name:
53+
continue
54+
# REPLACE on the parent (model), not on child
55+
quantized = convert(child, quant_cfg, QuantLinear)
56+
setattr(model, name, quantized)
57+
58+
# now recurse into whichever module is now at `model.name`
59+
replace_with_quant_linear(getattr(model, name), quant_cfg=quant_cfg)
60+
61+
return model
62+
63+
64+
def get_quant_config_with_scheme(scheme: str):
65+
"""Get quantization config."""
66+
67+
try:
68+
# use scheme definitions from AutoRound since we utilize the quantization functions now
69+
from auto_round.schemes import preset_name_to_scheme
70+
71+
quant_cfg = preset_name_to_scheme(scheme)
72+
return quant_cfg
73+
except ImportError:
74+
return None
75+
76+
77+
def convert_model_with_mapping(model, mapping=None):
78+
"""Process mapping to quant config."""
79+
# key is torch module, TODO: support more key format, like layer name.
80+
for key in mapping:
81+
# TODO: support more torch modules
82+
if key == nn.Linear:
83+
quant_cfg = get_quant_config_with_scheme(mapping[key])
84+
if quant_cfg is None:
85+
continue
86+
replace_with_quant_linear(model, quant_cfg)
87+
88+
replaced_modules = sum(isinstance(m, TensorQuantizer) for _, m in model.named_modules())
89+
print(f"Inserted {replaced_modules} quantizers")
90+
91+
92+
def get_quant_config(scheme: str) -> dict[str, Any]:
93+
"""Generate quantization config for a torch model.
94+
95+
Args:
96+
model: The PyTorch model to analyze
97+
98+
Returns:
99+
Dictionary containing the quantization configuration
100+
"""
101+
102+
# TODO: support more quant config
103+
try:
104+
from auto_round.export.export_to_llmcompressor.config import initialize_quantization
105+
106+
quantization_config = initialize_quantization(scheme=scheme)
107+
quantization_config = quantization_config.to_dict()
108+
quantization_config["provider"] = "auto-round"
109+
quantization_config["config_groups"]["group_0"]["weights"]["is_mx"] = True
110+
quantization_config["config_groups"]["group_0"]["input_activations"]["is_mx"] = True
111+
112+
except ImportError:
113+
quantization_config = None
114+
115+
return quantization_config
116+
117+
118+
def get_quantization_format(module) -> str | None:
119+
"""Gets the quantization string.
120+
121+
Gets the quantization string by iterating through the module and its children.
122+
The first non-None quantization string is returned.
123+
"""
124+
125+
def _get_quantization_from_layer(layer):
126+
weight_quantizer = getattr(layer, "weight_quantizer", None)
127+
input_quantizer = getattr(layer, "input_quantizer", None)
128+
129+
if weight_quantizer is None or weight_quantizer._disabled:
130+
return None
131+
132+
# TODO: support more quant format
133+
if weight_quantizer.num_bits == 8 and weight_quantizer.data_type == "mx_fp8":
134+
return "MXFP8"
135+
136+
# Raise error for unsupported num_bits
137+
raise NotImplementedError(f"Unsupported quantizer with num_bits: {weight_quantizer.num_bits}")
138+
139+
quantization = _get_quantization_from_layer(module)
140+
if quantization is not None:
141+
return quantization
142+
143+
for _, layer in module.named_children():
144+
format = get_quantization_format(layer)
145+
if format is not None:
146+
return format
147+
148+
return None
149+
150+
151+
def is_quantlinear(module: nn.Module) -> bool:
152+
"""Returns whether the module is a quantized linear layer."""
153+
return "QuantLinear" in type(module).__name__

0 commit comments

Comments
 (0)