Skip to content

Commit 2f903f8

Browse files
[mxfp8 moe training][BE] add docs showing equivalent convergence to bf16 at scale (#3312)
1 parent 86af458 commit 2f903f8

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

docs/static/mxfp8_with_loss.png

45.9 KB
Loading

torchao/prototype/moe_training/README.md

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,36 @@ This prototype provides:
66
- Using MXFP8 on a B200 GPU, this provides:
77
- **~1.4x - 1.8x speedups** over bfloat16 `torch._grouped_mm` for Llama4 Scout shapes
88
- **~1.19 - 1.6x speedups** over bfloat16 `torch._grouped_mm` for DeepSeekV3 671b shapes
9-
9+
- These benchmarks use `seq_len=8192`, `local_batch_size=16` (so `total_M = 8192 * 16 = 131,072`). We recommend using a large `total_M` dim to maximize speedup. See [benchmarks](#microbenchmarks) for more details.
1010

1111

1212
2. [TorchTitan](https://github.com/pytorch/torchtitan/tree/main) integration: pretrain DeepSeekV3/Llama4 with MXFP8 grouped GEMMs by adding the flag to your training command: `--model.converters="quantize.grouped_mm.mx" --quantize.grouped_mm.mx.fqns="experts"`
1313

1414
3. Model conversion API to swap all `torch._grouped_mm` ops in your model definition to use torchao `_quantize_then_scaled_grouped_mm` under the hood (see [example](#model-conversion-api-example-end-to-end-training) below).
1515

1616

17+
## Equivalent convergence to bfloat16 training baseline
18+
19+
Training runs on 64 node GB200 cluster with TorchTitan Llama4 Scout show that MXFP8 MoE training has equivalent convergence to bfloat16 training baseline. Infact, after 3,000 steps it finishes with slightly *lower* loss than bfloat16! This is consistent with our scaling experiments with [MXFP8 training for dense models](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/).
20+
21+
<img alt="Image" src="../../../docs/static/mxfp8_with_loss.png" />
22+
23+
Training and model configurations for this run:
24+
- Model: Llama4 Scout
25+
- Dataset: C4
26+
- Sequence length: 8192
27+
- Local batch size: 1
28+
- Learning rate: 1e-4
29+
- LR scheduler warmup steps: 2000
30+
- Parallelisms (64 nodes of 4 devices each = 256 chips):
31+
- FSDP=256 (on attention layers, shared experts, dense layer FFNs) and 256/4=64 (on routed experts)
32+
- EP=16 (on routed experts)
33+
- Activation checkpointing mode: `none` (ideally this should use selective per op AC but there was a bug at the time preventing us from using it).
34+
- `torch.compile` enabled
35+
- `mxfp8` applied to routed experts computation (grouped GEMMs)
36+
- `mxfp8` applied to all linear layers except: `output`, `router.gate`, `attention.wk`, `attention.wv` (Wk and Wv too small to benefit from mxfp8)
37+
38+
1739
## Table of Contents
1840

1941
- [Examples](#examples)

0 commit comments

Comments
 (0)