55# LICENSE file in the root directory of this source tree.
66
77import logging
8+ from functools import partial
89from typing import Optional
910
1011import torch
3839logger : logging .Logger = logging .getLogger (__name__ )
3940
4041
41- def _scaled_grouped_mm (
42+ def _quantize_then_scaled_grouped_mm (
4243 A : torch .Tensor ,
4344 B_t : torch .Tensor ,
4445 offs : Optional [torch .Tensor ] = None ,
4546 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
4647 scaling_type : MoEScalingType = MoEScalingType .FP8_ROWWISE ,
4748) -> torch .Tensor :
4849 """
49- This function performs dynamic float8 quantization with row-wise scaling
50+ This function performs dynamic quantization with the given recipe
5051 on the input tensors A and B, then performs a scaled grouped GEMM and returns the results.
5152
5253 Args:
@@ -78,6 +79,15 @@ def _scaled_grouped_mm(
7879 raise ValueError (f"Unsupported scaling type { scaling_type } " )
7980
8081
82+ # Aliases for convenience/clarity
83+ _to_mxfp8_then_scaled_grouped_mm = partial (
84+ _quantize_then_scaled_grouped_mm , scaling_type = MoEScalingType .MXFP8
85+ )
86+ _to_fp8_rowwise_then_scaled_grouped_mm = partial (
87+ _quantize_then_scaled_grouped_mm , scaling_type = MoEScalingType .FP8_ROWWISE
88+ )
89+
90+
8191class _Float8GroupedMM (torch .autograd .Function ):
8292 """Differentiable implementation of grouped GEMM with dynamic float8 quantization."""
8393
@@ -89,7 +99,7 @@ def forward(
8999 offs : Optional [torch .Tensor ] = None ,
90100 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
91101 ) -> torch .Tensor :
92- # torchao _scaled_grouped_mm only supports A=2D|3D and B=3D.
102+ # torchao _quantize_then_scaled_grouped_mm only supports A=2D|3D and B=3D.
93103 assert A .ndim == 2 or A .ndim == 3 , "A must be 2D or 3D"
94104 assert B_t .ndim == 3 , "B must be 3D"
95105
@@ -113,7 +123,7 @@ def forward(
113123
114124 # Assert A and B dims are compatible for a scaled grouped GEMM.
115125 assert A .size (- 1 ) == B_t .size (- 2 ), (
116- f"shape { A .shape } and { B_t .shape } are not compatible for _scaled_grouped_mm "
126+ f"shape { A .shape } and { B_t .shape } are not compatible for _quantize_then_scaled_grouped_mm "
117127 )
118128
119129 # The left operand in the scaled grouped GEMM must be row-major due to hardware requirements.
@@ -295,7 +305,7 @@ def forward(
295305 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
296306 emulated : bool = False ,
297307 ) -> torch .Tensor :
298- # torchao _scaled_grouped_mm only supports A=2D and B=3D.
308+ # torchao _quantize_then_scaled_grouped_mm only supports A=2D and B=3D.
299309 assert A .ndim == 2 , "A must be 2D"
300310 assert B_t .ndim == 3 , "B must be 3D"
301311 assert block_size == 32 , "Only block_size=32 is supported"
0 commit comments