Skip to content

NVFP4 Training Tracker #3293

@danielvegamyhre

Description

@danielvegamyhre

We want to support NVFP4 training in torchao for both dense and MoE models, following the recipe describe in this paper from NVIDIA.

Image

Support for dense models (linears)

  • Functionality
    • Random hadamard transform for wgrad gemm inputs (16x16 block granularities)
    • Rounding modes ([RFC] NVFP4 Rounding Modes #3264)
      • Stochastic rounding for gradients
      • Round to nearest even (RTNE) for activations and weights (e.g. *.rn modifier)
        • Triton wrapping inline ptx with .rn modifier, same as above
    • Quantization
      • 1x16 quantization for LHS activations (support RTNE and SR as it will be used for grads)
      • 16x1 quantization for RHS activations (support RTNE - if we do wgrad = grad_out.t() @ input we won't need SR for 16x1 scaling / RHS operand)
      • 16x16 quantization for weights, use RNTE
      • 1x16 @ 16x16 scaled gemm
        • Use for output = input @ weight.t()
        • Use for dgrad = grad_output @ weight
      • 1x16 @ 16x1 scaled gemm
        • Use for wgrad = grad_output.t() @ input
    • Autograd function implementing forward and backward with dynamic quant as described in the paper:
      • Forward
        • 1x16 quant with RTNE on 2d input
        • 16x16 quant with RTNE on 2d weights
        • 1x16 @ 16x16 scaled mm for output
      • Backward
        • dgrad
          • 1x16 quant with SR on 2d upstream grad
          • 16x16 quant on 2d weights with RTNE
          • per group scale factor conversion to blocked swizzled format (2d grad)
          • per group scale factor conversion to blocked swizzled format (2d weight)
          • 1x16 @ 16x16 scaled mm for dgrad
        • wgrad
          • RHT on upstream grad (16x16 granularity)
          • RHT on input activations (16x16 granularity)
          • 1x16 quant with SR for 2d transposed upstream grad
          • 16x1 quant with RTNE for 2d input activations
          • 1x16 @ 16x1 scaled_mm for wgrad
    • NVFP4Linear wrapping autograd func (used for module swaps)
    • DTensor handling for TP support
    • Custom ops around all custom kernels for torch.compile composability
    • Tests for FSDP, TP
    • quantize_ model conversion api peforming module swap of nn.Linear to NVFP4Linear (wraps autograd func)
  • Performance
    • Random Hadamard transform runs at 80%+ peak achievable memory bandwidth on Blackwell, for all block granularities
    • All quantization kernels run at 80%+ of peak achievable memory bandwidth on Blackwell
      • benchmark scripts for each quantization kernel
    • All gemm kernels run at 80%+ of peak achievable TFLOPs/sec on Blackwell
      • benchmark scripts for each gemm
  • Integration into torchtitan
    • Validate loss convergence virtually identical to bf16 for 3k+ steps on full size Llama3 8b/70b
    • Validate e2e throughput (TPS) improvement in same training run as above
  • Documentation
    • README
    • torchao docsite

Support for MoE layers (grouped GEMMs)

We can extend the low precision MoE training code here to support NVFP4 by doing the following:

  • Functionality
    • Quantization
      • 16x16 scaling quantization kernel for 3d expert weights (alternative: update existing kernel to be compatible and performant for 3d inputs)
      • Triton kernel for per-group conversion to blocked swizzled format where groups are along M (can probably reuse the one we wrote for mxfp8 here)
      • Triton kernel for per-group conversion to blocked swizzled format where groups are along the K/contracting dim (can probably reuse the ones we wrote for mxfp8 here)
      • Triton kernel for per-group conversion to blocked swizzled format for 3d expert weights (can probably re-use one we wrote for mxfp8 here)
      • 1x16 @ 16x16 scaled grouped gemm
        • Use for output = input @ weight.transpose(-2,-1)
        • Use for dgrad = grad_output @ weight
      • 1x16 @ 16x1 scaled grouped gemm
        • Use for wgrad = grad_output.transpose(-2,-1) @ input
    • Autograd function implementing forward and backward with dynamic quant on inputs (see mxfp8 example)
      • Forward
        • 1x16 quant with RTNE on 2d input
        • 16x16 quant with RTNE on 3d weights
        • 1x16 @ 16x16 scaled mm for output
      • Backward
        • dgrad
          • 1x16 quant with SR on 2d upstream grad
          • 16x16 quant on 3d weights with RTNE
          • per group scale factor conversion to blocked swizzled format (2d grad)
          • per group scale factor conversion to blocked swizzled format (3d weight)
          • 1x16 @ 16x16 scaled mm for dgrad
        • wgrad
          • RHT on upstream grad (16x16 granularity)
          • RHT on input activations (16x16 granularity)
          • 1x16 quant with SR for 2d transposed upstream grad
          • 16x1 quant with RTNE for 2d input activations
          • 1x16 @ 16x1 scaled_mm for wgrad
    • Custom ops around all custom kernels for torch.compile composability
    • Tests for FSDP, TP, EP
  • Performance
    • all quantization kernels run at 80%+ of peak achievable memory bandwidth on Hopper
      • benchmark scripts for each quantization kernel
    • all gemm kernels run at 80%+ of peak achievable TFLOPs/sec on Hopper
      • benchmark scripts for each gemm
  • Integration into torchtitan
    • Validate loss convergence virtually identical to bf16 for 3k+ steps on full size DeepSeekV3 671b
    • Validate e2e throughput (TPS) improvement in same training run as above
  • Documentation
    • README
    • torchao docsite

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions