@@ -394,7 +394,7 @@ def backward(ctx, grad_out: torch.Tensor):
394394 scale_calculation_mode = ScaleCalculationMode .FLOOR ,
395395 )
396396 grad_out_t_data = grad_out_t_mx .qdata
397- grad_out_t_scales = grad_out_t_mx ._scale_e8m0
397+ grad_out_t_scales = grad_out_t_mx .scale
398398
399399 # Transpose A so we can scale along the M dimension, then un-transpose.
400400 # A shape: (M, K)
@@ -410,7 +410,7 @@ def backward(ctx, grad_out: torch.Tensor):
410410 scale_calculation_mode = ScaleCalculationMode .FLOOR ,
411411 )
412412 A_t_data = A_t_mx .qdata
413- A_t_scales = A_t_mx ._scale_e8m0
413+ A_t_scales = A_t_mx .scale
414414
415415 # Convert scales to blocked format for 2d-2d grouped mm
416416 scale_group_offsets = offs // block_size
@@ -463,7 +463,7 @@ def _to_mxfp8_dim1_3d(
463463 )
464464 B_data = B_t_mx .qdata .t () # (K, E*N) -> (E*N, K)
465465 B_data = B_data .reshape (E , N , K ) # (E*N, K) -> (E, N, K)
466- B_scales = B_t_mx ._scale_e8m0 .view (torch .uint8 ) # (K, E*N//block_size)
466+ B_scales = B_t_mx .scale .view (torch .uint8 ) # (K, E*N//block_size)
467467 B_scales = B_scales .reshape (
468468 K , E , N // block_size
469469 ) # (K, E*N//block_size) -> (K, E, N//block_size)
0 commit comments