Skip to content

Commit f657903

Browse files
[mxfp8 moe training] make scaling mode configurable and make rceil default (#3271)
1 parent 2a89491 commit f657903

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def forward(
296296
out_dtype: Optional[torch.dtype] = torch.bfloat16,
297297
emulated: bool = False,
298298
use_triton_for_dim0_cast: bool = False,
299+
scale_calculation_mode: ScaleCalculationMode = ScaleCalculationMode.RCEIL,
299300
) -> torch.Tensor:
300301
# torchao _quantize_then_scaled_grouped_mm only supports A=2D and B=3D.
301302
assert A.ndim == 2, "A must be 2D"
@@ -321,11 +322,13 @@ def forward(
321322
A,
322323
elem_dtype=torch.float8_e4m3fn,
323324
block_size=block_size,
325+
scaling_mode=scale_calculation_mode,
324326
)
325327
B_scales, B_data = to_mx(
326328
B_t.transpose(-2, -1),
327329
elem_dtype=torch.float8_e4m3fn,
328330
block_size=block_size,
331+
scaling_mode=scale_calculation_mode,
329332
)
330333

331334
# Convert scales to blocked format for 2d-3d grouped mm
@@ -355,6 +358,7 @@ def forward(
355358
ctx.out_dtype = out_dtype
356359
ctx.emulated = emulated
357360
ctx.use_triton_for_dim0_cast = use_triton_for_dim0_cast
361+
ctx.scale_calculation_mode = scale_calculation_mode
358362
return out
359363

360364
@staticmethod
@@ -363,6 +367,7 @@ def backward(ctx, grad_out: torch.Tensor):
363367
block_size = ctx.block_size
364368
out_dtype = ctx.out_dtype
365369
use_triton_for_dim0_cast = ctx.use_triton_for_dim0_cast
370+
scale_calculation_mode = ctx.scale_calculation_mode
366371

367372
# grad_out_data shape: (M, N)
368373
# grad_out_scale shape: (M, N//block_size)
@@ -375,13 +380,16 @@ def backward(ctx, grad_out: torch.Tensor):
375380
grad_out,
376381
elem_dtype=torch.float8_e4m3fn,
377382
block_size=block_size,
383+
scaling_mode=scale_calculation_mode,
378384
)
379385

380386
# Quantize 3d expert weights along N (contraction dimension for next grouped gemm)
381387
# (E, K, N) -> (E, N, K)
382388
B = B_t.transpose(-2, -1)
383389
B_data, B_scales = mxfp8_quantize_cuda_3d(
384-
B._data if hasattr(B, "_data") else B, block_size=block_size
390+
B._data if hasattr(B, "_data") else B,
391+
block_size=block_size,
392+
scaling_mode=scale_calculation_mode.value.lower(),
385393
)
386394
# (E, N//block_size, K) -> (E, K, N//block_size)
387395
B_scales = B_scales.transpose(-2, -1)
@@ -413,7 +421,7 @@ def backward(ctx, grad_out: torch.Tensor):
413421
hp_dtype=grad_out.dtype,
414422
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used
415423
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
416-
scale_calculation_mode=ScaleCalculationMode.FLOOR,
424+
scale_calculation_mode=scale_calculation_mode,
417425
)
418426
grad_out_t_data = grad_out_t_mx.qdata
419427
grad_out_t_scales = grad_out_t_mx.scale
@@ -429,7 +437,7 @@ def backward(ctx, grad_out: torch.Tensor):
429437
hp_dtype=A.dtype,
430438
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used
431439
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
432-
scale_calculation_mode=ScaleCalculationMode.FLOOR,
440+
scale_calculation_mode=scale_calculation_mode,
433441
)
434442
A_t_data = A_t_mx.qdata
435443
A_t_scales = A_t_mx.scale

0 commit comments

Comments
 (0)