@@ -75,7 +75,7 @@ class NVFP4Tensor(TorchAOBaseTensor):
7575
7676 Attributes:
7777 qdata: Packed FP4 data (2 values per byte)
78- _scale_e4m3 : Blockwise scales in float8_e4m3fn format (may be swizzled)
78+ scale : Blockwise scales in float8_e4m3fn format (may be swizzled)
7979 _per_tensor_scale: Optional global per-tensor scale in float32 format
8080 _act_per_tensor_scale: Optional global per-tensor scale in float32 format, for activation
8181 _block_size (int): Block size for quantization (fixed at 16)
@@ -84,7 +84,7 @@ class NVFP4Tensor(TorchAOBaseTensor):
8484 use_triton_kernel (bool): Whether to use triton kernels
8585 """
8686
87- tensor_data_names = ["qdata" , "_scale_e4m3 " ]
87+ tensor_data_names = ["qdata" , "scale " ]
8888 tensor_attribute_names = [
8989 "_block_size" ,
9090 "_orig_dtype" ,
@@ -99,7 +99,7 @@ class NVFP4Tensor(TorchAOBaseTensor):
9999 def __new__ (
100100 cls ,
101101 qdata ,
102- blockwise_scales ,
102+ scale ,
103103 block_size ,
104104 orig_dtype ,
105105 _per_tensor_scale = None ,
@@ -125,7 +125,7 @@ def __new__(
125125 )
126126
127127 self .qdata = qdata
128- self ._scale_e4m3 = blockwise_scales
128+ self .scale = scale
129129 self ._block_size = block_size
130130 self ._orig_dtype = orig_dtype
131131 self ._per_tensor_scale = _per_tensor_scale
@@ -136,7 +136,7 @@ def __new__(
136136 return self
137137
138138 def __repr__ (self ):
139- return f"NVFP4Tensor: blockwise_scales : { self ._scale_e4m3 } , per_tensor_scale: { self ._per_tensor_scale } , d: { self .qdata } , d_hp: { self .to_dtype (self ._orig_dtype )} "
139+ return f"NVFP4Tensor: scale : { self .scale } , per_tensor_scale: { self ._per_tensor_scale } , d: { self .qdata } , d_hp: { self .to_dtype (self ._orig_dtype )} "
140140
141141 def _quantization_type (self ):
142142 return f"{ self ._is_swizzled_scales = } , { self .use_triton_kernel = } , { self .act_quant_kwargs = } "
@@ -258,10 +258,10 @@ def get_hp_scales(self) -> torch.Tensor:
258258 is_transposed = self .qdata .stride (- 2 ) < self .qdata .stride (- 1 )
259259 if is_transposed :
260260 leading_dims , M , K = self .shape [:- 2 ], self .shape [- 1 ], self .shape [- 2 ]
261- scale_e4m3 = self ._scale_e4m3 .transpose (- 2 , - 1 )
261+ scale_e4m3 = self .scale .transpose (- 2 , - 1 )
262262 else :
263263 leading_dims , M , K = self .shape [:- 2 ], self .shape [- 2 ], self .shape [- 1 ]
264- scale_e4m3 = self ._scale_e4m3
264+ scale_e4m3 = self .scale
265265
266266 if self ._is_swizzled_scales :
267267 scale_e4m3 = from_blocked (
@@ -298,7 +298,7 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
298298 and self ._block_size == src ._block_size
299299 and self ._orig_dtype == src ._orig_dtype
300300 and self ._is_swizzled_scales == src ._is_swizzled_scales
301- and self ._scale_e4m3 .shape == src ._scale_e4m3 .shape
301+ and self .scale .shape == src .scale .shape
302302 and per_tensor_scale_equal
303303 and act_per_tensor_scale_equal
304304 and self .qdata .shape == src .qdata .shape
@@ -338,7 +338,7 @@ def nvfp4_to_copy(func, types, args, kwargs):
338338 if dtype is not None :
339339 res = NVFP4Tensor (
340340 tensor .qdata ,
341- tensor ._scale_e4m3 ,
341+ tensor .scale ,
342342 tensor ._block_size ,
343343 dtype ,
344344 tensor ._per_tensor_scale ,
@@ -437,7 +437,7 @@ def nvfp4_slice(func, types, args, kwargs):
437437 )
438438
439439 sliced_scale = aten .slice .Tensor (
440- x ._scale_e4m3 .flatten (), 0 , start_idx , end_idx , 1
440+ x .scale .flatten (), 0 , start_idx , end_idx , 1
441441 )
442442 sliced_data = aten .slice .Tensor (x .qdata , 0 , start , end , step )
443443
@@ -481,7 +481,7 @@ def nvfp4_slice(func, types, args, kwargs):
481481
482482 if start_col_block == 0 and end_col_block == n_col_blocks :
483483 # Full width - no slicing needed
484- sliced_scale = x ._scale_e4m3
484+ sliced_scale = x .scale
485485 else :
486486 # Extract specific column blocks from each row block
487487 # Each row block in swizzled format contains n_col_blocks chunks of (32, 16)
@@ -493,7 +493,7 @@ def nvfp4_slice(func, types, args, kwargs):
493493 row_start = row_block * elements_per_row_block
494494 col_start = row_start + start_col_block * elements_per_block
495495 col_end = row_start + end_col_block * elements_per_block
496- slices_to_extract .append (x ._scale_e4m3 .flatten ()[col_start :col_end ])
496+ slices_to_extract .append (x .scale .flatten ()[col_start :col_end ])
497497
498498 # Concatenate all the slices
499499 sliced_scale = torch .cat (slices_to_extract , dim = 0 )
@@ -511,7 +511,7 @@ def nvfp4_slice(func, types, args, kwargs):
511511 )
512512
513513 else :
514- scale_shaped = x ._scale_e4m3 .view (M , K // x ._block_size )
514+ scale_shaped = x .scale .view (M , K // x ._block_size )
515515
516516 if dim == 0 :
517517 sliced_scale = aten .slice .Tensor (scale_shaped , dim , start , end , step )
@@ -581,7 +581,7 @@ def nvfp4_t(func, types, args, kwargs):
581581 old = args [0 ]
582582 new = NVFP4Tensor (
583583 old .qdata .t (),
584- old ._scale_e4m3 .t (),
584+ old .scale .t (),
585585 old ._block_size ,
586586 old ._orig_dtype ,
587587 old ._per_tensor_scale ,
@@ -600,7 +600,7 @@ def nvfp4_transpose(func, types, args, kwargs):
600600 valid_3d_dims = ((1 , 2 ), (2 , 1 ), (- 1 , - 2 ), (- 2 , - 1 ))
601601 assert (dim0 , dim1 ) in valid_3d_dims , f"transpose unsupported for { dim0 = } { dim1 = } "
602602 new_qdata = func (old .qdata , dim0 , dim1 , ** kwargs )
603- new_scale = func (old ._scale_e4m3 , dim0 , dim1 , ** kwargs )
603+ new_scale = func (old .scale , dim0 , dim1 , ** kwargs )
604604 new = NVFP4Tensor (
605605 new_qdata ,
606606 new_scale ,
@@ -623,7 +623,7 @@ def nvfp4_view_op(func, types, args, kwargs):
623623 new_data = func (data , new_size , * args [2 :], ** kwargs )
624624 return NVFP4Tensor (
625625 new_data ,
626- args [0 ]._scale_e4m3 ,
626+ args [0 ].scale ,
627627 args [0 ]._block_size ,
628628 args [0 ]._orig_dtype ,
629629 args [0 ]._per_tensor_scale ,
@@ -638,10 +638,10 @@ def nvfp4_view_op(func, types, args, kwargs):
638638def nvfp4_select (func , types , args , kwargs ):
639639 old , dim , index = args
640640 assert dim == 0 , f"NVFP4Tensor aten.select.int with { dim = } is not yet supported"
641- assert len (old .qdata .shape ) == len (old ._scale_e4m3 .shape ), "unsupported"
641+ assert len (old .qdata .shape ) == len (old .scale .shape ), "unsupported"
642642 new = old .__class__ (
643643 old .qdata [index ],
644- old ._scale_e4m3 [index ],
644+ old .scale [index ],
645645 old ._block_size ,
646646 old ._orig_dtype ,
647647 old ._per_tensor_scale ,
@@ -661,9 +661,9 @@ def _addmm_nvfp4_dispatch(
661661 The only difference is whether bias is None or not.
662662 """
663663 assert a .qdata .is_contiguous ()
664- assert a ._scale_e4m3 .is_contiguous ()
664+ assert a .scale .is_contiguous ()
665665 assert b .qdata .t ().is_contiguous ()
666- assert b ._scale_e4m3 .t ().is_contiguous ()
666+ assert b .scale .t ().is_contiguous ()
667667 assert a ._block_size == 16 , f"NVFP4 requires block_size=16, got { a ._block_size } "
668668 assert b ._block_size == 16 , f"NVFP4 requires block_size=16, got { b ._block_size } "
669669
@@ -672,15 +672,15 @@ def _addmm_nvfp4_dispatch(
672672
673673 # Swizzle Dizzle
674674 if a ._is_swizzled_scales :
675- a_scale_blocked = a ._scale_e4m3 # Already swizzled
675+ a_scale_blocked = a .scale # Already swizzled
676676 else :
677- a_scale = a ._scale_e4m3 .view (M , K // a ._block_size )
677+ a_scale = a .scale .view (M , K // a ._block_size )
678678 a_scale_blocked = to_blocked (a_scale )
679679
680680 if b ._is_swizzled_scales :
681- b_scale_blocked = b ._scale_e4m3 .t () # Already swizzled
681+ b_scale_blocked = b .scale .t () # Already swizzled
682682 else :
683- b_scale = b ._scale_e4m3 .t ().view (N , K // b ._block_size )
683+ b_scale = b .scale .t ().view (N , K // b ._block_size )
684684 b_scale_blocked = to_blocked (b_scale )
685685
686686 # Merge double quant scales into 1 scale for Scale_In^D
0 commit comments