-
Notifications
You must be signed in to change notification settings - Fork 364
Open
Description
We want to support NVFP4 training in torchao for both dense and MoE models, following the recipe describe in this paper from NVIDIA.
Support for dense models (linears)
- Functionality
- Random hadamard transform for wgrad gemm inputs (16x16 block granularities)
- Rounding modes ([RFC] NVFP4 Rounding Modes #3264)
- Stochastic rounding for gradients
- Triton wrapping inline PTX e.g.,
cvt.rs.satfinite.e2m1x4.f32 d, {a, b, e, f}, rbits; // convert 4 fp32 values to packed 4 e2m1 values with applying .rs rounding) https://docs.nvidia.com/cuda/parallel-thread-execution/#rounding-modifiers
- Triton wrapping inline PTX e.g.,
- Round to nearest even (RTNE) for activations and weights (e.g.
*.rnmodifier)- Triton wrapping inline ptx with
.rnmodifier, same as above
- Triton wrapping inline ptx with
- Stochastic rounding for gradients
- Quantization
- 1x16 quantization for LHS activations (support RTNE and SR as it will be used for grads)
- 16x1 quantization for RHS activations (support RTNE - if we do
wgrad = grad_out.t() @ inputwe won't need SR for 16x1 scaling / RHS operand) - 16x16 quantization for weights, use RNTE
- 1x16 @ 16x16 scaled gemm
- Use for
output = input @ weight.t() - Use for
dgrad = grad_output @ weight
- Use for
- 1x16 @ 16x1 scaled gemm
- Use for
wgrad = grad_output.t() @ input
- Use for
- Autograd function implementing forward and backward with dynamic quant as described in the paper:
- Forward
- 1x16 quant with RTNE on 2d input
- 16x16 quant with RTNE on 2d weights
- 1x16 @ 16x16 scaled mm for output
- Backward
- dgrad
- 1x16 quant with SR on 2d upstream grad
- 16x16 quant on 2d weights with RTNE
- per group scale factor conversion to blocked swizzled format (2d grad)
- per group scale factor conversion to blocked swizzled format (2d weight)
- 1x16 @ 16x16 scaled mm for dgrad
- wgrad
- RHT on upstream grad (16x16 granularity)
- RHT on input activations (16x16 granularity)
- 1x16 quant with SR for 2d transposed upstream grad
- 16x1 quant with RTNE for 2d input activations
- 1x16 @ 16x1 scaled_mm for wgrad
- dgrad
- Forward
- NVFP4Linear wrapping autograd func (used for module swaps)
- DTensor handling for TP support
- Custom ops around all custom kernels for
torch.compilecomposability - Tests for FSDP, TP
-
quantize_model conversion api peforming module swap of nn.Linear to NVFP4Linear (wraps autograd func)
- Performance
- Random Hadamard transform runs at 80%+ peak achievable memory bandwidth on Blackwell, for all block granularities
- All quantization kernels run at 80%+ of peak achievable memory bandwidth on Blackwell
- benchmark scripts for each quantization kernel
- All gemm kernels run at 80%+ of peak achievable TFLOPs/sec on Blackwell
- benchmark scripts for each gemm
- Integration into torchtitan
- Validate loss convergence virtually identical to bf16 for 3k+ steps on full size Llama3 8b/70b
- Validate e2e throughput (TPS) improvement in same training run as above
- Documentation
- README
- torchao docsite
Support for MoE layers (grouped GEMMs)
We can extend the low precision MoE training code here to support NVFP4 by doing the following:
- Functionality
- Quantization
- 16x16 scaling quantization kernel for 3d expert weights (alternative: update existing kernel to be compatible and performant for 3d inputs)
- Triton kernel for per-group conversion to blocked swizzled format where groups are along M (can probably reuse the one we wrote for mxfp8 here)
- Triton kernel for per-group conversion to blocked swizzled format where groups are along the K/contracting dim (can probably reuse the ones we wrote for mxfp8 here)
- Triton kernel for per-group conversion to blocked swizzled format for 3d expert weights (can probably re-use one we wrote for mxfp8 here)
- 1x16 @ 16x16 scaled grouped gemm
- Use for
output = input @ weight.transpose(-2,-1) - Use for
dgrad = grad_output @ weight
- Use for
- 1x16 @ 16x1 scaled grouped gemm
- Use for
wgrad = grad_output.transpose(-2,-1) @ input
- Use for
- Autograd function implementing forward and backward with dynamic quant on inputs (see mxfp8 example)
- Forward
- 1x16 quant with RTNE on 2d input
- 16x16 quant with RTNE on 3d weights
- 1x16 @ 16x16 scaled mm for output
- Backward
- dgrad
- 1x16 quant with SR on 2d upstream grad
- 16x16 quant on 3d weights with RTNE
- per group scale factor conversion to blocked swizzled format (2d grad)
- per group scale factor conversion to blocked swizzled format (3d weight)
- 1x16 @ 16x16 scaled mm for dgrad
- wgrad
- RHT on upstream grad (16x16 granularity)
- RHT on input activations (16x16 granularity)
- 1x16 quant with SR for 2d transposed upstream grad
- 16x1 quant with RTNE for 2d input activations
- 1x16 @ 16x1 scaled_mm for wgrad
- dgrad
- Forward
- Custom ops around all custom kernels for
torch.compilecomposability - Tests for FSDP, TP, EP
- Quantization
- Performance
- all quantization kernels run at 80%+ of peak achievable memory bandwidth on Hopper
- benchmark scripts for each quantization kernel
- all gemm kernels run at 80%+ of peak achievable TFLOPs/sec on Hopper
- benchmark scripts for each gemm
- all quantization kernels run at 80%+ of peak achievable memory bandwidth on Hopper
- Integration into torchtitan
- Validate loss convergence virtually identical to bf16 for 3k+ steps on full size DeepSeekV3 671b
- Validate e2e throughput (TPS) improvement in same training run as above
- Documentation
- README
- torchao docsite
syed-ahmed