55# LICENSE file in the root directory of this source tree.
66
77import logging
8- from functools import partial
98from typing import Optional
109
1110import torch
3433 ScaleCalculationMode ,
3534)
3635from torchao .prototype .mx_formats .kernels import triton_to_mxfp8_dim0
36+ from torchao .prototype .mx_formats .mx_tensor import to_mx
3737from torchao .prototype .mx_formats .utils import _to_mxfp8_dim1_kernel_wrapper
3838
3939logger : logging .Logger = logging .getLogger (__name__ )
@@ -79,15 +79,6 @@ def _quantize_then_scaled_grouped_mm(
7979 raise ValueError (f"Unsupported scaling type { scaling_type } " )
8080
8181
82- # Aliases for convenience/clarity
83- _to_mxfp8_then_scaled_grouped_mm = partial (
84- _quantize_then_scaled_grouped_mm , scaling_type = MoEScalingType .MXFP8
85- )
86- _to_fp8_rowwise_then_scaled_grouped_mm = partial (
87- _quantize_then_scaled_grouped_mm , scaling_type = MoEScalingType .FP8_ROWWISE
88- )
89-
90-
9182class _Float8GroupedMM (torch .autograd .Function ):
9283 """Differentiable implementation of grouped GEMM with dynamic float8 quantization."""
9384
@@ -304,6 +295,7 @@ def forward(
304295 block_size : int = 32 ,
305296 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
306297 emulated : bool = False ,
298+ use_triton_for_dim0_cast : bool = False ,
307299 ) -> torch .Tensor :
308300 # torchao _quantize_then_scaled_grouped_mm only supports A=2D and B=3D.
309301 assert A .ndim == 2 , "A must be 2D"
@@ -313,17 +305,28 @@ def forward(
313305
314306 # A_data shape: (M, K)
315307 # A_scale shape: (M, K//block_size)
316- A_data , A_scale = triton_to_mxfp8_dim0 (
317- A ,
318- inner_block_size = block_size ,
319- )
320-
321- # B_data shape: (E, N, K)
322- # B_scale shape: (E, N, K//block_size)
323- B_data , B_scales = triton_to_mxfp8_dim0 (
324- B_t .transpose (- 2 , - 1 ),
325- inner_block_size = block_size ,
326- )
308+ if use_triton_for_dim0_cast :
309+ A_data , A_scale = triton_to_mxfp8_dim0 (
310+ A ,
311+ inner_block_size = block_size ,
312+ )
313+ # B_data shape: (E, N, K)
314+ # B_scale shape: (E, N, K//block_size)
315+ B_data , B_scales = triton_to_mxfp8_dim0 (
316+ B_t .transpose (- 2 , - 1 ),
317+ inner_block_size = block_size ,
318+ )
319+ else :
320+ A_scale , A_data = to_mx (
321+ A ,
322+ elem_dtype = torch .float8_e4m3fn ,
323+ block_size = block_size ,
324+ )
325+ B_scales , B_data = to_mx (
326+ B_t .transpose (- 2 , - 1 ),
327+ elem_dtype = torch .float8_e4m3fn ,
328+ block_size = block_size ,
329+ )
327330
328331 # Convert scales to blocked format for 2d-3d grouped mm
329332 _ , blocked_scales_group_offsets_2d3d = (
@@ -351,19 +354,28 @@ def forward(
351354 ctx .block_size = block_size
352355 ctx .out_dtype = out_dtype
353356 ctx .emulated = emulated
357+ ctx .use_triton_for_dim0_cast = use_triton_for_dim0_cast
354358 return out
355359
356360 @staticmethod
357361 def backward (ctx , grad_out : torch .Tensor ):
358362 A , B_t , offs , blocked_scales_group_offsets_2d3d = ctx .saved_tensors
359363 block_size = ctx .block_size
360364 out_dtype = ctx .out_dtype
365+ use_triton_for_dim0_cast = ctx .use_triton_for_dim0_cast
361366
362367 # grad_out_data shape: (M, N)
363368 # grad_out_scale shape: (M, N//block_size)
364- grad_out_data , grad_out_scale = triton_to_mxfp8_dim0 (
365- grad_out , inner_block_size = block_size
366- )
369+ if use_triton_for_dim0_cast :
370+ grad_out_data , grad_out_scale = triton_to_mxfp8_dim0 (
371+ grad_out , inner_block_size = block_size
372+ )
373+ else :
374+ grad_out_scale , grad_out_data = to_mx (
375+ grad_out ,
376+ elem_dtype = torch .float8_e4m3fn ,
377+ block_size = block_size ,
378+ )
367379
368380 # Quantize 3d expert weights along N (contraction dimension for next grouped gemm)
369381 # (E, K, N) -> (E, N, K)
@@ -449,7 +461,7 @@ def backward(ctx, grad_out: torch.Tensor):
449461 )
450462 # grad_B_t shape = (E,K,N)
451463 grad_B_t = grad_B .transpose (- 2 , - 1 )
452- return grad_A , grad_B_t , None , None , None
464+ return grad_A , grad_B_t , None , None , None , None
453465
454466
455467def _to_mxfp8_dim1_3d (
@@ -659,3 +671,8 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
659671
660672def round_up (x , y ):
661673 return ((x + y - 1 ) // y ) * y
674+
675+
676+ # Aliases for convenience/clarity
677+ _to_mxfp8_then_scaled_grouped_mm = _MXFP8GroupedMM .apply
678+ _to_fp8_rowwise_then_scaled_grouped_mm = _Float8GroupedMM .apply
0 commit comments