Skip to content

Commit fe9be99

Browse files
authored
rename MXTensor and NVFP4Tensor's to_dtype to dequantize (#3169)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent ba4593f commit fe9be99

File tree

6 files changed

+41
-34
lines changed

6 files changed

+41
-34
lines changed

test/prototype/mx_formats/test_mx_dtensor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,16 @@ def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4):
5959
local_rank = torch.distributed.get_rank()
6060
world_size = torch.distributed.get_world_size()
6161
assert size % world_size == 0, "unsupported"
62-
x_fp8_fp32 = x_fp8.to_dtype(torch.float32)
62+
x_fp8_fp32 = x_fp8.dequantize(torch.float32)
6363
rows_per_slice = size // world_size
6464
slice_start = local_rank * rows_per_slice
6565
slice_end = (local_rank + 1) * rows_per_slice
6666
x_fp8_fp32_slice = x_fp8_fp32[slice_start:slice_end]
6767
torch.testing.assert_close(
68-
x_fp8_fp32_slice, dist_x_fp8.to_local().to_dtype(torch.float32), atol=0, rtol=0
68+
x_fp8_fp32_slice,
69+
dist_x_fp8.to_local().dequantize(torch.float32),
70+
atol=0,
71+
rtol=0,
6972
)
7073

7174

test/prototype/mx_formats/test_mx_mm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:
4949
a_scale_block = to_blocked(a_scale)
5050
b_scale_block = to_blocked(b_scale)
5151

52-
out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose(
53-
-1, -2
54-
)
52+
out_hp = a_mx.dequantize(torch.bfloat16) @ b_mx.dequantize(
53+
torch.bfloat16
54+
).transpose(-1, -2)
5555
out = mx_func(a_data, b_data, a_scale_block, b_scale_block)
5656

5757
return compute_error(out_hp, out).item()

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _test_mx(
5353
data_hp, elem_dtype, block_size, scale_calculation_mode=ScaleCalculationMode.FLOOR
5454
):
5555
data_mx = MXTensor.to_mx(data_hp, elem_dtype, block_size, scale_calculation_mode)
56-
data_mx_dq = data_mx.to_dtype(data_hp.dtype)
56+
data_mx_dq = data_mx.dequantize(data_hp.dtype)
5757

5858
def assert_sqnr_gt_threshold(orig, new, threshold):
5959
sqnr = compute_error(orig, new)
@@ -389,7 +389,7 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
389389
pack_fp6,
390390
None,
391391
)
392-
tensor_hp = tensor_mx.to_dtype(torch.float)
392+
tensor_hp = tensor_mx.dequantize(torch.float)
393393
assert torch.all(torch.isnan(tensor_hp.flatten()[0:4]))
394394
assert not torch.any(torch.isnan(tensor_hp.flatten()[4:]))
395395

@@ -436,10 +436,10 @@ def test_transpose(elem_dtype):
436436
elem_dtype,
437437
block_size,
438438
)
439-
tensor_mx_dq_t = tensor_mx.to_dtype(tensor_hp.dtype).t()
439+
tensor_mx_dq_t = tensor_mx.dequantize(tensor_hp.dtype).t()
440440

441441
tensor_mx_t = tensor_mx.t()
442-
tensor_mx_t_dq = tensor_mx_t.to_dtype(tensor_hp.dtype)
442+
tensor_mx_t_dq = tensor_mx_t.dequantize(tensor_hp.dtype)
443443

444444
assert tensor_mx_dq_t.shape == tensor_mx_t_dq.shape
445445
torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0)
@@ -461,8 +461,8 @@ def test_clone():
461461
data_mx = MXTensor.to_mx(data, torch.float8_e4m3fn, block_size)
462462
data_mx_c = data_mx.clone()
463463
torch.testing.assert_close(
464-
data_mx.to_dtype(torch.bfloat16),
465-
data_mx_c.to_dtype(torch.bfloat16),
464+
data_mx.dequantize(torch.bfloat16),
465+
data_mx_c.dequantize(torch.bfloat16),
466466
atol=0,
467467
rtol=0,
468468
)
@@ -571,7 +571,7 @@ def test_index_select():
571571

572572
x_mx_1 = x_mx[1]
573573
torch.testing.assert_close(
574-
x_mx.to_dtype(x.dtype)[1], x_mx_1.to_dtype(x.dtype), atol=0, rtol=0
574+
x_mx.dequantize(x.dtype)[1], x_mx_1.dequantize(x.dtype), atol=0, rtol=0
575575
)
576576

577577

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale):
5858
scale = None
5959

6060
x_nvfp4 = NVFP4Tensor.to_nvfp4(x, per_tensor_scale=scale)
61-
x_reconstructed = x_nvfp4.to_dtype(dtype)
61+
x_reconstructed = x_nvfp4.dequantize(dtype)
6262

6363
def assert_sqnr_gt_threshold(orig, new, threshold):
6464
sqnr = compute_error(orig, new)
@@ -91,7 +91,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
9191
x_nvfp4_t = x_nvfp4.transpose(-2, -1)
9292
x_t = x.transpose(-2, -1)
9393

94-
x_reconstructed_t = x_nvfp4_t.to_dtype(dtype)
94+
x_reconstructed_t = x_nvfp4_t.dequantize(dtype)
9595
assert_sqnr_gt_threshold(x_t, x_reconstructed_t, 8.0)
9696

9797
assert x_t.shape == x_reconstructed_t.shape, (
@@ -127,7 +127,7 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape):
127127

128128
tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=is_swizzled_scales)
129129
assert tensor._is_swizzled_scales == is_swizzled_scales
130-
reconstructed = tensor.to_dtype(torch.bfloat16)
130+
reconstructed = tensor.dequantize(torch.bfloat16)
131131
assert reconstructed.shape == data.shape
132132

133133

@@ -181,10 +181,10 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec):
181181
assert sliced_tensor._is_swizzled_scales == True
182182

183183
# Verify sliced tensor can be dequantized
184-
sliced_reconstructed = sliced_tensor.to_dtype(torch.bfloat16)
184+
sliced_reconstructed = sliced_tensor.dequantize(torch.bfloat16)
185185

186186
# Compare with direct slicing of original data
187-
original_reconstructed = tensor.to_dtype(torch.bfloat16)
187+
original_reconstructed = tensor.dequantize(torch.bfloat16)
188188
if slice_dim == 0:
189189
expected = original_reconstructed[slice_spec, :]
190190
else:
@@ -324,8 +324,8 @@ def test_nvfp4_swizzled_scales_serialization():
324324
assert reconstructed_tensor._is_swizzled_scales == True
325325

326326
# Verify functionality is preserved
327-
original_dq = original_tensor.to_dtype(torch.bfloat16)
328-
reconstructed_dq = reconstructed_tensor.to_dtype(torch.bfloat16)
327+
original_dq = original_tensor.dequantize(torch.bfloat16)
328+
reconstructed_dq = reconstructed_tensor.dequantize(torch.bfloat16)
329329

330330
torch.testing.assert_close(original_dq, reconstructed_dq, atol=1e-6, rtol=1e-6)
331331

@@ -404,8 +404,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
404404
rtol=0,
405405
)
406406

407-
x_pt_dequant = nvfp4_pt.to_dtype(dtype)
408-
x_triton_dequant = nvfp4_triton.to_dtype(dtype)
407+
x_pt_dequant = nvfp4_pt.dequantize(dtype)
408+
x_triton_dequant = nvfp4_triton.dequantize(dtype)
409409

410410
sqnr = compute_error(x_pt_dequant, x_triton_dequant)
411411
SQNR_THRESHOLD = 40.0

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -567,13 +567,15 @@ def __repr__(self):
567567
def _quantization_type(self):
568568
return f"{self._elem_dtype=}, {self._block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}"
569569

570-
def to_dtype(self, target_dtype):
570+
def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
571+
if output_dtype is None:
572+
output_dtype = self.dtype
571573
return to_dtype(
572574
self.qdata,
573575
self.scale,
574576
self._elem_dtype,
575577
self._block_size,
576-
target_dtype,
578+
output_dtype,
577579
self._pack_fp6,
578580
)
579581

@@ -718,8 +720,8 @@ def _addmm_mx_dispatch(
718720

719721
else:
720722
# emulated MX gemm
721-
a_hp = a.to_dtype(a._orig_dtype)
722-
b_hp = b.to_dtype(b._orig_dtype)
723+
a_hp = a.dequantize(a._orig_dtype)
724+
b_hp = b.dequantize(b._orig_dtype)
723725
# assert memory layout we expect to be required in hardware
724726
assert a_hp.is_contiguous()
725727
assert b_hp.t().is_contiguous()
@@ -780,7 +782,7 @@ def mx_cast_up_op(func, types, args, kwargs):
780782

781783
def unwrap(x):
782784
if isinstance(x, MXTensor):
783-
return x.to_dtype(x._orig_dtype)
785+
return x.dequantize(x._orig_dtype)
784786
return x
785787

786788
new_args = tree_map(unwrap, args)

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def __new__(
136136
return self
137137

138138
def __repr__(self):
139-
return f"NVFP4Tensor: scale: {self.scale}, per_tensor_scale: {self.per_tensor_scale}, d: {self.qdata}, d_hp: {self.to_dtype(self._orig_dtype)}"
139+
return f"NVFP4Tensor: scale: {self.scale}, per_tensor_scale: {self.per_tensor_scale}, d: {self.qdata}, d_hp: {self.dequantize(self._orig_dtype)}"
140140

141141
def _quantization_type(self):
142142
return f"{self._is_swizzled_scales=}, {self.use_triton_kernel=}, {self.act_quant_kwargs=}"
@@ -217,7 +217,7 @@ def to_nvfp4(
217217
# Do not force the NVFP4Tensor type on the returned tensor
218218
__torch_function__ = torch._C._disabled_torch_function_impl
219219

220-
def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
220+
def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
221221
"""Convert NVFP4Tensor back to high precision dtype.
222222
223223
Args:
@@ -226,6 +226,8 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
226226
Returns:
227227
torch.Tensor: Dequantized tensor in the target dtype
228228
"""
229+
if output_dtype is None:
230+
output_dtype = self.dtype
229231
is_transposed = self.qdata.stride(-2) < self.qdata.stride(-1)
230232
if is_transposed:
231233
leading_dims, M, K = self.shape[:-2], self.shape[-1], self.shape[-2]
@@ -242,7 +244,7 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
242244
*leading_dims, M, K // self._block_size, 1
243245
)
244246
data_scaled = data_f32 * scale_e4m3_reshaped.to(torch.float32)
245-
result = data_scaled.view(*leading_dims, M, K).to(target_dtype)
247+
result = data_scaled.view(*leading_dims, M, K).to(output_dtype)
246248

247249
if is_transposed:
248250
result = result.transpose(-2, -1)
@@ -731,7 +733,7 @@ def nvfp4_linear(func, types, args, kwargs):
731733

732734
if weight_tensor.act_quant_kwargs is None:
733735
# weight_only quant
734-
weight_dequant = weight_tensor.to_dtype(weight_tensor._orig_dtype)
736+
weight_dequant = weight_tensor.dequantize(weight_tensor._orig_dtype)
735737
return torch.nn.functional.linear(input_tensor, weight_dequant, bias)
736738
else:
737739
# dynamic quant
@@ -759,9 +761,9 @@ def nvfp4_mm(func, types, args, kwargs):
759761
raise NotImplementedError("NVFP4Tensor: weight must be NVFP4Tensor")
760762

761763
if weight_tensor.act_quant_kwargs is None:
762-
weight_dequant = weight_tensor.to_dtype(weight_tensor._orig_dtype)
764+
weight_dequant = weight_tensor.dequantize(weight_tensor._orig_dtype)
763765
if isinstance(input_tensor, NVFP4Tensor):
764-
input_dequant = input_tensor.to_dtype(input_tensor._orig_dtype)
766+
input_dequant = input_tensor.dequantize(input_tensor._orig_dtype)
765767
return func(input_dequant, weight_dequant)
766768
else:
767769
return func(input_tensor, weight_dequant)
@@ -791,9 +793,9 @@ def nvfp4_addmm(func, types, args, kwargs):
791793
raise NotImplementedError("NVFP4Tensor: weight must be NVFP4Tensor")
792794

793795
if weight_tensor.act_quant_kwargs is None:
794-
weight_dequant = weight_tensor.to_dtype(weight_tensor._orig_dtype)
796+
weight_dequant = weight_tensor.dequantize(weight_tensor._orig_dtype)
795797
if isinstance(input_tensor, NVFP4Tensor):
796-
input_dequant = input_tensor.to_dtype(input_tensor._orig_dtype)
798+
input_dequant = input_tensor.dequantize(input_tensor._orig_dtype)
797799
return torch.addmm(bias, input_dequant, weight_dequant)
798800
else:
799801
return torch.addmm(bias, input_tensor, weight_dequant)

0 commit comments

Comments
 (0)