Skip to content

Commit 8a4f586

Browse files
authored
rename MXTensor's _scale_e8m0 to scale (#3164)
Update [ghstack-poisoned]
1 parent fb1450d commit 8a4f586

File tree

3 files changed

+29
-29
lines changed

3 files changed

+29
-29
lines changed

test/prototype/mx_formats/test_mx_mm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:
4343
assert b_data.is_contiguous()
4444
b_data = b_data.transpose(-1, -2)
4545

46-
a_scale = a_mx._scale_e8m0.view(M, K // 32)
47-
b_scale = b_mx._scale_e8m0.view(N, K // 32)
46+
a_scale = a_mx.scale.view(M, K // 32)
47+
b_scale = b_mx.scale.view(N, K // 32)
4848

4949
a_scale_block = to_blocked(a_scale)
5050
b_scale_block = to_blocked(b_scale)

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
7575
assert data_mx.qdata.shape == (*prev_dims, K // 2)
7676
else:
7777
assert data_mx.qdata.shape == (*prev_dims, K)
78-
assert data_mx._scale_e8m0.shape == (*prev_dims, K // block_size)
78+
assert data_mx.scale.shape == (*prev_dims, K // block_size)
7979

8080

8181
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -146,7 +146,7 @@ def test_to_mx_rceil():
146146
data_mx = MXTensor.to_mx(
147147
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
148148
)
149-
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
149+
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
150150
assert torch.isnan(data_mx.qdata[0])
151151
assert torch.all(data_mx.qdata[1:] == 0)
152152
# fp32 denorm
@@ -168,7 +168,7 @@ def test_to_mx_rceil():
168168
data_mx = MXTensor.to_mx(
169169
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
170170
)
171-
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
171+
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
172172
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
173173
# bf16 denorm
174174
# fmt: off
@@ -189,7 +189,7 @@ def test_to_mx_rceil():
189189
data_mx = MXTensor.to_mx(
190190
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
191191
)
192-
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
192+
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
193193
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
194194
# fp32 some denorm
195195
# fmt: off
@@ -220,7 +220,7 @@ def test_to_mx_rceil():
220220
data_mx = MXTensor.to_mx(
221221
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
222222
)
223-
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
223+
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
224224
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
225225
# bf16 some denorm
226226
# fmt: off
@@ -251,7 +251,7 @@ def test_to_mx_rceil():
251251
data_mx = MXTensor.to_mx(
252252
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
253253
)
254-
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
254+
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
255255
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
256256
# zero
257257
data_hp = torch.tensor([0] * 32, dtype=torch.uint32).view(torch.float32)
@@ -262,7 +262,7 @@ def test_to_mx_rceil():
262262
data_mx = MXTensor.to_mx(
263263
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
264264
)
265-
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
265+
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
266266
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
267267
# fp32 normal
268268
# fmt: off
@@ -293,7 +293,7 @@ def test_to_mx_rceil():
293293
data_mx = MXTensor.to_mx(
294294
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
295295
)
296-
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
296+
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
297297
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
298298
# bf16 normal
299299
# fmt: off
@@ -324,7 +324,7 @@ def test_to_mx_rceil():
324324
data_mx = MXTensor.to_mx(
325325
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
326326
)
327-
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
327+
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
328328
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
329329

330330

@@ -340,8 +340,8 @@ def test_exponent_nan_in(elem_dtype):
340340
)
341341
block_size = 4
342342
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
343-
assert torch.all(torch.isnan(tensor_mx._scale_e8m0[0]))
344-
assert not torch.any(torch.isnan(tensor_mx._scale_e8m0[1:]))
343+
assert torch.all(torch.isnan(tensor_mx.scale[0]))
344+
assert not torch.any(torch.isnan(tensor_mx.scale[1:]))
345345

346346

347347
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -507,8 +507,8 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
507507
x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
508508
x_mx_c = to_mx_c(x, elem_dtype, block_size)
509509
torch.testing.assert_close(
510-
x_mx._scale_e8m0,
511-
x_mx_c._scale_e8m0,
510+
x_mx.scale,
511+
x_mx_c.scale,
512512
atol=0,
513513
rtol=0,
514514
)
@@ -519,15 +519,15 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
519519
pack_fp6 = False
520520
x_mx_dq = to_dtype(
521521
x_mx.qdata,
522-
x_mx._scale_e8m0,
522+
x_mx.scale,
523523
x_mx._elem_dtype,
524524
x_mx._block_size,
525525
hp_dtype, # noqa: E501
526526
pack_fp6,
527527
)
528528
x_mx_c_dq = to_dtype_c(
529529
x_mx_c.qdata,
530-
x_mx_c._scale_e8m0,
530+
x_mx_c.scale,
531531
x_mx_c._elem_dtype,
532532
x_mx_c._block_size,
533533
hp_dtype,

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def tensor_size_fp6x4_to_hpx3(orig_size, is_contiguous):
470470

471471

472472
class MXTensor(TorchAOBaseTensor):
473-
tensor_data_names = ["qdata", "_scale_e8m0"]
473+
tensor_data_names = ["qdata", "scale"]
474474
tensor_attribute_names = [
475475
"_elem_dtype",
476476
"_block_size",
@@ -548,10 +548,10 @@ def __new__(
548548
# TODO investigate
549549
assert target_numel == qdata.numel(), f"{target_numel} != {qdata.numel()}"
550550

551-
# `_scale_e8m0` has rank 1 and applies to a row-major memory layout of
551+
# `scale` has rank 1 and applies to a row-major memory layout of
552552
# `qdata`
553553
self.qdata = qdata
554-
self._scale_e8m0 = scale_e8m0_bits
554+
self.scale = scale_e8m0_bits
555555
self._elem_dtype = elem_dtype
556556
self._block_size = block_size
557557
self._orig_dtype = orig_dtype
@@ -562,15 +562,15 @@ def __new__(
562562

563563
def __repr__(self):
564564
# TODO better elem dtype print for fp4
565-
return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self._scale_e8m0}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}" # noqa: E501
565+
return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self.scale}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}" # noqa: E501
566566

567567
def _quantization_type(self):
568568
return f"{self._elem_dtype=}, {self._block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}"
569569

570570
def to_dtype(self, target_dtype):
571571
return to_dtype(
572572
self.qdata,
573-
self._scale_e8m0,
573+
self.scale,
574574
self._elem_dtype,
575575
self._block_size,
576576
target_dtype,
@@ -685,8 +685,8 @@ def _addmm_mx_dispatch(
685685
assert a._block_size == 32, f"Invalid block size {a._block_size}"
686686
assert b._block_size == 32, f"Invalid block size {b._block_size}"
687687

688-
a_scale = a._scale_e8m0.view(M, K // a._block_size)
689-
b_scale = b._scale_e8m0.view(N, K // b._block_size)
688+
a_scale = a.scale.view(M, K // a._block_size)
689+
b_scale = b.scale.view(N, K // b._block_size)
690690
a_scale_block = to_blocked(a_scale)
691691
b_scale_block = to_blocked(b_scale)
692692

@@ -757,7 +757,7 @@ def mx_t(func, types, args, kwargs):
757757
old = args[0]
758758
new = MXTensor(
759759
old.qdata.t(),
760-
old._scale_e8m0,
760+
old.scale,
761761
old._elem_dtype,
762762
old._block_size,
763763
old._orig_dtype,
@@ -801,7 +801,7 @@ def mx_view_op(func, types, args, kwargs):
801801
new_data = func(data, new_size, *args[2:], **kwargs)
802802
return MXTensor(
803803
new_data,
804-
args[0]._scale_e8m0,
804+
args[0].scale,
805805
args[0]._elem_dtype,
806806
args[0]._block_size,
807807
args[0]._orig_dtype,
@@ -821,7 +821,7 @@ def mx_slice(func, types, args, kwargs):
821821
M, K = x.shape[0], x.shape[1]
822822

823823
# TODO why doesn't scale have shape?
824-
scale_shaped = x._scale_e8m0.view(M, K // x._block_size)
824+
scale_shaped = x.scale.view(M, K // x._block_size)
825825

826826
if dim == 0:
827827
# Slicing along the first dimension (rows) TODO assuming that dim 1 is reduciton dim for now
@@ -888,12 +888,12 @@ def mx_clone(func, types, args, kwargs):
888888
def mx_select(func, types, args, kwargs):
889889
old_mx_tensor, dim, index = args
890890
assert dim == 0, f"MXTensor aten.select.int with {dim=} is not yet supported"
891-
assert len(old_mx_tensor.qdata.shape) == len(old_mx_tensor._scale_e8m0.shape), (
891+
assert len(old_mx_tensor.qdata.shape) == len(old_mx_tensor.scale.shape), (
892892
"unsupported"
893893
)
894894
new_mx_tensor = old_mx_tensor.__class__(
895895
old_mx_tensor.qdata[index],
896-
old_mx_tensor._scale_e8m0[index],
896+
old_mx_tensor.scale[index],
897897
old_mx_tensor._elem_dtype,
898898
old_mx_tensor._block_size,
899899
old_mx_tensor._orig_dtype,

0 commit comments

Comments
 (0)