Skip to content

Commit 1c7ceea

Browse files
authored
rename NVFP4Tensor's _scale_e4m3 field to scale (#3166)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 8a4f586 commit 1c7ceea

File tree

2 files changed

+31
-33
lines changed

2 files changed

+31
-33
lines changed

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def test_nvfp4_swizzled_scales_view_semantics():
285285

286286
# Test full-width column slicing (should maintain views)
287287
full_width_slice = tensor[:, 0:K]
288-
assert full_width_slice._scale_e4m3.data_ptr() == tensor._scale_e4m3.data_ptr()
288+
assert full_width_slice.scale.data_ptr() == tensor.scale.data_ptr()
289289
assert full_width_slice.qdata.data_ptr() == tensor.qdata.data_ptr()
290290

291291

@@ -394,9 +394,7 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
394394
use_triton_kernel=True,
395395
)
396396

397-
torch.testing.assert_close(
398-
nvfp4_pt._scale_e4m3.flatten(), nvfp4_triton._scale_e4m3.flatten()
399-
)
397+
torch.testing.assert_close(nvfp4_pt.scale.flatten(), nvfp4_triton.scale.flatten())
400398
pt_unpacked = unpack_uint4(nvfp4_pt.qdata)
401399
triton_unpacked = unpack_uint4(nvfp4_triton.qdata)
402400
torch.testing.assert_close(
@@ -523,7 +521,7 @@ def test_nvfp4_to_copy():
523521
x = NVFP4Tensor.to_nvfp4(torch.randn((32, 128))).cuda()
524522
y = torch.ops.aten._to_copy(x, dtype=torch.bfloat16)
525523
assert torch.equal(x.qdata, y.qdata)
526-
assert torch.equal(x._scale_e4m3, y._scale_e4m3)
524+
assert torch.equal(x.scale, y.scale)
527525
assert x._per_tensor_scale is None
528526
assert y._per_tensor_scale is None
529527
assert x._act_per_tensor_scale is None
@@ -586,20 +584,20 @@ def test_scale_shape_matches_qdata(
586584
if is_swizzled_scales:
587585
# in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile
588586
expected_padded_m = ceil_div(orig_m, 128) * 32
589-
actual_padded_m = x._scale_e4m3.shape[m_dim]
587+
actual_padded_m = x.scale.shape[m_dim]
590588
assert expected_padded_m == actual_padded_m, (
591-
f"incompatible padded shape for dim {m_dim}: {expected_padded_m=}, {actual_padded_m=}, {x.shape}, {x._scale_e4m3.shape}"
589+
f"incompatible padded shape for dim {m_dim}: {expected_padded_m=}, {actual_padded_m=}, {x.shape}, {x.scale.shape}"
592590
)
593591

594592
orig_k = x_hp.shape[k_dim]
595593
expected_padded_k = orig_k // block_size
596594
if is_swizzled_scales:
597595
# in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile
598596
expected_padded_k = ceil_div(orig_k // block_size, 4) * 16
599-
actual_padded_k = x._scale_e4m3.shape[k_dim]
597+
actual_padded_k = x.scale.shape[k_dim]
600598

601599
assert expected_padded_k == actual_padded_k, (
602-
f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x._scale_e4m3.shape}"
600+
f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x.scale.shape}"
603601
)
604602

605603

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class NVFP4Tensor(TorchAOBaseTensor):
7575
7676
Attributes:
7777
qdata: Packed FP4 data (2 values per byte)
78-
_scale_e4m3: Blockwise scales in float8_e4m3fn format (may be swizzled)
78+
scale: Blockwise scales in float8_e4m3fn format (may be swizzled)
7979
_per_tensor_scale: Optional global per-tensor scale in float32 format
8080
_act_per_tensor_scale: Optional global per-tensor scale in float32 format, for activation
8181
_block_size (int): Block size for quantization (fixed at 16)
@@ -84,7 +84,7 @@ class NVFP4Tensor(TorchAOBaseTensor):
8484
use_triton_kernel (bool): Whether to use triton kernels
8585
"""
8686

87-
tensor_data_names = ["qdata", "_scale_e4m3"]
87+
tensor_data_names = ["qdata", "scale"]
8888
tensor_attribute_names = [
8989
"_block_size",
9090
"_orig_dtype",
@@ -99,7 +99,7 @@ class NVFP4Tensor(TorchAOBaseTensor):
9999
def __new__(
100100
cls,
101101
qdata,
102-
blockwise_scales,
102+
scale,
103103
block_size,
104104
orig_dtype,
105105
_per_tensor_scale=None,
@@ -125,7 +125,7 @@ def __new__(
125125
)
126126

127127
self.qdata = qdata
128-
self._scale_e4m3 = blockwise_scales
128+
self.scale = scale
129129
self._block_size = block_size
130130
self._orig_dtype = orig_dtype
131131
self._per_tensor_scale = _per_tensor_scale
@@ -136,7 +136,7 @@ def __new__(
136136
return self
137137

138138
def __repr__(self):
139-
return f"NVFP4Tensor: blockwise_scales: {self._scale_e4m3}, per_tensor_scale: {self._per_tensor_scale}, d: {self.qdata}, d_hp: {self.to_dtype(self._orig_dtype)}"
139+
return f"NVFP4Tensor: scale: {self.scale}, per_tensor_scale: {self._per_tensor_scale}, d: {self.qdata}, d_hp: {self.to_dtype(self._orig_dtype)}"
140140

141141
def _quantization_type(self):
142142
return f"{self._is_swizzled_scales=}, {self.use_triton_kernel=}, {self.act_quant_kwargs=}"
@@ -258,10 +258,10 @@ def get_hp_scales(self) -> torch.Tensor:
258258
is_transposed = self.qdata.stride(-2) < self.qdata.stride(-1)
259259
if is_transposed:
260260
leading_dims, M, K = self.shape[:-2], self.shape[-1], self.shape[-2]
261-
scale_e4m3 = self._scale_e4m3.transpose(-2, -1)
261+
scale_e4m3 = self.scale.transpose(-2, -1)
262262
else:
263263
leading_dims, M, K = self.shape[:-2], self.shape[-2], self.shape[-1]
264-
scale_e4m3 = self._scale_e4m3
264+
scale_e4m3 = self.scale
265265

266266
if self._is_swizzled_scales:
267267
scale_e4m3 = from_blocked(
@@ -298,7 +298,7 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
298298
and self._block_size == src._block_size
299299
and self._orig_dtype == src._orig_dtype
300300
and self._is_swizzled_scales == src._is_swizzled_scales
301-
and self._scale_e4m3.shape == src._scale_e4m3.shape
301+
and self.scale.shape == src.scale.shape
302302
and per_tensor_scale_equal
303303
and act_per_tensor_scale_equal
304304
and self.qdata.shape == src.qdata.shape
@@ -338,7 +338,7 @@ def nvfp4_to_copy(func, types, args, kwargs):
338338
if dtype is not None:
339339
res = NVFP4Tensor(
340340
tensor.qdata,
341-
tensor._scale_e4m3,
341+
tensor.scale,
342342
tensor._block_size,
343343
dtype,
344344
tensor._per_tensor_scale,
@@ -437,7 +437,7 @@ def nvfp4_slice(func, types, args, kwargs):
437437
)
438438

439439
sliced_scale = aten.slice.Tensor(
440-
x._scale_e4m3.flatten(), 0, start_idx, end_idx, 1
440+
x.scale.flatten(), 0, start_idx, end_idx, 1
441441
)
442442
sliced_data = aten.slice.Tensor(x.qdata, 0, start, end, step)
443443

@@ -481,7 +481,7 @@ def nvfp4_slice(func, types, args, kwargs):
481481

482482
if start_col_block == 0 and end_col_block == n_col_blocks:
483483
# Full width - no slicing needed
484-
sliced_scale = x._scale_e4m3
484+
sliced_scale = x.scale
485485
else:
486486
# Extract specific column blocks from each row block
487487
# Each row block in swizzled format contains n_col_blocks chunks of (32, 16)
@@ -493,7 +493,7 @@ def nvfp4_slice(func, types, args, kwargs):
493493
row_start = row_block * elements_per_row_block
494494
col_start = row_start + start_col_block * elements_per_block
495495
col_end = row_start + end_col_block * elements_per_block
496-
slices_to_extract.append(x._scale_e4m3.flatten()[col_start:col_end])
496+
slices_to_extract.append(x.scale.flatten()[col_start:col_end])
497497

498498
# Concatenate all the slices
499499
sliced_scale = torch.cat(slices_to_extract, dim=0)
@@ -511,7 +511,7 @@ def nvfp4_slice(func, types, args, kwargs):
511511
)
512512

513513
else:
514-
scale_shaped = x._scale_e4m3.view(M, K // x._block_size)
514+
scale_shaped = x.scale.view(M, K // x._block_size)
515515

516516
if dim == 0:
517517
sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step)
@@ -581,7 +581,7 @@ def nvfp4_t(func, types, args, kwargs):
581581
old = args[0]
582582
new = NVFP4Tensor(
583583
old.qdata.t(),
584-
old._scale_e4m3.t(),
584+
old.scale.t(),
585585
old._block_size,
586586
old._orig_dtype,
587587
old._per_tensor_scale,
@@ -600,7 +600,7 @@ def nvfp4_transpose(func, types, args, kwargs):
600600
valid_3d_dims = ((1, 2), (2, 1), (-1, -2), (-2, -1))
601601
assert (dim0, dim1) in valid_3d_dims, f"transpose unsupported for {dim0=} {dim1=}"
602602
new_qdata = func(old.qdata, dim0, dim1, **kwargs)
603-
new_scale = func(old._scale_e4m3, dim0, dim1, **kwargs)
603+
new_scale = func(old.scale, dim0, dim1, **kwargs)
604604
new = NVFP4Tensor(
605605
new_qdata,
606606
new_scale,
@@ -623,7 +623,7 @@ def nvfp4_view_op(func, types, args, kwargs):
623623
new_data = func(data, new_size, *args[2:], **kwargs)
624624
return NVFP4Tensor(
625625
new_data,
626-
args[0]._scale_e4m3,
626+
args[0].scale,
627627
args[0]._block_size,
628628
args[0]._orig_dtype,
629629
args[0]._per_tensor_scale,
@@ -638,10 +638,10 @@ def nvfp4_view_op(func, types, args, kwargs):
638638
def nvfp4_select(func, types, args, kwargs):
639639
old, dim, index = args
640640
assert dim == 0, f"NVFP4Tensor aten.select.int with {dim=} is not yet supported"
641-
assert len(old.qdata.shape) == len(old._scale_e4m3.shape), "unsupported"
641+
assert len(old.qdata.shape) == len(old.scale.shape), "unsupported"
642642
new = old.__class__(
643643
old.qdata[index],
644-
old._scale_e4m3[index],
644+
old.scale[index],
645645
old._block_size,
646646
old._orig_dtype,
647647
old._per_tensor_scale,
@@ -661,9 +661,9 @@ def _addmm_nvfp4_dispatch(
661661
The only difference is whether bias is None or not.
662662
"""
663663
assert a.qdata.is_contiguous()
664-
assert a._scale_e4m3.is_contiguous()
664+
assert a.scale.is_contiguous()
665665
assert b.qdata.t().is_contiguous()
666-
assert b._scale_e4m3.t().is_contiguous()
666+
assert b.scale.t().is_contiguous()
667667
assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}"
668668
assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}"
669669

@@ -672,15 +672,15 @@ def _addmm_nvfp4_dispatch(
672672

673673
# Swizzle Dizzle
674674
if a._is_swizzled_scales:
675-
a_scale_blocked = a._scale_e4m3 # Already swizzled
675+
a_scale_blocked = a.scale # Already swizzled
676676
else:
677-
a_scale = a._scale_e4m3.view(M, K // a._block_size)
677+
a_scale = a.scale.view(M, K // a._block_size)
678678
a_scale_blocked = to_blocked(a_scale)
679679

680680
if b._is_swizzled_scales:
681-
b_scale_blocked = b._scale_e4m3.t() # Already swizzled
681+
b_scale_blocked = b.scale.t() # Already swizzled
682682
else:
683-
b_scale = b._scale_e4m3.t().view(N, K // b._block_size)
683+
b_scale = b.scale.t().view(N, K // b._block_size)
684684
b_scale_blocked = to_blocked(b_scale)
685685

686686
# Merge double quant scales into 1 scale for Scale_In^D

0 commit comments

Comments
 (0)