Skip to content

Commit 7e5d907

Browse files
authored
Update TorchAO README inference section before PTC (#3206)
Summary: att Test Plan: visual inspection Reviewers: Subscribers: Tasks: Tags:
1 parent d17d446 commit 7e5d907

File tree

4 files changed

+57
-98
lines changed

4 files changed

+57
-98
lines changed

README.md

Lines changed: 38 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
- [Oct 20] MXFP8 MoE training prototype achieved **~1.45x speedup** for MoE layer in Llama4 Scout, and **~1.25x** speedup for MoE layer in DeepSeekV3 671b - with comparable numerics to bfloat16! Check out the [docs](./torchao/prototype/moe_training/) to try it out.
2828
- [Sept 25] MXFP8 training achieved [1.28x speedup on Crusoe B200 cluster](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/) with virtually identical loss curve to bfloat16!
29+
- [Sept 19] [TorchAO Quantized Model and Quantization Recipes Now Available on Huggingface Hub](https://pytorch.org/blog/torchao-quantized-models-and-quantization-recipes-now-available-on-huggingface-hub/)!
2930
- [Jun 25] Our [TorchAO paper](https://openreview.net/attachment?id=HpqH0JakHf&name=pdf) was accepted to CodeML @ ICML 2025!
3031
- [May 25] QAT is now integrated into [Axolotl](https://github.com/axolotl-ai-cloud/axolotl) for fine-tuning ([docs](https://docs.axolotl.ai/docs/qat.html))!
3132
- [Apr 25] Float8 rowwise training yielded [1.34-1.43x training speedup](https://pytorch.org/blog/accelerating-large-scale-training-and-convergence-with-pytorch-float8-rowwise-on-crusoe-2k-h200s/) at 2k H100 GPU scale
@@ -59,13 +60,6 @@ TorchAO is an easy to use quantization library for native PyTorch. TorchAO works
5960

6061
Check out our [docs](https://docs.pytorch.org/ao/main/) for more details!
6162

62-
From the team that brought you the fast series:
63-
* 9.5x inference speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai)
64-
* 10x inference speedups for Language models with [gpt-fast](https://pytorch.org/blog/accelerating-generative-ai-2)
65-
* 3x inference speedup for Diffusion models with [sd-fast](https://pytorch.org/blog/accelerating-generative-ai-3) (new: [flux-fast](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/))
66-
* 2.7x inference speedup for FAIR’s Seamless M4T-v2 model with [seamlessv2-fast](https://pytorch.org/blog/accelerating-generative-ai-4/)
67-
68-
6963
## 🚀 Quick Start
7064

7165
First, install TorchAO. We recommend installing the latest stable version:
@@ -76,20 +70,9 @@ pip install torchao
7670
Quantize your model weights to int4!
7771
```python
7872
from torchao.quantization import Int4WeightOnlyConfig, quantize_
79-
quantize_(model, Int4WeightOnlyConfig(group_size=32, version=1))
80-
```
81-
Compared to a `torch.compiled` bf16 baseline, your quantized model should be significantly smaller and faster on a single A100 GPU:
82-
```bash
83-
int4 model size: 1.25 MB
84-
bfloat16 model size: 4.00 MB
85-
compression ratio: 3.2
86-
87-
bf16 mean time: 30.393 ms
88-
int4 mean time: 4.410 ms
89-
speedup: 6.9x
73+
quantize_(model, Int4WeightOnlyConfig(group_size=32, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq"))
9074
```
91-
See our [quick start guide](https://docs.pytorch.org/ao/stable/quick_start.html) for more details. Alternatively, try quantizing your favorite model using our [HuggingFace space](https://huggingface.co/spaces/pytorch/torchao-my-repo)!
92-
75+
See our [quick start guide](https://docs.pytorch.org/ao/stable/quick_start.html) for more details.
9376

9477
## 🛠 Installation
9578

@@ -103,16 +86,18 @@ pip install torchao
10386

10487
```
10588
# Nightly
106-
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126
89+
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu128
10790
10891
# Different CUDA versions
10992
pip install torchao --index-url https://download.pytorch.org/whl/cu126 # CUDA 12.6
93+
pip install torchao --index-url https://download.pytorch.org/whl/cu129 # CUDA 12.9
11094
pip install torchao --index-url https://download.pytorch.org/whl/cpu # CPU only
11195
11296
# For developers
11397
USE_CUDA=1 python setup.py develop
11498
USE_CPP=0 python setup.py develop
11599
```
100+
116101
</details>
117102

118103
Please see the [torchao compability table](https://github.com/pytorch/ao/issues/2919) for version requirements for dependencies.
@@ -123,57 +108,64 @@ TorchAO is integrated into some of the leading open-source libraries including:
123108

124109
* HuggingFace transformers with a [builtin inference backend](https://huggingface.co/docs/transformers/main/quantization/torchao) and [low bit optimizers](https://github.com/huggingface/transformers/pull/31865)
125110
* HuggingFace diffusers best practices with `torch.compile` and TorchAO in a standalone repo [diffusers-torchao](https://github.com/huggingface/diffusers/blob/main/docs/source/en/quantization/torchao.md)
111+
* vLLM for LLM serving: [usage](https://docs.vllm.ai/en/latest/features/quantization/torchao.html), [detailed docs](https://docs.pytorch.org/ao/main/torchao_vllm_integration.html)
112+
* Integration with [FBGEMM](https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gen_ai) for SOTA kernels on server GPUs
113+
* Integration with [ExecuTorch](https://github.com/pytorch/executorch/) for edge device deployment
114+
* Axolotl for [QAT](https://docs.axolotl.ai/docs/qat.html) and [PTQ](https://docs.axolotl.ai/docs/quantize.html)
115+
* TorchTitan for [float8 pre-training](https://github.com/pytorch/torchtitan/blob/main/docs/float8.md)
126116
* HuggingFace PEFT for LoRA using TorchAO as their [quantization backend](https://huggingface.co/docs/peft/en/developer_guides/quantization#torchao-pytorch-architecture-optimization)
127-
* Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference)
128117
* TorchTune for our NF4 [QLoRA](https://docs.pytorch.org/torchtune/main/tutorials/qlora_finetune.html), [QAT](https://docs.pytorch.org/torchtune/main/recipes/qat_distributed.html), and [float8 quantized fine-tuning](https://github.com/pytorch/torchtune/pull/2546) recipes
129-
* TorchTitan for [float8 pre-training](https://github.com/pytorch/torchtitan/blob/main/docs/float8.md)
130-
* VLLM for LLM serving: [usage](https://docs.vllm.ai/en/latest/features/quantization/torchao.html), [detailed docs](https://docs.pytorch.org/ao/main/torchao_vllm_integration.html)
131-
* SGLang for LLM serving: [usage](https://docs.sglang.ai/backend/server_arguments.html#server-arguments) and the major [PR](https://github.com/sgl-project/sglang/pull/1341).
132-
* Axolotl for [QAT](https://docs.axolotl.ai/docs/qat.html) and [PTQ](https://docs.axolotl.ai/docs/quantize.html)
133-
118+
* SGLang for LLM serving: [usage](https://docs.sglang.ai/advanced_features/quantization.html#online-quantization)
134119

135120
## 🔎 Inference
136121

137122
TorchAO delivers substantial performance gains with minimal code changes:
138123

139-
- **Int4 weight-only**: [1.89x throughput with 58.1% less memory](torchao/quantization/README.md) on Llama-3-8B
140-
- **Float8 dynamic quantization**: [1.54x and 1.27x speedup on Flux.1-Dev* and CogVideoX-5b respectively](https://github.com/sayakpaul/diffusers-torchao) on H100 with preserved quality
124+
- **Int4 weight-only**: [1.73x speedup with 65% less memory](https://huggingface.co/pytorch/gemma-3-12b-it-INT4) for Gemma3-12b-it on H100 with slight impact on accuracy
125+
- **Float8 dynamic quantization**: [1.5-1.6x speedup on gemma-3-27b-it](https://huggingface.co/pytorch/gemma-3-27b-it-FP8/blob/main/README.md#results-h100-machine) and [1.54x and 1.27x speedup on Flux.1-Dev* and CogVideoX-5b respectively](https://github.com/sayakpaul/diffusers-torchao) on H100 with preserved quality
126+
- **Int8 activation quantization and int4 weight quantization**: Quantized Qwen3-4B running with 14.8 tokens/s with 3379 MB memory usage on iPhone 15 Pro through [ExecuTorch](https://huggingface.co/pytorch/Qwen3-4B-INT8-INT4#running-in-a-mobile-app)
141127
- **Int4 + 2:4 Sparsity**: [2.37x throughput with 67.7% memory reduction](torchao/sparsity/README.md) on Llama-3-8B
142128

143-
Quantize any model with `nn.Linear` layers in just one line (Option 1), or load the quantized model directly from HuggingFace using our integration with HuggingFace transformers (Option 2):
144-
145-
#### Option 1: Direct TorchAO API
146-
147-
```python
148-
from torchao.quantization.quant_api import quantize_, Int4WeightOnlyConfig
149-
quantize_(model, Int4WeightOnlyConfig(group_size=128, use_hqq=True, version=1))
150-
```
151-
152-
#### Option 2: HuggingFace Integration
153-
129+
Following is our recommended flow for quantization and deployment:
154130
```python
155131
from transformers import TorchAoConfig, AutoModelForCausalLM
156-
from torchao.quantization.quant_api import Int4WeightOnlyConfig
132+
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow
157133

158134
# Create quantization configuration
159-
quantization_config = TorchAoConfig(quant_type=Int4WeightOnlyConfig(group_size=128, use_hqq=True, version=1))
135+
quantization_config = TorchAoConfig(quant_type=Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
160136

161137
# Load and automatically quantize
162138
quantized_model = AutoModelForCausalLM.from_pretrained(
163-
"microsoft/Phi-4-mini-instruct",
139+
"Qwen/Qwen3-32B",
164140
dtype="auto",
165141
device_map="auto",
166142
quantization_config=quantization_config
167143
)
168144
```
169145

170-
#### Deploy quantized models in vLLM with one command:
146+
Alternative quantization API to use when the above doesn't work is `quantize_` API in [quick start guide](https://docs.pytorch.org/ao/main/quick_start.html).
147+
148+
Serving with vllm on 1xH100 machine:
149+
```shell
150+
# Server
151+
VLLM_DISABLE_COMPILE_CACHE=1 vllm serve pytorch/Qwen3-32B-FP8 --tokenizer Qwen/Qwen3-32B -O3
152+
```
171153

172154
```shell
173-
vllm serve pytorch/Phi-4-mini-instruct-int4wo-hqq --tokenizer microsoft/Phi-4-mini-instruct -O3
155+
# Client
156+
curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
157+
"model": "pytorch/Qwen3-32B-FP8",
158+
"messages": [
159+
{"role": "user", "content": "Give me a short introduction to large language models."}
160+
],
161+
"temperature": 0.6,
162+
"top_p": 0.95,
163+
"top_k": 20,
164+
"max_tokens": 32768
165+
}'
174166
```
175167

176-
With this quantization flow, we achieve **67% VRAM reduction and 12-20% speedup** on A100 GPUs while maintaining model quality. For more detail, see this [step-by-step quantization guide](https://huggingface.co/pytorch/Phi-4-mini-instruct-int4wo-hqq#quantization-recipe). We also release some pre-quantized models [here](https://huggingface.co/pytorch).
168+
We also support deployment to edge devices through ExecuTorch, for more detail, see [quantization and serving guide](https://docs.pytorch.org/ao/main/serving.html). We also release pre-quantized models [here](https://huggingface.co/pytorch).
177169

178170
## 🚅 Training
179171

docs/source/api_ref_quantization.rst

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ Main Quantization APIs
1414
:nosignatures:
1515

1616
quantize_
17-
autoquant
1817

1918
Inference APIs for quantize\_
2019
-------------------------------
@@ -27,13 +26,9 @@ Inference APIs for quantize\_
2726
Float8DynamicActivationInt4WeightConfig
2827
Float8DynamicActivationFloat8WeightConfig
2928
Float8WeightOnlyConfig
30-
Float8StaticActivationFloat8WeightConfig
3129
Int8DynamicActivationInt4WeightConfig
32-
GemliteUIntXWeightOnlyConfig
3330
Int8WeightOnlyConfig
3431
Int8DynamicActivationInt8WeightConfig
35-
UIntXWeightOnlyConfig
36-
FPXWeightOnlyConfig
3732

3833
.. currentmodule:: torchao.quantization
3934

@@ -51,19 +46,4 @@ Quantization Primitives
5146
safe_int_mm
5247
int_scaled_matmul
5348
MappingType
54-
ZeroPointDomain
5549
TorchAODType
56-
57-
..
58-
TODO: delete these?
59-
60-
Other
61-
-----
62-
63-
.. autosummary::
64-
:toctree: generated/
65-
:nosignatures:
66-
67-
to_linear_activation_quantized
68-
swap_linear_with_smooth_fq_linear
69-
smooth_fq_linear_to_inference

docs/source/quick_start.rst

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,8 @@ Quick Start Guide
22
-----------------
33

44
In this quick start guide, we will explore how to perform basic quantization using torchao.
5-
First, install the latest stable torchao release::
6-
7-
pip install torchao
8-
9-
If you prefer to use the nightly release, you can install torchao using the following
10-
command instead::
11-
12-
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121
13-
14-
torchao is compatible with the latest 3 major versions of PyTorch, which you will also
15-
need to install (`detailed instructions <https://pytorch.org/get-started/locally/>`__)::
16-
17-
pip install torch
185

6+
Follow `torchao installation and compatibility guide <https://github.com/pytorch/ao#-installation>`__ to install torchao and compatible pytorch.
197

208
First Quantization Example
219
==========================
@@ -55,9 +43,8 @@ for efficient mixed dtype matrix multiplication:
5543

5644
.. code:: py
5745
58-
# torch 2.4+ only
5946
from torchao.quantization import Int4WeightOnlyConfig, quantize_
60-
quantize_(model, Int4WeightOnlyConfig(group_size=32, version=1))
47+
quantize_(model, Int4WeightOnlyConfig(group_size=32, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq"))
6148
6249
The quantized model is now ready to use! Note that the quantization
6350
logic is inserted through tensor subclasses, so there is no change

docs/source/serving.rst

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Post-training Quantization with HuggingFace
1515
-------------------------------------------
1616

1717
HuggingFace Transformers provides seamless integration with torchao quantization. The ``TorchAoConfig`` automatically applies torchao's optimized quantization algorithms during model loading.
18-
Please check out our `HF Integration Docs <torchao_hf_integration.html>`_ for examples on how to use quantization and sparsity in Transformers and Diffusers.
18+
Please check out our `HF Integration Docs <torchao_hf_integration.html>`_ for examples on how to use quantization and sparsity in Transformers and Diffusers and `TorchAOConfig Reference <api_ref_quantization.html#inference-apis-for-quantize>`_ for all available torchao configs to use.
1919

2020
Serving and Inference
2121
--------------------
@@ -29,19 +29,19 @@ First, install vLLM with torchao support:
2929

3030
.. code-block:: bash
3131
32-
pip install vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
33-
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126
32+
pip install vllm --pre --extra-index-url https://download.pytorch.org/whl/nightly/vllm/
33+
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu128
3434
3535
To serve in vLLM, we're using the model we quantized and pushed to Hugging Face hub in the previous step :ref:`Post-training Quantization with HuggingFace`.
3636

3737
.. code-block:: bash
3838
3939
# Server
40-
vllm serve pytorch/Phi-4-mini-instruct-float8dq --tokenizer microsoft/Phi-4-mini-instruct -O3
40+
vllm serve pytorch/Phi-4-mini-instruct-FP8 --tokenizer microsoft/Phi-4-mini-instruct -O3
4141
4242
# Client
4343
curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
44-
"model": "pytorch/Phi-4-mini-instruct-float8dq",
44+
"model": "pytorch/Phi-4-mini-instruct-FP8",
4545
"messages": [
4646
{"role": "user", "content": "Give me a short introduction to large language models."}
4747
],
@@ -271,8 +271,8 @@ Evaluate quantized models using lm-evaluation-harness:
271271
# Evaluate baseline model
272272
lm_eval --model hf --model_args pretrained=microsoft/Phi-4-mini-instruct --tasks hellaswag --device cuda:0 --batch_size 8
273273
274-
# Evaluate torchao-quantized model (float8dq)
275-
lm_eval --model hf --model_args pretrained=pytorch/Phi-4-mini-instruct-float8dq --tasks hellaswag --device cuda:0 --batch_size 8
274+
# Evaluate torchao-quantized model (FP8)
275+
lm_eval --model hf --model_args pretrained=pytorch/Phi-4-mini-instruct-FP8 --tasks hellaswag --device cuda:0 --batch_size 8
276276
277277
Memory Benchmarking
278278
^^^^^^^^^^^^^^^^^
@@ -283,8 +283,8 @@ For Phi-4-mini-instruct, when quantized with float8 dynamic quant, we can reduce
283283
import torch
284284
from transformers import AutoModelForCausalLM, AutoTokenizer
285285
286-
# use "microsoft/Phi-4-mini-instruct" or "pytorch/Phi-4-mini-instruct-float8dq"
287-
model_id = "pytorch/Phi-4-mini-instruct-float8dq"
286+
# use "microsoft/Phi-4-mini-instruct" or "pytorch/Phi-4-mini-instruct-FP8"
287+
model_id = "pytorch/Phi-4-mini-instruct-FP8"
288288
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", dtype=torch.bfloat16)
289289
tokenizer = AutoTokenizer.from_pretrained(model_id)
290290
@@ -328,7 +328,7 @@ Output:
328328
Peak Memory Usage: 5.70 GB
329329
330330
+-------------------+---------------------+------------------------------+
331-
| Benchmark | Phi-4 mini-instruct | Phi-4-mini-instruct-float8dq |
331+
| Benchmark | Phi-4 mini-instruct | Phi-4-mini-instruct-FP8 |
332332
+===================+=====================+==============================+
333333
| Peak Memory (GB) | 8.91 | 5.70 (36% reduction) |
334334
+-------------------+---------------------+------------------------------+
@@ -342,10 +342,10 @@ Latency Benchmarking
342342
.. code-block:: bash
343343
344344
# baseline
345-
python benchmarks/benchmark_latency.py --input-len 256 --output-len 256 --model microsoft/Phi-4-mini-instruct --batch-size 1
345+
vllm bench latency --input-len 256 --output-len 256 --model microsoft/Phi-4-mini-instruct --batch-size 1
346346
347-
# float8dq
348-
VLLM_DISABLE_COMPILE_CACHE=1 python benchmarks/benchmark_latency.py --input-len 256 --output-len 256 --model pytorch/Phi-4-mini-instruct-float8dq --batch-size 1
347+
# FP8
348+
VLLM_DISABLE_COMPILE_CACHE=1 vllm bench latency --input-len 256 --output-len 256 --model pytorch/Phi-4-mini-instruct-FP8 --batch-size 1
349349
350350
Serving Benchmarking
351351
"""""""""""""""""""""
@@ -372,13 +372,13 @@ We benchmarked the throughput in a serving environment.
372372
# Server:
373373
vllm serve microsoft/Phi-4-mini-instruct --tokenizer microsoft/Phi-4-mini-instruct -O3
374374
# Client:
375-
python benchmarks/benchmark_serving.py --backend vllm --dataset-name sharegpt --tokenizer microsoft/Phi-4-mini-instruct --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json --model microsoft/Phi-4-mini-instruct --num-prompts 1
375+
vllm bench serve --backend vllm --dataset-name sharegpt --tokenizer microsoft/Phi-4-mini-instruct --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json --model microsoft/Phi-4-mini-instruct --num-prompts 1
376376
377-
# For float8dq
377+
# For FP8
378378
# Server:
379-
VLLM_DISABLE_COMPILE_CACHE=1 vllm serve pytorch/Phi-4-mini-instruct-float8dq --tokenizer microsoft/Phi-4-mini-instruct -O3
379+
VLLM_DISABLE_COMPILE_CACHE=1 vllm serve pytorch/Phi-4-mini-instruct-FP8 --tokenizer microsoft/Phi-4-mini-instruct -O3
380380
# Client:
381-
python benchmarks/benchmark_serving.py --backend vllm --dataset-name sharegpt --tokenizer microsoft/Phi-4-mini-instruct --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json --model pytorch/Phi-4-mini-instruct-float8dq --num-prompts 1
381+
vllm bench serve --backend vllm --dataset-name sharegpt --tokenizer microsoft/Phi-4-mini-instruct --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json --model pytorch/Phi-4-mini-instruct-FP8 --num-prompts 1
382382
383383
Results (H100 machine)
384384
"""""""""""""""""""""

0 commit comments

Comments
 (0)