@@ -136,7 +136,7 @@ def __new__(
136136 return self
137137
138138 def __repr__ (self ):
139- return f"NVFP4Tensor: scale: { self .scale } , per_tensor_scale: { self .per_tensor_scale } , d: { self .qdata } , d_hp: { self .to_dtype (self ._orig_dtype )} "
139+ return f"NVFP4Tensor: scale: { self .scale } , per_tensor_scale: { self .per_tensor_scale } , d: { self .qdata } , d_hp: { self .dequantize (self ._orig_dtype )} "
140140
141141 def _quantization_type (self ):
142142 return f"{ self ._is_swizzled_scales = } , { self .use_triton_kernel = } , { self .act_quant_kwargs = } "
@@ -217,7 +217,7 @@ def to_nvfp4(
217217 # Do not force the NVFP4Tensor type on the returned tensor
218218 __torch_function__ = torch ._C ._disabled_torch_function_impl
219219
220- def to_dtype (self , target_dtype : torch .dtype ) -> torch .Tensor :
220+ def dequantize (self , output_dtype : Optional [ torch .dtype ] = None ) -> torch .Tensor :
221221 """Convert NVFP4Tensor back to high precision dtype.
222222
223223 Args:
@@ -226,6 +226,8 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
226226 Returns:
227227 torch.Tensor: Dequantized tensor in the target dtype
228228 """
229+ if output_dtype is None :
230+ output_dtype = self .dtype
229231 is_transposed = self .qdata .stride (- 2 ) < self .qdata .stride (- 1 )
230232 if is_transposed :
231233 leading_dims , M , K = self .shape [:- 2 ], self .shape [- 1 ], self .shape [- 2 ]
@@ -242,7 +244,7 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
242244 * leading_dims , M , K // self ._block_size , 1
243245 )
244246 data_scaled = data_f32 * scale_e4m3_reshaped .to (torch .float32 )
245- result = data_scaled .view (* leading_dims , M , K ).to (target_dtype )
247+ result = data_scaled .view (* leading_dims , M , K ).to (output_dtype )
246248
247249 if is_transposed :
248250 result = result .transpose (- 2 , - 1 )
@@ -731,7 +733,7 @@ def nvfp4_linear(func, types, args, kwargs):
731733
732734 if weight_tensor .act_quant_kwargs is None :
733735 # weight_only quant
734- weight_dequant = weight_tensor .to_dtype (weight_tensor ._orig_dtype )
736+ weight_dequant = weight_tensor .dequantize (weight_tensor ._orig_dtype )
735737 return torch .nn .functional .linear (input_tensor , weight_dequant , bias )
736738 else :
737739 # dynamic quant
@@ -759,9 +761,9 @@ def nvfp4_mm(func, types, args, kwargs):
759761 raise NotImplementedError ("NVFP4Tensor: weight must be NVFP4Tensor" )
760762
761763 if weight_tensor .act_quant_kwargs is None :
762- weight_dequant = weight_tensor .to_dtype (weight_tensor ._orig_dtype )
764+ weight_dequant = weight_tensor .dequantize (weight_tensor ._orig_dtype )
763765 if isinstance (input_tensor , NVFP4Tensor ):
764- input_dequant = input_tensor .to_dtype (input_tensor ._orig_dtype )
766+ input_dequant = input_tensor .dequantize (input_tensor ._orig_dtype )
765767 return func (input_dequant , weight_dequant )
766768 else :
767769 return func (input_tensor , weight_dequant )
@@ -791,9 +793,9 @@ def nvfp4_addmm(func, types, args, kwargs):
791793 raise NotImplementedError ("NVFP4Tensor: weight must be NVFP4Tensor" )
792794
793795 if weight_tensor .act_quant_kwargs is None :
794- weight_dequant = weight_tensor .to_dtype (weight_tensor ._orig_dtype )
796+ weight_dequant = weight_tensor .dequantize (weight_tensor ._orig_dtype )
795797 if isinstance (input_tensor , NVFP4Tensor ):
796- input_dequant = input_tensor .to_dtype (input_tensor ._orig_dtype )
798+ input_dequant = input_tensor .dequantize (input_tensor ._orig_dtype )
797799 return torch .addmm (bias , input_dequant , weight_dequant )
798800 else :
799801 return torch .addmm (bias , input_tensor , weight_dequant )
0 commit comments