You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: torchao/prototype/mx_formats/README.md
+38-15Lines changed: 38 additions & 15 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,16 +1,42 @@
1
1
# MX training and inference with native PyTorch
2
2
3
-
This is a workflow for e2e training and inference with MX dtypes from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
4
-
in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 hardware.
3
+
e2e training and inference with mxfp8, mxfp4, nvfp4 formats from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
4
+
in native PyTorch.
5
+
6
+
> :warning: We are currently in prototype. Use nightly versions of PyTorch and torchao (or build from source) for best results.
ℹ️ <em>See the [feature tracker](https://github.com/pytorch/ao/issues/556) and the [performance tracker](https://github.com/pytorch/ao/issues/1768) for upcoming features.</em>
* mxfp8 support for grouped_gemm and all2all for MoE training (see https://github.com/pytorch/ao/tree/main/torchao/prototype/moe_training ).
36
+
* mxfp8, nvfp4, mxfp4 performance optimizations for inference
37
+
* polish the nvpf4 QAT recipe, and enable mxfp4 QAT
38
+
* blocked formats for faster training
39
+
* stochastic rounding and hadamard transforms for improved fp4 training numerics
14
40
15
41
## Training e2e benchmarks on NVIDIA B200
16
42
@@ -42,6 +68,8 @@ including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=re
42
68
43
69
## MX training
44
70
71
+
Below is a toy training loop. For an example real training loop, see our torchtitan integration here: https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/quantization/mx.py .
0 commit comments