Skip to content

Commit ce9f009

Browse files
authored
fixup missing callsites for _scale_e8m0 -> scale rename (#3173)
Update [ghstack-poisoned]
1 parent 8878f30 commit ce9f009

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

benchmarks/float8/profile_lowp_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def cast_with_to_blocked(x_hp):
392392
gemm_kernel_choice=config.gemm_kernel_choice,
393393
)
394394
m, k = x_hp.shape
395-
scale_blocked = to_blocked(x_mx._scale_e8m0.reshape(m, k // config.block_size))
395+
scale_blocked = to_blocked(x_mx.scale.reshape(m, k // config.block_size))
396396
return x_mx._data, scale_blocked
397397

398398
# this function is used for cast_only_dim0_dim1

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def backward(ctx, grad_out: torch.Tensor):
394394
scale_calculation_mode=ScaleCalculationMode.FLOOR,
395395
)
396396
grad_out_t_data = grad_out_t_mx.qdata
397-
grad_out_t_scales = grad_out_t_mx._scale_e8m0
397+
grad_out_t_scales = grad_out_t_mx.scale
398398

399399
# Transpose A so we can scale along the M dimension, then un-transpose.
400400
# A shape: (M, K)
@@ -410,7 +410,7 @@ def backward(ctx, grad_out: torch.Tensor):
410410
scale_calculation_mode=ScaleCalculationMode.FLOOR,
411411
)
412412
A_t_data = A_t_mx.qdata
413-
A_t_scales = A_t_mx._scale_e8m0
413+
A_t_scales = A_t_mx.scale
414414

415415
# Convert scales to blocked format for 2d-2d grouped mm
416416
scale_group_offsets = offs // block_size
@@ -463,7 +463,7 @@ def _to_mxfp8_dim1_3d(
463463
)
464464
B_data = B_t_mx.qdata.t() # (K, E*N) -> (E*N, K)
465465
B_data = B_data.reshape(E, N, K) # (E*N, K) -> (E, N, K)
466-
B_scales = B_t_mx._scale_e8m0.view(torch.uint8) # (K, E*N//block_size)
466+
B_scales = B_t_mx.scale.view(torch.uint8) # (K, E*N//block_size)
467467
B_scales = B_scales.reshape(
468468
K, E, N // block_size
469469
) # (K, E*N//block_size) -> (K, E, N//block_size)

0 commit comments

Comments
 (0)