@@ -296,6 +296,7 @@ def forward(
296296 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
297297 emulated : bool = False ,
298298 use_triton_for_dim0_cast : bool = False ,
299+ scale_calculation_mode : ScaleCalculationMode = ScaleCalculationMode .RCEIL ,
299300 ) -> torch .Tensor :
300301 # torchao _quantize_then_scaled_grouped_mm only supports A=2D and B=3D.
301302 assert A .ndim == 2 , "A must be 2D"
@@ -321,11 +322,13 @@ def forward(
321322 A ,
322323 elem_dtype = torch .float8_e4m3fn ,
323324 block_size = block_size ,
325+ scaling_mode = scale_calculation_mode ,
324326 )
325327 B_scales , B_data = to_mx (
326328 B_t .transpose (- 2 , - 1 ),
327329 elem_dtype = torch .float8_e4m3fn ,
328330 block_size = block_size ,
331+ scaling_mode = scale_calculation_mode ,
329332 )
330333
331334 # Convert scales to blocked format for 2d-3d grouped mm
@@ -355,6 +358,7 @@ def forward(
355358 ctx .out_dtype = out_dtype
356359 ctx .emulated = emulated
357360 ctx .use_triton_for_dim0_cast = use_triton_for_dim0_cast
361+ ctx .scale_calculation_mode = scale_calculation_mode
358362 return out
359363
360364 @staticmethod
@@ -363,6 +367,7 @@ def backward(ctx, grad_out: torch.Tensor):
363367 block_size = ctx .block_size
364368 out_dtype = ctx .out_dtype
365369 use_triton_for_dim0_cast = ctx .use_triton_for_dim0_cast
370+ scale_calculation_mode = ctx .scale_calculation_mode
366371
367372 # grad_out_data shape: (M, N)
368373 # grad_out_scale shape: (M, N//block_size)
@@ -375,13 +380,16 @@ def backward(ctx, grad_out: torch.Tensor):
375380 grad_out ,
376381 elem_dtype = torch .float8_e4m3fn ,
377382 block_size = block_size ,
383+ scaling_mode = scale_calculation_mode ,
378384 )
379385
380386 # Quantize 3d expert weights along N (contraction dimension for next grouped gemm)
381387 # (E, K, N) -> (E, N, K)
382388 B = B_t .transpose (- 2 , - 1 )
383389 B_data , B_scales = mxfp8_quantize_cuda_3d (
384- B ._data if hasattr (B , "_data" ) else B , block_size = block_size
390+ B ._data if hasattr (B , "_data" ) else B ,
391+ block_size = block_size ,
392+ scaling_mode = scale_calculation_mode .value .lower (),
385393 )
386394 # (E, N//block_size, K) -> (E, K, N//block_size)
387395 B_scales = B_scales .transpose (- 2 , - 1 )
@@ -413,7 +421,7 @@ def backward(ctx, grad_out: torch.Tensor):
413421 hp_dtype = grad_out .dtype ,
414422 gemm_kernel_choice = MXGemmKernelChoice .CUTLASS , # Not used
415423 cast_kernel_choice = MXFP8Dim1CastKernelChoice .CUDA ,
416- scale_calculation_mode = ScaleCalculationMode . FLOOR ,
424+ scale_calculation_mode = scale_calculation_mode ,
417425 )
418426 grad_out_t_data = grad_out_t_mx .qdata
419427 grad_out_t_scales = grad_out_t_mx .scale
@@ -429,7 +437,7 @@ def backward(ctx, grad_out: torch.Tensor):
429437 hp_dtype = A .dtype ,
430438 gemm_kernel_choice = MXGemmKernelChoice .CUTLASS , # Not used
431439 cast_kernel_choice = MXFP8Dim1CastKernelChoice .CUDA ,
432- scale_calculation_mode = ScaleCalculationMode . FLOOR ,
440+ scale_calculation_mode = scale_calculation_mode ,
433441 )
434442 A_t_data = A_t_mx .qdata
435443 A_t_scales = A_t_mx .scale
0 commit comments