Skip to content

Commit d208768

Browse files
authored
mx_formats: Update README.md (#3210)
1 parent 94dee9c commit d208768

File tree

1 file changed

+38
-15
lines changed

1 file changed

+38
-15
lines changed

torchao/prototype/mx_formats/README.md

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,42 @@
11
# MX training and inference with native PyTorch
22

3-
This is a workflow for e2e training and inference with MX dtypes from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
4-
in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 hardware.
3+
e2e training and inference with mxfp8, mxfp4, nvfp4 formats from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
4+
in native PyTorch.
5+
6+
> :warning: We are currently in prototype. Use nightly versions of PyTorch and torchao (or build from source) for best results.
57
68
## Overall status
79

8-
| workflow | emulation | performance | accuracy |
9-
| --- | --- | --- | --- |
10-
| training with mxfp8 ||||
11-
| inference with mxfp8, mxfp6, mxfp4 || 🔲 | 🔲 |
10+
### mxfp8
11+
12+
| workflow | emulation | performance | accuracy | API polish |
13+
| --- | --- | --- | --- | --- |
14+
| training for `torch.nn.Linear` || 🟡 / 🟢 | 🟢 | 🟡 |
15+
| inference for `torch.nn.Linear` || 🟡 / 🟢 | 🟢 | 🟡 |
16+
17+
### nvfp4
18+
19+
| workflow | emulation | performance | accuracy | API polish |
20+
| --- | --- | --- | --- | --- |
21+
| training for `torch.nn.Linear` || 🔴 | 🟡 | 🟡 |
22+
| QAT for `torch.nn.Linear` || n/a | 🟢 | 🟡 |
23+
| inference for `torch.nn.Linear` || 🟡 / 🟢 | 🟢 | 🟡 |
24+
25+
### mxfp4
1226

13-
ℹ️ <em>See the [feature tracker](https://github.com/pytorch/ao/issues/556) and the [performance tracker](https://github.com/pytorch/ao/issues/1768) for upcoming features.</em>
27+
| workflow | emulation | performance | accuracy | API polish |
28+
| --- | --- | --- | --- | --- |
29+
| training for `torch.nn.Linear` || 🔴 | 🟡 | 🟡 |
30+
| QAT for `torch.nn.Linear` | planned | n/a | planned | planned |
31+
| inference for `torch.nn.Linear` || 🔴 | 🟢 | 🟡 |
32+
33+
### planned improvements
34+
35+
* mxfp8 support for grouped_gemm and all2all for MoE training (see https://github.com/pytorch/ao/tree/main/torchao/prototype/moe_training ).
36+
* mxfp8, nvfp4, mxfp4 performance optimizations for inference
37+
* polish the nvpf4 QAT recipe, and enable mxfp4 QAT
38+
* blocked formats for faster training
39+
* stochastic rounding and hadamard transforms for improved fp4 training numerics
1440

1541
## Training e2e benchmarks on NVIDIA B200
1642

@@ -42,6 +68,8 @@ including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=re
4268

4369
## MX training
4470

71+
Below is a toy training loop. For an example real training loop, see our torchtitan integration here: https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/quantization/mx.py .
72+
4573
```python
4674
import torch
4775
from torchao.quantization import quantize_
@@ -150,7 +178,7 @@ x_hp = x_mx.to_dtype(torch.float)
150178
## mxfp8 gemm
151179

152180
On NVIDIA B200 machines, we use the cuBLAS mxfp8 gemm exposed via the `torch._scaled_mm` op.
153-
We observe a speedup of **2x to 3x** vs the bf16 baseline on common shapes. To reproduce this
181+
We observe a speedup of **up to ~2x** vs the bf16 baseline on common shapes. To reproduce this
154182
on supported hardware, you can run the following command:
155183

156184
```bash
@@ -160,7 +188,7 @@ on supported hardware, you can run the following command:
160188

161189
## to_mx cast across dim0 and dim1
162190

163-
On NVIDIA B200 machines, our to_mx kernels for mxfp8 achieve **up to 5.5 TB/s** for the dim0 cast (with torch.compile),
191+
On NVIDIA B200 machines, our to_mx kernels for mxfp8 achieve **up to 6.3 TB/s** for the dim0 cast (with torch.compile),
164192
and **up to 3.9 TB/s** for the dim1 cast (with a triton kernel). We are actively working on improving
165193
the performance of this cast ([details](https://github.com/pytorch/ao/issues/1768)).
166194

@@ -176,16 +204,11 @@ To reproduce this on supported hardware, you can run the following command:
176204
// example output: https://gist.github.com/vkuzo/7ac5fce44c9b90bfb9eae2a07b721cda
177205
```
178206

179-
## performance tracker
180-
181-
Please see our [performance tracker](https://github.com/pytorch/ao/issues/1768) for the latest on MX training and inference performance!
182-
183207
# accuracy
184208

185209
## training
186210

187-
* LLaMa 3 8B pretraining on 4 GPUs for 500 iterations shows that loss convergence is not meaningfully degraded (code not in this repo)
188-
* we match bitwise to other implementations of the OCP MX spec (code not in this repo), with a couple of edge cases left to resolve
211+
* LLaMa 3 8B pretraining on 4 GPUs for 500 iterations shows that loss convergence is not meaningfully degraded (via torchtitan)
189212

190213
## inference
191214

0 commit comments

Comments
 (0)