Skip to content

Commit e36230e

Browse files
authored
Update MXQuant doc (#2309)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
1 parent 24871ad commit e36230e

File tree

1 file changed

+36
-9
lines changed

1 file changed

+36
-9
lines changed

docs/source/3x/PT_MXQuant.md

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,32 +70,59 @@ Neural Compressor seamlessly applies the MX data type to post-training quantizat
7070
</left>
7171
</a>
7272

73-
The memory and computational limits of LLMs are more severe than other general neural networks, so our exploration focuses on LLMs first. The following table shows the basic MX quantization recipes in Neural Compressor and enumerates distinctions among various data types. The MX data type replaces general float scale with powers of two to be more hardware-friendly. It adapts a granularity falling between per-channel and per-tensor to balance accuracy and memory consumption.
73+
The memory and computational limits of LLMs are more severe than other general neural networks, so our exploration focuses on LLMs first. The following table shows the basic MX quantization recipes in Neural Compressor and enumerates distinctions among various data types. The MX data type replaces general float scale with powers of two to be more hardware-friendly.
7474

7575
| | MX Format | INT8 | FP8 |
7676
|------------|--------------|------------|------------|
7777
| Scale | $2^{exp}$ | $\frac{MAX}{amax}$ | $\frac{MAX}{amax}$ |
7878
| Zero point | 0 (None) | $2^{bits - 1}$ or $-min * scale$ | 0 (None) |
7979
| Granularity | per-block (default blocksize is 32) | per-channel or per-tensor | per-channel or per-tensor |
8080

81-
The exponent (exp) is equal to torch.floor(torch.log2(amax)), MAX is the representation range of the data type, amax is the max absolute value of per-block tensor, and rmin is the minimum value of the per-block tensor.
81+
The exponent (exp) is equal to clamp(floor(log2(amax)) - maxExp, -127, 127), MAX is the representation range of the data type, amax is the max absolute value of per-block tensor, and rmin is the minimum value of the per-block tensor.
8282

8383

8484
## Get Started with Microscaling Quantization API
8585

86-
To get a model quantized with Microscaling Data Types, users can use the Microscaling Quantization API as follows.
86+
To get a model quantized with Microscaling Data Types, users can use the AutoRound Quantization API as follows.
8787

8888
```python
89-
from neural_compressor.torch.quantization import MXQuantConfig, prepare, convert
90-
91-
quant_config = MXQuantConfig(w_dtype=args.w_dtype, act_dtype=args.act_dtype, weight_only=args.woq)
92-
user_model = prepare(model=user_model, quant_config=quant_config)
93-
user_model = convert(model=user_model)
89+
from neural_compressor.torch.quantization import AutoRoundConfig, prepare, convert
90+
from transformers import AutoModelForCausalLM, AutoTokenizer
91+
92+
fp32_model = AutoModelForCausalLM.from_pretrained(
93+
"facebook/opt-125m",
94+
device_map="auto",
95+
)
96+
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m", trust_remote_code=True)
97+
output_dir = "./saved_inc"
98+
99+
# quantization configuration
100+
quant_config = AutoRoundConfig(
101+
tokenizer=tokenizer,
102+
nsamples=32,
103+
seqlen=32,
104+
iters=20,
105+
scheme="MXFP4", # MXFP4, MXFP8
106+
export_format="auto_round",
107+
output_dir=output_dir, # default is "temp_auto_round"
108+
)
109+
110+
# quantize the model and save to output_dir
111+
model = prepare(model=fp32_model, quant_config=quant_config)
112+
model = convert(model)
113+
114+
# loading
115+
model = AutoModelForCausalLM.from_pretrained(output_dir, torch_dtype="auto", device_map="auto")
116+
117+
# inference
118+
text = "There is a girl who likes adventure,"
119+
inputs = tokenizer(text, return_tensors="pt").to(model.device)
120+
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=10)[0]))
94121
```
95122

96123
## Examples
97124

98-
- PyTorch [huggingface models](/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/mx_quant)
125+
- PyTorch [LLM/VLM models](/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4)
99126

100127

101128
## Reference

0 commit comments

Comments
 (0)