@@ -75,7 +75,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
7575 assert data_mx .qdata .shape == (* prev_dims , K // 2 )
7676 else :
7777 assert data_mx .qdata .shape == (* prev_dims , K )
78- assert data_mx ._scale_e8m0 .shape == (* prev_dims , K // block_size )
78+ assert data_mx .scale .shape == (* prev_dims , K // block_size )
7979
8080
8181@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
@@ -146,7 +146,7 @@ def test_to_mx_rceil():
146146 data_mx = MXTensor .to_mx (
147147 data_hp , torch .float8_e4m3fn , 32 , ScaleCalculationMode .RCEIL
148148 )
149- torch .testing .assert_close (data_mx ._scale_e8m0 , ground_truth_scale )
149+ torch .testing .assert_close (data_mx .scale , ground_truth_scale )
150150 assert torch .isnan (data_mx .qdata [0 ])
151151 assert torch .all (data_mx .qdata [1 :] == 0 )
152152 # fp32 denorm
@@ -168,7 +168,7 @@ def test_to_mx_rceil():
168168 data_mx = MXTensor .to_mx (
169169 data_hp , torch .float8_e4m3fn , 32 , ScaleCalculationMode .RCEIL
170170 )
171- torch .testing .assert_close (data_mx ._scale_e8m0 , ground_truth_scale )
171+ torch .testing .assert_close (data_mx .scale , ground_truth_scale )
172172 torch .testing .assert_close (data_mx .qdata , ground_truth_fp8 )
173173 # bf16 denorm
174174 # fmt: off
@@ -189,7 +189,7 @@ def test_to_mx_rceil():
189189 data_mx = MXTensor .to_mx (
190190 data_hp , torch .float8_e4m3fn , 32 , ScaleCalculationMode .RCEIL
191191 )
192- torch .testing .assert_close (data_mx ._scale_e8m0 , ground_truth_scale )
192+ torch .testing .assert_close (data_mx .scale , ground_truth_scale )
193193 torch .testing .assert_close (data_mx .qdata , ground_truth_fp8 )
194194 # fp32 some denorm
195195 # fmt: off
@@ -220,7 +220,7 @@ def test_to_mx_rceil():
220220 data_mx = MXTensor .to_mx (
221221 data_hp , torch .float8_e4m3fn , 32 , ScaleCalculationMode .RCEIL
222222 )
223- torch .testing .assert_close (data_mx ._scale_e8m0 , ground_truth_scale )
223+ torch .testing .assert_close (data_mx .scale , ground_truth_scale )
224224 torch .testing .assert_close (data_mx .qdata , ground_truth_fp8 )
225225 # bf16 some denorm
226226 # fmt: off
@@ -251,7 +251,7 @@ def test_to_mx_rceil():
251251 data_mx = MXTensor .to_mx (
252252 data_hp , torch .float8_e4m3fn , 32 , ScaleCalculationMode .RCEIL
253253 )
254- torch .testing .assert_close (data_mx ._scale_e8m0 , ground_truth_scale )
254+ torch .testing .assert_close (data_mx .scale , ground_truth_scale )
255255 torch .testing .assert_close (data_mx .qdata , ground_truth_fp8 )
256256 # zero
257257 data_hp = torch .tensor ([0 ] * 32 , dtype = torch .uint32 ).view (torch .float32 )
@@ -262,7 +262,7 @@ def test_to_mx_rceil():
262262 data_mx = MXTensor .to_mx (
263263 data_hp , torch .float8_e4m3fn , 32 , ScaleCalculationMode .RCEIL
264264 )
265- torch .testing .assert_close (data_mx ._scale_e8m0 , ground_truth_scale )
265+ torch .testing .assert_close (data_mx .scale , ground_truth_scale )
266266 torch .testing .assert_close (data_mx .qdata , ground_truth_fp8 )
267267 # fp32 normal
268268 # fmt: off
@@ -293,7 +293,7 @@ def test_to_mx_rceil():
293293 data_mx = MXTensor .to_mx (
294294 data_hp , torch .float8_e4m3fn , 32 , ScaleCalculationMode .RCEIL
295295 )
296- torch .testing .assert_close (data_mx ._scale_e8m0 , ground_truth_scale )
296+ torch .testing .assert_close (data_mx .scale , ground_truth_scale )
297297 torch .testing .assert_close (data_mx .qdata , ground_truth_fp8 )
298298 # bf16 normal
299299 # fmt: off
@@ -324,7 +324,7 @@ def test_to_mx_rceil():
324324 data_mx = MXTensor .to_mx (
325325 data_hp , torch .float8_e4m3fn , 32 , ScaleCalculationMode .RCEIL
326326 )
327- torch .testing .assert_close (data_mx ._scale_e8m0 , ground_truth_scale )
327+ torch .testing .assert_close (data_mx .scale , ground_truth_scale )
328328 torch .testing .assert_close (data_mx .qdata , ground_truth_fp8 )
329329
330330
@@ -340,8 +340,8 @@ def test_exponent_nan_in(elem_dtype):
340340 )
341341 block_size = 4
342342 tensor_mx = MXTensor .to_mx (tensor_hp , elem_dtype , block_size )
343- assert torch .all (torch .isnan (tensor_mx ._scale_e8m0 [0 ]))
344- assert not torch .any (torch .isnan (tensor_mx ._scale_e8m0 [1 :]))
343+ assert torch .all (torch .isnan (tensor_mx .scale [0 ]))
344+ assert not torch .any (torch .isnan (tensor_mx .scale [1 :]))
345345
346346
347347@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
@@ -507,8 +507,8 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
507507 x_mx = MXTensor .to_mx (x , elem_dtype , block_size )
508508 x_mx_c = to_mx_c (x , elem_dtype , block_size )
509509 torch .testing .assert_close (
510- x_mx ._scale_e8m0 ,
511- x_mx_c ._scale_e8m0 ,
510+ x_mx .scale ,
511+ x_mx_c .scale ,
512512 atol = 0 ,
513513 rtol = 0 ,
514514 )
@@ -519,15 +519,15 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
519519 pack_fp6 = False
520520 x_mx_dq = to_dtype (
521521 x_mx .qdata ,
522- x_mx ._scale_e8m0 ,
522+ x_mx .scale ,
523523 x_mx ._elem_dtype ,
524524 x_mx ._block_size ,
525525 hp_dtype , # noqa: E501
526526 pack_fp6 ,
527527 )
528528 x_mx_c_dq = to_dtype_c (
529529 x_mx_c .qdata ,
530- x_mx_c ._scale_e8m0 ,
530+ x_mx_c .scale ,
531531 x_mx_c ._elem_dtype ,
532532 x_mx_c ._block_size ,
533533 hp_dtype ,
0 commit comments