Skip to content

Commit ad533f5

Browse files
authored
Update QAT README before PTC (#3214)
Update QAT README
1 parent 7e5d907 commit ad533f5

File tree

2 files changed

+117
-39
lines changed

2 files changed

+117
-39
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
### PyTorch-Native Training-to-Serving Model Optimization
88
- Pre-train Llama-3.1-70B **1.5x faster** with float8 training
9-
- Recover **77% of quantized perplexity degradation** on Llama-3.2-3B with QAT
9+
- Recover **67% of quantized accuracy degradation** on Gemma3-4B with QAT
1010
- Quantize Llama-3-8B to int4 for **1.89x faster** inference with **58% less memory**
1111

1212
<div align="center">
@@ -106,6 +106,7 @@ Please see the [torchao compability table](https://github.com/pytorch/ao/issues/
106106

107107
TorchAO is integrated into some of the leading open-source libraries including:
108108

109+
* Unsloth for QAT, blog post coming soon!
109110
* HuggingFace transformers with a [builtin inference backend](https://huggingface.co/docs/transformers/main/quantization/torchao) and [low bit optimizers](https://github.com/huggingface/transformers/pull/31865)
110111
* HuggingFace diffusers best practices with `torch.compile` and TorchAO in a standalone repo [diffusers-torchao](https://github.com/huggingface/diffusers/blob/main/docs/source/en/quantization/torchao.md)
111112
* vLLM for LLM serving: [usage](https://docs.vllm.ai/en/latest/features/quantization/torchao.html), [detailed docs](https://docs.pytorch.org/ao/main/torchao_vllm_integration.html)

torchao/quantization/qat/README.md

Lines changed: 115 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ quantize_(m, qat_config, filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding
142142
```
143143

144144

145-
### Quantizer API (legacy)
145+
<details>
146+
<summary><h3>Quantizer API (legacy)</h3></summary>
146147

147148
Alternatively, torchao provides a few hardcoded quantization settings through
148149
the following Quantizers, but these may be removed soon:
@@ -191,8 +192,51 @@ model = qat_quantizer.prepare(model)
191192
train_loop(model)
192193
model = qat_quantizer.convert(model)
193194
```
195+
</details>
194196

195-
## torchtune integration
197+
## Axolotl integration
198+
199+
[Axolotl](https://github.com/axolotl-ai-cloud) uses TorchAO to support quantized-aware fine-tuning. You can use the following commands to fine-tune, and then quantize a Llama-3.2-3B model:
200+
201+
```bash
202+
axolotl train examples/llama-3/3b-qat-fsdp2.yaml
203+
# once training is complete, perform the quantization step
204+
axolotl quantize examples/llama-3/3b-qat-fsdp2.yaml
205+
# you should now have a quantized model saved in ./outputs/qat_out/quatized
206+
```
207+
208+
Please see the [QAT documentation](https://docs.axolotl.ai/docs/qat.html) in axolotl for more details.
209+
210+
211+
## Unsloth integration
212+
213+
[Unsloth](https://github.com/unslothai/unsloth) also leverages TorchAO for quantized-aware fine-tuning. Unsloth's QAT support can be used with both full and LoRA fine-tuning. For example:
214+
215+
```python
216+
from unsloth import FastLanguageModel
217+
218+
model, tokenizer = FastLanguageModel.from_pretrained(
219+
"unsloth/Qwen3-4B-Instruct-2507",
220+
max_seq_len = 2048,
221+
dtype = torch.bfloat16,
222+
load_in_4bit = False,
223+
full_finetuning = False,
224+
)
225+
226+
model = FastLanguageModel.get_peft_model(
227+
model,
228+
r = 16,
229+
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",],
230+
lora_alpha = 16,
231+
qat_scheme = "int4",
232+
)
233+
```
234+
235+
For a full notebook example, see: https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(14B)-Reasoning-Conversational.ipynb. A QAT-specific notebook is coming soon.
236+
237+
238+
<details>
239+
<summary><h2>torchtune integration (legacy)</h2></summary>
196240

197241
torchao QAT is integrated with [torchtune](https://github.com/pytorch/torchtune)
198242
to allow users to run quantized-aware fine-tuning as follows:
@@ -210,47 +254,80 @@ tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config ll
210254
```
211255

212256
For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html).
257+
</details>
213258

214-
## Axolotl integration
259+
## Evaluation Results
215260

216-
[Axolotl](https://github.com/axolotl-ai-cloud) uses torchao to support quantized-aware fine-tuning. You can use the following commands to fine-tune, and then quantize a Llama-3.2-3B model:
261+
Int4 weight-only QAT + LoRA using a group size of 128, fine-tuned using Unsloth.
262+
Both fine-tuning and evaluation was done on a single H100 GPU using the
263+
[mlabonne/FineTome-100k](https://huggingface.co/datasets/mlabonne/FineTome-100k)
264+
dataset. Learning rate was 2e-5 and batch size was 64 with no gradient accumulation.
217265

218-
```bash
219-
axolotl train examples/llama-3/3b-qat-fsdp2.yaml
220-
# once training is complete, perform the quantization step
221-
axolotl quantize examples/llama-3/3b-qat-fsdp2.yaml
222-
# you should now have a quantized model saved in ./outputs/qat_out/quatized
266+
```
267+
# gemma3-12b-it
268+
+-------------+-----------------+-----------------+------------+-------------+
269+
| Eval task | bf16 baseline | int4 baseline | int4 QAT | recovered |
270+
+=============+=================+=================+============+=============+
271+
| wikitext | 9.1477 | 9.7745 | 9.5631 | 33.727% |
272+
+-------------+-----------------+-----------------+------------+-------------+
273+
| bbh | 0.8079 | 0.7624 | 0.7831 | 45.495% |
274+
+-------------+-----------------+-----------------+------------+-------------+
275+
276+
# gemma3-4b-it
277+
+-------------+-----------------+-----------------+------------+-------------+
278+
| Eval task | bf16 baseline | int4 baseline | int4 QAT | recovered |
279+
+=============+=================+=================+============+=============+
280+
| wikitext | 12.1155 | 13.247 | 12.797 | 39.770% |
281+
+-------------+-----------------+-----------------+------------+-------------+
282+
| bbh | 0.7074 | 0.6415 | 0.6666 | 38.088 |
283+
+-------------+-----------------+-----------------+------------+-------------+
284+
| gpqa | 0.3232 | 0.3081 | 0.3182 | 66.887% |
285+
+-------------+-----------------+-----------------+------------+-------------+
286+
287+
# Qwen3-4B-Instruct
288+
+-------------+-----------------+-----------------+------------+-------------+
289+
| Eval task | bf16 baseline | int4 baseline | int4 QAT | recovered |
290+
+=============+=================+=================+============+=============+
291+
| mmlu-pro | 0.4909 | 0.4328 | 0.4524 | 33.735% |
292+
+-------------+-----------------+-----------------+------------+-------------+
293+
294+
# Llama3.2-3B
295+
+-------------+-----------------+-----------------+------------+-------------+
296+
| Eval task | bf16 baseline | int4 baseline | int4 QAT | recovered |
297+
+=============+=================+=================+============+=============+
298+
| wikitext | 12.1322 | 13.3459 | 12.8796 | 38.420% |
299+
+-------------+-----------------+-----------------+------------+-------------+
300+
| bbh | 0.5483 | 0.4967 | 0.5174 | 40.116% |
301+
+-------------+-----------------+-----------------+------------+-------------+
302+
| gpqa | 0.3333 | 0.2879 | 0.303 | 33.260% |
303+
+-------------+-----------------+-----------------+------------+-------------+
304+
| mmlu-pro | 0.2771 | 0.2562 | 0.2629 | 32.057% |
305+
+-------------+-----------------+-----------------+------------+-------------+
223306
```
224307

225-
Please see the [QAT documentation](https://docs.axolotl.ai/docs/qat.html) in axolotl for more details.
226-
227-
## Evaluation Results
308+
NVFP4 QAT full fine-tuning, fine-tuned using Axolotl on 8x B200 GPUs on the
309+
[yahma/alpaca-cleaned](https://huggingface.co/datasets/yahma/alpaca-cleaned)
310+
dataset. Learning rate was 2e-5 and batch size was 128 for `gemma3-12b-it`
311+
and 32 for `Qwen3-8B`.
228312

229-
Evaluation was performed on 6-8 A100 GPUs (80GB each) using the torchtune QAT
230-
integration described above. We fine-tune [Llama3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
231-
on the [C4 dataset](https://huggingface.co/datasets/allenai/c4) (en subset)
232-
for 5000 steps using a group size of 256 for the weights. Note that extensive
233-
hyperparameter tuning may further improve these results.
234-
235-
Results for int8 per token dynamic activations + int4 per group weights, using a learning rate of 2e-5:
236-
237-
| | hellaswag<br>(acc) | hellaswag<br>(acc_norm) | wikitext<br>(word_perplexity) | wikitext<br>(byte_perplexity) | wikitext<br>(bits_per_byte) |
238-
| ---------------- | ------ | ------ | ------ | ------ | ------ |
239-
| No quantization | 57.86% | 76.60% | 8.905 | 1.505 | 0.590 |
240-
| PTQ | 51.74% | 70.66% | 11.878 | 1.588 | 0.668 |
241-
| QAT (quantized) | 57.25% | 76.51% | 9.859 | 1.534 | 0.617 |
242-
| PTQ degradation | -6.11% | -5.94% | +2.973 | +0.083 | +0.078 |
243-
| QAT degradation | -0.61% | -0.21% | +0.947 | +0.029 | +0.027 |
244-
245-
Results for int4 per group weights, using a learning rate of 2e-6. For this quantization scheme, the
246-
quantized path uses the more efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097).
247-
248-
| | hellaswag<br>(acc) | hellaswag<br>(acc_norm) | wikitext<br>(word_perplexity) | wikitext<br>(byte_perplexity) | wikitext<br>(bits_per_byte) |
249-
| ---------------- | -------- | ------- | ------ | ------ | ------ |
250-
| No quantization | 57.16% | 77.02% | 8.858 | 1.504 | 0.589 |
251-
| PTQ | 55.06% | 74.24% | 10.311 | 1.547 | 0.630 |
252-
| QAT (quantized) | 55.86% | 75.06% | 10.134 | 1.542 | 0.625 |
253-
| PTQ degradation | -2.10% | -2.78% | +1.453 | +0.043 | +0.041 |
254-
| QAT degradation | -1.30% | -1.96% | +1.276 | +0.038 | +0.036 |
313+
```
314+
# gemma3-12b-it
315+
+-------------+-----------------+------------------+-------------+-------------+
316+
| Eval task | bf16 baseline | nvfp4 baseline | nvfp4 QAT | recovered |
317+
+=============+=================+==================+=============+=============+
318+
| bbh | 0.7527 | 0.7068 | 0.7222 | 33.551% |
319+
+-------------+-----------------+------------------+-------------+-------------+
320+
| mmlu-pro | 0.4074 | 0.3621 | 0.3702 | 17.881% |
321+
+-------------+-----------------+------------------+-------------+-------------+
322+
323+
# Qwen3-8B
324+
+-------------+-----------------+------------------+-------------+-------------+
325+
| Eval task | bf16 baseline | nvfp4 baseline | nvfp4 QAT | recovered |
326+
+=============+=================+==================+=============+=============+
327+
| bbh | 0.7771 | 0.7262 | 0.7397 | 26.523% |
328+
+-------------+-----------------+------------------+-------------+-------------+
329+
| mmlu-pro | 0.4929 | 0.4519 | 0.4686 | 40.732% |
330+
+-------------+-----------------+------------------+-------------+-------------+
331+
```
255332

256333
For more details, please refer to [this blog post](https://pytorch.org/blog/quantization-aware-training).

0 commit comments

Comments
 (0)