From 40de7e0d1c1387d1fc39f17ea4716eeb603bf087 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 6 Nov 2025 01:11:31 +0000 Subject: [PATCH 01/15] Move dyn_int8_act_int4_wei_cpu_layout to prototype --- torchao/prototype/dtypes/uintx/__init__.py | 2 + .../uintx/dyn_int8_act_int4_wei_cpu_layout.py | 319 ++++++++++++++++++ 2 files changed, 321 insertions(+) create mode 100644 torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py diff --git a/torchao/prototype/dtypes/uintx/__init__.py b/torchao/prototype/dtypes/uintx/__init__.py index 53edddb8ac..89c1f3f810 100644 --- a/torchao/prototype/dtypes/uintx/__init__.py +++ b/torchao/prototype/dtypes/uintx/__init__.py @@ -6,8 +6,10 @@ from .block_sparse_layout import BlockSparseLayout from .cutlass_int4_packed_layout import CutlassInt4PackedLayout +from .dyn_int8_act_int4_wei_cpu_layout import Int8DynamicActInt4WeightCPULayout __all__ = [ "BlockSparseLayout", "CutlassInt4PackedLayout", + "Int8DynamicActInt4WeightCPULayout", ] diff --git a/torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py b/torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py new file mode 100644 index 0000000000..8d0cfaddeb --- /dev/null +++ b/torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py @@ -0,0 +1,319 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Tuple + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.utils import Layout, PlainLayout, is_device +from torchao.utils import torch_version_at_least + +from .int4_cpu_layout import ( + Int4CPUAQTTensorImpl, + _is_float, +) + +aten = torch.ops.aten + + +@dataclass(frozen=True) +class Int8DynamicActInt4WeightCPULayout(Layout): + """Layout class for da8w4 CPU layout for affine quantized tensor""" + + pass + + +@register_layout(Int8DynamicActInt4WeightCPULayout) +class DA8W4CPUAQTTensorImpl(Int4CPUAQTTensorImpl): + """TensorImpl for da8w4 CPU layout for affine quantized tensor + It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of + dimension: [n][k / 2] (uint8 dtype) + It is similar to Int4CPUAQTTensorImpl but with a different memory layout of weight data + fields: + packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout + scales (torch.Tensor): the scales Tensor used to map between floating point tensor to quantized tensor + qzeros (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor + """ + + def __new__( + cls, + packed_weight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + compensation: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + compensation: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + self.packed_weight = packed_weight + self.scales = scales + self.qzeros = qzeros + self.compensation = compensation + self.transposed = transposed + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_weight", "scales", "qzeros", "compensation"], [ + self.transposed, + self._layout, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scales, qzeros, compensation = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scales"], + tensor_data_dict["qzeros"], + tensor_data_dict["compensation"], + ) + ( + transposed, + _layout, + ) = tensor_attributes + return cls(packed_weight, scales, qzeros, compensation, transposed, _layout) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + assert isinstance(_layout, Int8DynamicActInt4WeightCPULayout) + assert int_data.dtype == torch.uint8, "DA8W4 CPU: expects uint8 weight" + assert int_data.shape[1] % 2 == 0, "DA8W4 CPU: expects even number of columns" + if scale.dim() == 1: + scale.unsqueeze_(-1) + scale = scale.to(torch.float) + if zero_point.dim() == 1: + zero_point.unsqueeze_(-1) + + # Pack weight from [N, K] to [N / block_n, K / block_k, block_k, block_n]. + # Pack the inner blocks [block_k, block_n] to VNNI layout if AMX is available. + # Pack scales/qzeros from [N, num_groups] to [N / block_n, num_groups, block_n]. + # Compensation shape = [N / block_n, K / block_k, block_n]. + weight_int4, scales, qzeros, compensation = ( + torch.ops.torchao.da8w4_linear_prepack_cpu(int_data, scale, zero_point) + ) + return cls(weight_int4, scales, qzeros, compensation, False, _layout) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_weight), + fn(self.scales), + fn(self.qzeros), + fn(self.compensation), + self.transposed, + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + if func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + transposed = DA8W4CPUAQTTensorImpl( + args[0].packed_weight, + args[0].scales, + args[0].qzeros, + args[0].compensation, + not args[0].transposed, + args[0]._layout, + ) + return return_and_correct_aliasing(func, args, kwargs, transposed) + else: + return super().__torch_dispatch__(func, types, args, kwargs) + + __torch_function__ = torch._C._disabled_torch_function_impl + + @property + def block_size(self): + assert len(self.packed_weight.shape) == 2 + weight_shape = self.packed_weight.shape + N = weight_shape[0] + K = weight_shape[1] * 2 + groups = self.scales.numel() // N + group_size = K // groups + return (1, group_size) + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Unpack weight by linear(eye(K), packed_weight).t() + packed_w_shape = self.packed_weight.shape + if len(packed_w_shape) == 4: + K = packed_w_shape[1] * packed_w_shape[2] + else: + K = packed_w_shape[1] + x = torch.eye(K).to(torch.uint8) + x_scale = torch.ones(K).float() + x_qzero = torch.zeros(K).to(torch.int32) + w_scale = torch.ones_like(self.scales).float() + w_qzero = torch.zeros_like(self.qzeros).to(torch.int8) + plain_weight = torch.ops.torchao.da8w4_linear_cpu.default( + x, + x_scale, + x_qzero, + self.packed_weight, + w_scale, + w_qzero, + self.compensation, + None, # bias + torch.float, # out_dtype + ) + plain_weight = plain_weight.t().contiguous() + plain_weight = plain_weight.to(torch.int8) + + if self.scales.dim() == 2: + assert self.qzeros.dim() == 2 + plain_scales = self.scales + plain_qzeros = self.qzeros + else: + assert self.scales.dim() == 3 and self.qzeros.dim() == 3 + packed_shape = self.scales.shape # [Nc, G, block_n] + plain_scales = ( + self.scales.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) + ) + plain_qzeros = ( + self.qzeros.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) + ) + + return plain_weight, plain_scales, plain_qzeros + + +def _aqt_is_uint8(aqt): + """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.uint8 + and aqt.quant_min == 0 + and aqt.quant_max == 255 + ) + + +def _aqt_is_int8(aqt): + """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.int8 + and aqt.quant_min == -127 + and aqt.quant_max == 127 + ) + + +def _aqt_is_uint4(aqt): + """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.uint8 + and aqt.quant_min == 0 + and aqt.quant_max == 15 + ) + + +def _linear_int8_act_int4_weight_cpu_check(input_tensor, weight_tensor, bias): + return ( + torch_version_at_least("2.7.0") + and is_device(input_tensor.device.type, "cpu") + and is_device(weight_tensor.device.type, "cpu") + and (bias is None or is_device(bias.device.type, "cpu")) + and isinstance(input_tensor, AffineQuantizedTensor) + and (_aqt_is_uint8(input_tensor) or _aqt_is_int8(input_tensor)) + and _is_float(input_tensor.dtype) + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_uint4(weight_tensor) + and _is_float(weight_tensor.dtype) + and isinstance(weight_tensor._layout, Int8DynamicActInt4WeightCPULayout) + ) + + +def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias): + assert torch_version_at_least("2.7.0"), ( + f"Requires PyTorch version at least 2.7, but got: {torch.__version__}" + ) + if _aqt_is_int8(input_tensor): + assert torch_version_at_least("2.8.0"), ( + f"Requires PyTorch version at least 2.8, but got: {torch.__version__}" + ) + assert is_device(input_tensor.device.type, "cpu"), ( + f"For CPU device only but got: {input_tensor.device}" + ) + assert weight_tensor.block_size[0] == 1, ( + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"need input_tensor shape: {input_tensor.shape} final" + f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " + ) + + act_mat = input_tensor + act = act_mat.tensor_impl.int_data + act_scales = act_mat.tensor_impl.scale + act_qzeros = act_mat.tensor_impl.zero_point + + packed_weight = weight_tensor.tensor_impl.packed_weight + wei_scales = weight_tensor.tensor_impl.scales + wei_qzeros = weight_tensor.tensor_impl.qzeros + compensation = weight_tensor.tensor_impl.compensation + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + # reshape to 2D + act = act.reshape(-1, act.shape[-1]) + + y = torch.ops.torchao.da8w4_linear_cpu.default( + act.contiguous(), + act_scales, + act_qzeros, + packed_weight, + wei_scales, + wei_qzeros, + compensation, + bias.float() if bias is not None else bias, # requires bias to be float + orig_dtype, # out_dtype + ) + + # remove out_feature padding + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + return y.to(orig_dtype) + + +# Register the concat linear fusion pass +# from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass + +# register_da8w4_concat_linear_cpu_pass() From d3db93eca4d98f527a51153dde709c6805ae5f17 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 6 Nov 2025 03:48:30 +0000 Subject: [PATCH 02/15] Move dyn_int8_act_int4_wei_cpu_layout to prototype --- test/quantization/test_da8w4_cpu.py | 27 ++ torchao/dtypes/__init__.py | 2 +- torchao/dtypes/affine_quantized_tensor_ops.py | 8 +- .../uintx/dyn_int8_act_int4_wei_cpu_layout.py | 326 +----------------- torchao/prototype/dtypes/__init__.py | 7 +- .../uintx/dyn_int8_act_int4_wei_cpu_layout.py | 7 +- 6 files changed, 58 insertions(+), 319 deletions(-) diff --git a/test/quantization/test_da8w4_cpu.py b/test/quantization/test_da8w4_cpu.py index d4f68c4333..c4b0eac39f 100644 --- a/test/quantization/test_da8w4_cpu.py +++ b/test/quantization/test_da8w4_cpu.py @@ -176,5 +176,32 @@ def forward(self, x): common_utils.instantiate_parametrized_tests(TestDa8w4Cpu) +# TODO: Remove this test once the deprecated API has been removed +def test_int8_dynamic_act_int4_weight_cpu_layout_deprecated(): + import sys + import warnings + + # We need to clear the cache to force re-importing and trigger the warning again. + modules_to_clear = [ + "torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout", + "torchao.dtypes", + ] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + from torchao.dtypes import Int8DynamicActInt4WeightCPULayout # noqa: F401 + + warnings.simplefilter("always") # Ensure all warnings are captured + assert any( + issubclass(warning.category, DeprecationWarning) + and "Int8DynamicActInt4WeightCPULayout" in str(warning.message) + for warning in w + ), ( + f"Expected deprecation warning for Int8DynamicActInt4WeightCPULayout, got: {[str(warning.message) for warning in w]}" + ) + + if __name__ == "__main__": run_tests() diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 252498bc97..354692e794 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -16,7 +16,6 @@ from .uintx import ( Int4CPULayout, Int4XPULayout, - Int8DynamicActInt4WeightCPULayout, MarlinQQQLayout, MarlinQQQTensor, MarlinSparseLayout, @@ -29,6 +28,7 @@ ) from .uintx.block_sparse_layout import BlockSparseLayout from .uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout +from .uintx.dyn_int8_act_int4_wei_cpu_layout import Int8DynamicActInt4WeightCPULayout from .utils import ( Layout, PlainLayout, diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index e46809059e..3816f9bf1f 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -25,10 +25,6 @@ _linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl, ) -from torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import ( - _linear_int8_act_int4_weight_cpu_check, - _linear_int8_act_int4_weight_cpu_impl, -) from torchao.dtypes.uintx.gemlite_layout import ( _linear_fp_act_int4_weight_gemlite_check, _linear_fp_act_int4_weight_gemlite_impl, @@ -94,6 +90,10 @@ _linear_int8_act_int4_weight_cutlass_check, _linear_int8_act_int4_weight_cutlass_impl, ) +from torchao.prototype.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import ( + _linear_int8_act_int4_weight_cpu_check, + _linear_int8_act_int4_weight_cpu_impl, +) from torchao.quantization.quant_primitives import ( ZeroPointDomain, _dequantize_affine_no_zero_point, diff --git a/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py index 8d0cfaddeb..d66f70e2ee 100644 --- a/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py +++ b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py @@ -3,317 +3,25 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from typing import Tuple -import torch -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, -) +# Backward compatibility stub - imports from the new location +import warnings -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, - register_layout, +warnings.warn( + "Importing from torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout is deprecated. " + "Please use 'from torchao.prototype.dtypes import Int8DynamicActInt4WeightCPULayout' instead. " + "This import path will be removed in a future release of torchao. " + "See https://github.com/pytorch/ao/issues/2752 for more details.", + DeprecationWarning, + stacklevel=2, ) -from torchao.dtypes.utils import Layout, PlainLayout, is_device -from torchao.utils import torch_version_at_least -from .int4_cpu_layout import ( - Int4CPUAQTTensorImpl, - _is_float, +from torchao.prototype.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import ( # noqa: F401 + DA8W4CPUAQTTensorImpl, # noqa: F401 + Int8DynamicActInt4WeightCPULayout, # noqa: F401 + _aqt_is_int8, # noqa: F401 + _aqt_is_uint4, # noqa: F401 + _aqt_is_uint8, # noqa: F401 + _linear_int8_act_int4_weight_cpu_check, # noqa: F401 + _linear_int8_act_int4_weight_cpu_impl, # noqa: F401 ) - -aten = torch.ops.aten - - -@dataclass(frozen=True) -class Int8DynamicActInt4WeightCPULayout(Layout): - """Layout class for da8w4 CPU layout for affine quantized tensor""" - - pass - - -@register_layout(Int8DynamicActInt4WeightCPULayout) -class DA8W4CPUAQTTensorImpl(Int4CPUAQTTensorImpl): - """TensorImpl for da8w4 CPU layout for affine quantized tensor - It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of - dimension: [n][k / 2] (uint8 dtype) - It is similar to Int4CPUAQTTensorImpl but with a different memory layout of weight data - fields: - packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout - scales (torch.Tensor): the scales Tensor used to map between floating point tensor to quantized tensor - qzeros (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor - """ - - def __new__( - cls, - packed_weight: torch.Tensor, - scales: torch.Tensor, - qzeros: torch.Tensor, - compensation: torch.Tensor, - transposed: bool, - _layout: Layout, - ): - kwargs = {} - kwargs["device"] = packed_weight.device - kwargs["layout"] = ( - kwargs.get("layout") - if kwargs.get("layout", False) - else packed_weight.layout - ) - kwargs["dtype"] = packed_weight.dtype - kwargs["requires_grad"] = False - shape = packed_weight.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - packed_weight: torch.Tensor, - scales: torch.Tensor, - qzeros: torch.Tensor, - compensation: torch.Tensor, - transposed: bool, - _layout: Layout, - ): - self.packed_weight = packed_weight - self.scales = scales - self.qzeros = qzeros - self.compensation = compensation - self.transposed = transposed - self._layout = _layout - - def __tensor_flatten__(self): - return ["packed_weight", "scales", "qzeros", "compensation"], [ - self.transposed, - self._layout, - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - packed_weight, scales, qzeros, compensation = ( - tensor_data_dict["packed_weight"], - tensor_data_dict["scales"], - tensor_data_dict["qzeros"], - tensor_data_dict["compensation"], - ) - ( - transposed, - _layout, - ) = tensor_attributes - return cls(packed_weight, scales, qzeros, compensation, transposed, _layout) - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - _layout: Layout, - ): - assert isinstance(_layout, Int8DynamicActInt4WeightCPULayout) - assert int_data.dtype == torch.uint8, "DA8W4 CPU: expects uint8 weight" - assert int_data.shape[1] % 2 == 0, "DA8W4 CPU: expects even number of columns" - if scale.dim() == 1: - scale.unsqueeze_(-1) - scale = scale.to(torch.float) - if zero_point.dim() == 1: - zero_point.unsqueeze_(-1) - - # Pack weight from [N, K] to [N / block_n, K / block_k, block_k, block_n]. - # Pack the inner blocks [block_k, block_n] to VNNI layout if AMX is available. - # Pack scales/qzeros from [N, num_groups] to [N / block_n, num_groups, block_n]. - # Compensation shape = [N / block_n, K / block_k, block_n]. - weight_int4, scales, qzeros, compensation = ( - torch.ops.torchao.da8w4_linear_prepack_cpu(int_data, scale, zero_point) - ) - return cls(weight_int4, scales, qzeros, compensation, False, _layout) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.packed_weight), - fn(self.scales), - fn(self.qzeros), - fn(self.compensation), - self.transposed, - self._layout, - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - if func is aten.t.default: - """we don't need to repack the weight and just rely on external - shape being changed and record the status of transpose/no-transpose - """ - transposed = DA8W4CPUAQTTensorImpl( - args[0].packed_weight, - args[0].scales, - args[0].qzeros, - args[0].compensation, - not args[0].transposed, - args[0]._layout, - ) - return return_and_correct_aliasing(func, args, kwargs, transposed) - else: - return super().__torch_dispatch__(func, types, args, kwargs) - - __torch_function__ = torch._C._disabled_torch_function_impl - - @property - def block_size(self): - assert len(self.packed_weight.shape) == 2 - weight_shape = self.packed_weight.shape - N = weight_shape[0] - K = weight_shape[1] * 2 - groups = self.scales.numel() // N - group_size = K // groups - return (1, group_size) - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # Unpack weight by linear(eye(K), packed_weight).t() - packed_w_shape = self.packed_weight.shape - if len(packed_w_shape) == 4: - K = packed_w_shape[1] * packed_w_shape[2] - else: - K = packed_w_shape[1] - x = torch.eye(K).to(torch.uint8) - x_scale = torch.ones(K).float() - x_qzero = torch.zeros(K).to(torch.int32) - w_scale = torch.ones_like(self.scales).float() - w_qzero = torch.zeros_like(self.qzeros).to(torch.int8) - plain_weight = torch.ops.torchao.da8w4_linear_cpu.default( - x, - x_scale, - x_qzero, - self.packed_weight, - w_scale, - w_qzero, - self.compensation, - None, # bias - torch.float, # out_dtype - ) - plain_weight = plain_weight.t().contiguous() - plain_weight = plain_weight.to(torch.int8) - - if self.scales.dim() == 2: - assert self.qzeros.dim() == 2 - plain_scales = self.scales - plain_qzeros = self.qzeros - else: - assert self.scales.dim() == 3 and self.qzeros.dim() == 3 - packed_shape = self.scales.shape # [Nc, G, block_n] - plain_scales = ( - self.scales.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) - ) - plain_qzeros = ( - self.qzeros.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) - ) - - return plain_weight, plain_scales, plain_qzeros - - -def _aqt_is_uint8(aqt): - """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" - return ( - aqt.tensor_impl.dtype == torch.uint8 - and aqt.quant_min == 0 - and aqt.quant_max == 255 - ) - - -def _aqt_is_int8(aqt): - """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" - return ( - aqt.tensor_impl.dtype == torch.int8 - and aqt.quant_min == -127 - and aqt.quant_max == 127 - ) - - -def _aqt_is_uint4(aqt): - """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" - return ( - aqt.tensor_impl.dtype == torch.uint8 - and aqt.quant_min == 0 - and aqt.quant_max == 15 - ) - - -def _linear_int8_act_int4_weight_cpu_check(input_tensor, weight_tensor, bias): - return ( - torch_version_at_least("2.7.0") - and is_device(input_tensor.device.type, "cpu") - and is_device(weight_tensor.device.type, "cpu") - and (bias is None or is_device(bias.device.type, "cpu")) - and isinstance(input_tensor, AffineQuantizedTensor) - and (_aqt_is_uint8(input_tensor) or _aqt_is_int8(input_tensor)) - and _is_float(input_tensor.dtype) - and isinstance(input_tensor._layout, PlainLayout) - and isinstance(weight_tensor, AffineQuantizedTensor) - and _aqt_is_uint4(weight_tensor) - and _is_float(weight_tensor.dtype) - and isinstance(weight_tensor._layout, Int8DynamicActInt4WeightCPULayout) - ) - - -def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias): - assert torch_version_at_least("2.7.0"), ( - f"Requires PyTorch version at least 2.7, but got: {torch.__version__}" - ) - if _aqt_is_int8(input_tensor): - assert torch_version_at_least("2.8.0"), ( - f"Requires PyTorch version at least 2.8, but got: {torch.__version__}" - ) - assert is_device(input_tensor.device.type, "cpu"), ( - f"For CPU device only but got: {input_tensor.device}" - ) - assert weight_tensor.block_size[0] == 1, ( - f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" - ) - assert input_tensor.shape[-1] == weight_tensor.shape[1], ( - f"need input_tensor shape: {input_tensor.shape} final" - f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " - ) - - act_mat = input_tensor - act = act_mat.tensor_impl.int_data - act_scales = act_mat.tensor_impl.scale - act_qzeros = act_mat.tensor_impl.zero_point - - packed_weight = weight_tensor.tensor_impl.packed_weight - wei_scales = weight_tensor.tensor_impl.scales - wei_qzeros = weight_tensor.tensor_impl.qzeros - compensation = weight_tensor.tensor_impl.compensation - - orig_act_size = act_mat.size() - orig_dtype = act_mat.dtype - - # reshape to 2D - act = act.reshape(-1, act.shape[-1]) - - y = torch.ops.torchao.da8w4_linear_cpu.default( - act.contiguous(), - act_scales, - act_qzeros, - packed_weight, - wei_scales, - wei_qzeros, - compensation, - bias.float() if bias is not None else bias, # requires bias to be float - orig_dtype, # out_dtype - ) - - # remove out_feature padding - orig_out_features = weight_tensor.shape[-2] - y = y[:, :orig_out_features] - y = y.reshape(*orig_act_size[:-1], orig_out_features) - - return y.to(orig_dtype) - - -# Register the concat linear fusion pass -# from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass - -# register_da8w4_concat_linear_cpu_pass() diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py index 25f139d583..52a5aec425 100644 --- a/torchao/prototype/dtypes/__init__.py +++ b/torchao/prototype/dtypes/__init__.py @@ -4,9 +4,14 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from .uintx import BlockSparseLayout, CutlassInt4PackedLayout +from .uintx import ( + BlockSparseLayout, + CutlassInt4PackedLayout, + Int8DynamicActInt4WeightCPULayout, +) __all__ = [ "BlockSparseLayout", "CutlassInt4PackedLayout", + "Int8DynamicActInt4WeightCPULayout", ] diff --git a/torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py b/torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py index 8d0cfaddeb..24cc02e358 100644 --- a/torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py +++ b/torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py @@ -15,13 +15,12 @@ AffineQuantizedTensor, register_layout, ) -from torchao.dtypes.utils import Layout, PlainLayout, is_device -from torchao.utils import torch_version_at_least - -from .int4_cpu_layout import ( +from torchao.dtypes.uintx.int4_cpu_layout import ( Int4CPUAQTTensorImpl, _is_float, ) +from torchao.dtypes.utils import Layout, PlainLayout, is_device +from torchao.utils import torch_version_at_least aten = torch.ops.aten From 6d87fb3b39838d4c00efae981e1c0c35795a9d0d Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 6 Nov 2025 22:33:23 +0000 Subject: [PATCH 03/15] Move marlin_qqq_tensor to prototype --- benchmarks/microbenchmarks/utils.py | 2 +- test/quantization/test_marlin_qqq.py | 22 +- torchao/_models/llama/generate.py | 2 +- torchao/dtypes/affine_quantized_tensor_ops.py | 8 +- torchao/dtypes/uintx/__init__.py | 11 +- torchao/dtypes/uintx/marlin_qqq_tensor.py | 359 +----------------- torchao/prototype/dtypes/__init__.py | 6 + torchao/prototype/dtypes/uintx/__init__.py | 8 + .../dtypes/uintx/marlin_qqq_tensor.py | 351 +++++++++++++++++ 9 files changed, 415 insertions(+), 354 deletions(-) create mode 100644 torchao/prototype/dtypes/uintx/marlin_qqq_tensor.py diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index d7300a6a81..2c6a443a86 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -218,7 +218,7 @@ def string_to_config( ) if "marlin" in quantization: if "qqq" in quantization: - from torchao.dtypes import MarlinQQQLayout + from torchao.prototype.dtypes import MarlinQQQLayout return Int8DynamicActivationInt4WeightConfig( group_size=128, diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index e0733520ff..ec52a71545 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -10,7 +10,7 @@ from torch import nn from torch.testing._internal.common_utils import TestCase, run_tests -from torchao.dtypes import MarlinQQQLayout +from torchao.prototype.dtypes import MarlinQQQLayout from torchao.quantization.marlin_qqq import ( pack_to_marlin_qqq, unpack_from_marlin_qqq, @@ -132,5 +132,25 @@ def test_pack_unpack_equivalence(self): ) +def test_marlin_qqq_tensor_deprecation_warning(): + """Test that importing from the old location raises a deprecation warning""" + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + # Import from the old deprecated location + from torchao.dtypes.uintx.marlin_qqq_tensor import ( # noqa: F401 + MarlinQQQLayout, + ) + + # Verify the deprecation warning was raised + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "torchao.dtypes.uintx.marlin_qqq_tensor is deprecated" in str( + w[-1].message + ) + assert "torchao.prototype.dtypes import" in str(w[-1].message) + + if __name__ == "__main__": run_tests() diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index da1b848bcb..fc3d371139 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -460,7 +460,7 @@ def ffn_or_attn_only(mod, fqn): ) if "marlin" in quantization: if "qqq" in quantization: - from torchao.dtypes import MarlinQQQLayout + from torchao.prototype.dtypes import MarlinQQQLayout quantize_( model, diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 3816f9bf1f..21f13729dd 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -39,10 +39,6 @@ _linear_fp_act_uint4_weight_int8_zero_check, _linear_fp_act_uint4_weight_int8_zero_impl, ) -from torchao.dtypes.uintx.marlin_qqq_tensor import ( - _linear_int8_act_int4_weight_marlin_qqq_check, - _linear_int8_act_int4_weight_marlin_qqq_impl, -) from torchao.dtypes.uintx.marlin_sparse_layout import ( _linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl, @@ -94,6 +90,10 @@ _linear_int8_act_int4_weight_cpu_check, _linear_int8_act_int4_weight_cpu_impl, ) +from torchao.prototype.dtypes.uintx.marlin_qqq_tensor import ( + _linear_int8_act_int4_weight_marlin_qqq_check, + _linear_int8_act_int4_weight_marlin_qqq_impl, +) from torchao.quantization.quant_primitives import ( ZeroPointDomain, _dequantize_affine_no_zero_point, diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index b76e80e0fc..71106d809d 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,3 +1,9 @@ +from torchao.prototype.dtypes.uintx.marlin_qqq_tensor import ( + MarlinQQQLayout, + MarlinQQQTensor, + to_marlinqqq_quantized_intx, +) + from .dyn_int8_act_int4_wei_cpu_layout import ( Int8DynamicActInt4WeightCPULayout, ) @@ -7,11 +13,6 @@ from .int4_xpu_layout import ( Int4XPULayout, ) -from .marlin_qqq_tensor import ( - MarlinQQQLayout, - MarlinQQQTensor, - to_marlinqqq_quantized_intx, -) from .marlin_sparse_layout import ( MarlinSparseLayout, ) diff --git a/torchao/dtypes/uintx/marlin_qqq_tensor.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py index 04066a6c65..19d16a1e9f 100644 --- a/torchao/dtypes/uintx/marlin_qqq_tensor.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -3,349 +3,24 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import logging -import math -from dataclasses import dataclass -from typing import Optional, Tuple -import torch -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, -) +# Backward compatibility stub - imports from the new location +import warnings -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, - get_tensor_impl_constructor, - register_layout, -) -from torchao.dtypes.uintx.plain_layout import ( - _aqt_is_int8_reduced_range, +warnings.warn( + "Importing from torchao.dtypes.uintx.marlin_qqq_tensor is deprecated. " + "Please use 'from torchao.prototype.dtypes import MarlinQQQLayout, MarlinQQQTensor' instead. " + "This import path will be removed in a future release of torchao. " + "See https://github.com/pytorch/ao/issues/2752 for more details.", + DeprecationWarning, + stacklevel=2, ) -from torchao.dtypes.utils import AQTTensorImpl, Layout -from torchao.quantization.quant_primitives import ( - ZeroPointDomain, - _choose_qparams_and_quantize_affine_qqq, - _dequantize_affine_qqq, -) - -logger = logging.getLogger(__name__) - -aten = torch.ops.aten - - -class MarlinQQQTensor(AffineQuantizedTensor): - """MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. - - To see what happens during _choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization, - please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py - and check the two quant primitive ops: _choose_qparams_and_quantize_affine_qqq and _dequantize_affine_qqq - """ - - def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: - if output_dtype is None: - output_dtype = self.dtype - - int_data, s_group, s_channel = self.tensor_impl.get_plain() - nbits = int(math.log2(self.quant_max - self.quant_min + 1)) - group_size = max(self.block_size) - return _dequantize_affine_qqq( - int_data, s_group, s_channel, nbits, group_size, output_dtype - ) - - @classmethod - def from_hp_to_intx( - cls, - input_float: torch.Tensor, - block_size: Tuple[int, ...], - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - _layout: Optional[Layout] = None, - ): - """Converts a floating point tensor to a Marlin QQQ quantized tensor.""" - if zero_point_domain is None: - raise ValueError("Please use ZeroPointDomain.NONE instead of None") - original_shape = input_float.shape - input_float = _layout.pre_process(input_float) - nbits = int(math.log2(quant_max - quant_min + 1)) - group_size = max(block_size) - data, s_group, s_channel, _ = _choose_qparams_and_quantize_affine_qqq( - input_float, nbits, group_size - ) - tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) - tensor_impl = tensor_impl_ctr(data, s_group, s_channel, _layout) - return cls( - tensor_impl, - block_size, - original_shape, - quant_min, - quant_max, - zero_point_domain, - dtype=input_float.dtype, - ) - - -@dataclass(frozen=True) -class MarlinQQQLayout(Layout): - """MarlinQQQLayout is a layout class for Marlin QQQ quantization.""" - - pass - - -@register_layout(MarlinQQQLayout) -class MarlinQQQAQTTensorImpl(AQTTensorImpl): - """ - TensorImpl storage class for sparse_qqq layout for affine quantized tensor. - - Can only be used with 4 bits quantization for now. - - Original marlin documentation and information: - https://github.com/IST-DASLab/marlin/tree/master - - Marlin qqq information: - https://github.com/HandH1998/QQQ/tree/main - https://arxiv.org/pdf/2406.09904 - - fields: - original_shape (torch.Size): the original shape of the tensor. used to unpack the tensor to the original shape - group_size (int): the group size used to pack the tensor - num_bits (int): the number of bits used to quantize the tensor - """ - - @staticmethod - def __new__( - cls, - int_data: torch.Tensor, - s_group: torch.Tensor, - s_channel: torch.Tensor, - _layout: Layout, - original_shape: torch.Size, - group_size: int, - num_bits: int, - ): - kwargs = {} - kwargs["device"] = int_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout - ) - kwargs["dtype"] = int_data.dtype - kwargs["requires_grad"] = False - shape = int_data.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - int_data: torch.Tensor, - s_group: torch.Tensor, - s_channel: torch.Tensor, - _layout: Layout, - original_shape: torch.Size, - group_size: int, - num_bits: int, - ): - self.int_data = int_data - self.s_group = s_group - self.s_channel = s_channel - self._layout = _layout - self.original_shape = original_shape - self.group_size = group_size - self.num_bits = num_bits - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - raise NotImplementedError( - f"MarlinQQQAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - def __tensor_flatten__(self): - return ["int_data", "s_group", "s_channel"], [ - self._layout, - self.original_shape, - self.group_size, - self.num_bits, - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - int_data = tensor_data_dict["int_data"] - s_group = tensor_data_dict["s_group"] - s_channel = tensor_data_dict["s_channel"] - _layout, original_shape, group_size, num_bits = tensor_attributes - return cls( - int_data, s_group, s_channel, _layout, original_shape, group_size, num_bits - ) - - def get_plain(self): - from torchao.quantization.marlin_qqq import ( - unpack_from_marlin_qqq, - ) - int_data_expanded, s_group_expanded, s_channel_expanded = ( - unpack_from_marlin_qqq( - self.int_data, - self.s_group, - self.s_channel, - self.original_shape, - self.num_bits, - self.group_size, - ) - ) - int_data_expanded_t = int_data_expanded.t() - s_group_expanded_t = s_group_expanded.t() - s_channel_expanded_t = s_channel_expanded.t() - return int_data_expanded_t, s_group_expanded_t, s_channel_expanded_t - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - s_group: torch.Tensor, - s_channel: torch.Tensor, - _layout: Layout, - ): - from torchao.quantization.marlin_qqq import ( - const, - pack_to_marlin_qqq, - ) - - assert isinstance(_layout, MarlinQQQLayout) - - # Linear layers are (in_features, out_features) but the int_data that is reaching this point - # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. - q_w = int_data.t() - s_group_t = s_group.t() - s_channel_t = s_channel.t() - - if not torch.cuda.get_device_capability()[0] >= 8: - raise ValueError( - f"Can not use Marlin QQQ int4*int8 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel." - ) - - if q_w.dtype != torch.int32: - raise ValueError("Only `torch.int32` weights are supported.") - - in_features, out_features = q_w.shape - # (thread_k, thread_n) - thread_config = [(64, 256), (128, 128), (128, 64), (64, 128)] - if not any( - [ - in_features % thread_k == 0 and out_features % thread_n == 0 - for thread_k, thread_n in thread_config - ] - ): - raise ValueError( - "Not supported `in_features`: {} and `out_features`: {}.".format( - in_features, out_features - ) - ) - - num_bits = 4 if torch.max(q_w) - torch.min(q_w) < 16 else -1 - if num_bits not in [4]: - raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.") - - if s_group.numel() == 0: - group_size = -1 - else: - group_size = in_features // s_group_t.shape[0] - assert group_size <= in_features, ( - "Group size must be less than or equal to in_features." - ) - - if group_size not in const.SUPPORTED_GROUP_SIZES: - raise ValueError( - f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}." - ) - - # Compress quantized weight to marlin format - marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = pack_to_marlin_qqq( - q_w, s_group_t, s_channel_t, num_bits, group_size - ) - - return cls( - marlin_qqq_q_w, - marlin_qqq_s_group, - marlin_qqq_s_channel, - _layout, - q_w.shape, - group_size, - num_bits, - ) - - def get_layout(self) -> Layout: - return self._layout - - def _apply_fn_to_data(self, fn): - self.int_data = fn(self.int_data) - self.s_group = fn(self.s_group) - self.s_channel = fn(self.s_channel) - return self - - -def _linear_int8_act_int4_weight_marlin_qqq_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) - and _aqt_is_int8_reduced_range(input_tensor) - and input_tensor.dtype == torch.float16 - and input_tensor.tensor_impl.scale.dtype == torch.float32 - and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 - and isinstance(weight_tensor, AffineQuantizedTensor) - and weight_tensor.tensor_impl.dtype == torch.int32 - and len(weight_tensor.shape) == 2 - and isinstance(weight_tensor._layout, MarlinQQQLayout) - ) - - -def _linear_int8_act_int4_weight_marlin_qqq_impl(input_tensor, weight_tensor, bias): - from torchao.ops import marlin_qqq_gemm - from torchao.quantization.marlin_qqq import marlin_qqq_workspace - - assert isinstance(input_tensor, AffineQuantizedTensor) - assert isinstance(weight_tensor, AffineQuantizedTensor) - - input = input_tensor.tensor_impl.int_data - input_scale = input_tensor.tensor_impl.scale - - w_int4 = weight_tensor.tensor_impl.int_data - s_group = weight_tensor.tensor_impl.s_group - s_channel = weight_tensor.tensor_impl.s_channel - original_shape = weight_tensor.tensor_impl.original_shape - - # Folds batch dimension into the first dimension - input_2d = input.view(-1, input.shape[-1]) - input_scale = input_scale.view(1, -1) - - size_m = input_2d.shape[0] - size_n = s_channel.shape[1] - size_k = input_2d.shape[1] - workspace_qqq = marlin_qqq_workspace(original_shape[1]) - - out = marlin_qqq_gemm( - input_2d, - w_int4, - input_scale, - s_channel, - s_group, - workspace_qqq, - size_m, - size_n, - size_k, - ) - - # Unfold the batch dimension - out = out.reshape(input.shape[:-1] + (s_channel.shape[1],)) - - if bias is not None: - out += bias.to(out.dtype) - return out - - -to_marlinqqq_quantized_intx = MarlinQQQTensor.from_hp_to_intx +from torchao.prototype.dtypes.uintx.marlin_qqq_tensor import ( # noqa: F401 + MarlinQQQAQTTensorImpl, # noqa: F401 + MarlinQQQLayout, # noqa: F401 + MarlinQQQTensor, # noqa: F401 + _linear_int8_act_int4_weight_marlin_qqq_check, # noqa: F401 + _linear_int8_act_int4_weight_marlin_qqq_impl, # noqa: F401 + to_marlinqqq_quantized_intx, # noqa: F401 +) diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py index 52a5aec425..294c7d0b15 100644 --- a/torchao/prototype/dtypes/__init__.py +++ b/torchao/prototype/dtypes/__init__.py @@ -8,10 +8,16 @@ BlockSparseLayout, CutlassInt4PackedLayout, Int8DynamicActInt4WeightCPULayout, + MarlinQQQLayout, + MarlinQQQTensor, + to_marlinqqq_quantized_intx, ) __all__ = [ "BlockSparseLayout", "CutlassInt4PackedLayout", "Int8DynamicActInt4WeightCPULayout", + "MarlinQQQLayout", + "MarlinQQQTensor", + "to_marlinqqq_quantized_intx", ] diff --git a/torchao/prototype/dtypes/uintx/__init__.py b/torchao/prototype/dtypes/uintx/__init__.py index 89c1f3f810..cd333a90e9 100644 --- a/torchao/prototype/dtypes/uintx/__init__.py +++ b/torchao/prototype/dtypes/uintx/__init__.py @@ -7,9 +7,17 @@ from .block_sparse_layout import BlockSparseLayout from .cutlass_int4_packed_layout import CutlassInt4PackedLayout from .dyn_int8_act_int4_wei_cpu_layout import Int8DynamicActInt4WeightCPULayout +from .marlin_qqq_tensor import ( + MarlinQQQLayout, + MarlinQQQTensor, + to_marlinqqq_quantized_intx, +) __all__ = [ "BlockSparseLayout", "CutlassInt4PackedLayout", "Int8DynamicActInt4WeightCPULayout", + "MarlinQQQLayout", + "MarlinQQQTensor", + "to_marlinqqq_quantized_intx", ] diff --git a/torchao/prototype/dtypes/uintx/marlin_qqq_tensor.py b/torchao/prototype/dtypes/uintx/marlin_qqq_tensor.py new file mode 100644 index 0000000000..04066a6c65 --- /dev/null +++ b/torchao/prototype/dtypes/uintx/marlin_qqq_tensor.py @@ -0,0 +1,351 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import logging +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + get_tensor_impl_constructor, + register_layout, +) +from torchao.dtypes.uintx.plain_layout import ( + _aqt_is_int8_reduced_range, +) +from torchao.dtypes.utils import AQTTensorImpl, Layout +from torchao.quantization.quant_primitives import ( + ZeroPointDomain, + _choose_qparams_and_quantize_affine_qqq, + _dequantize_affine_qqq, +) + +logger = logging.getLogger(__name__) + +aten = torch.ops.aten + + +class MarlinQQQTensor(AffineQuantizedTensor): + """MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. + + To see what happens during _choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization, + please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py + and check the two quant primitive ops: _choose_qparams_and_quantize_affine_qqq and _dequantize_affine_qqq + """ + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + + int_data, s_group, s_channel = self.tensor_impl.get_plain() + nbits = int(math.log2(self.quant_max - self.quant_min + 1)) + group_size = max(self.block_size) + return _dequantize_affine_qqq( + int_data, s_group, s_channel, nbits, group_size, output_dtype + ) + + @classmethod + def from_hp_to_intx( + cls, + input_float: torch.Tensor, + block_size: Tuple[int, ...], + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + _layout: Optional[Layout] = None, + ): + """Converts a floating point tensor to a Marlin QQQ quantized tensor.""" + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") + original_shape = input_float.shape + input_float = _layout.pre_process(input_float) + nbits = int(math.log2(quant_max - quant_min + 1)) + group_size = max(block_size) + data, s_group, s_channel, _ = _choose_qparams_and_quantize_affine_qqq( + input_float, nbits, group_size + ) + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(data, s_group, s_channel, _layout) + return cls( + tensor_impl, + block_size, + original_shape, + quant_min, + quant_max, + zero_point_domain, + dtype=input_float.dtype, + ) + + +@dataclass(frozen=True) +class MarlinQQQLayout(Layout): + """MarlinQQQLayout is a layout class for Marlin QQQ quantization.""" + + pass + + +@register_layout(MarlinQQQLayout) +class MarlinQQQAQTTensorImpl(AQTTensorImpl): + """ + TensorImpl storage class for sparse_qqq layout for affine quantized tensor. + + Can only be used with 4 bits quantization for now. + + Original marlin documentation and information: + https://github.com/IST-DASLab/marlin/tree/master + + Marlin qqq information: + https://github.com/HandH1998/QQQ/tree/main + https://arxiv.org/pdf/2406.09904 + + fields: + original_shape (torch.Size): the original shape of the tensor. used to unpack the tensor to the original shape + group_size (int): the group size used to pack the tensor + num_bits (int): the number of bits used to quantize the tensor + """ + + @staticmethod + def __new__( + cls, + int_data: torch.Tensor, + s_group: torch.Tensor, + s_channel: torch.Tensor, + _layout: Layout, + original_shape: torch.Size, + group_size: int, + num_bits: int, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + s_group: torch.Tensor, + s_channel: torch.Tensor, + _layout: Layout, + original_shape: torch.Size, + group_size: int, + num_bits: int, + ): + self.int_data = int_data + self.s_group = s_group + self.s_channel = s_channel + self._layout = _layout + self.original_shape = original_shape + self.group_size = group_size + self.num_bits = num_bits + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"MarlinQQQAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + def __tensor_flatten__(self): + return ["int_data", "s_group", "s_channel"], [ + self._layout, + self.original_shape, + self.group_size, + self.num_bits, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data = tensor_data_dict["int_data"] + s_group = tensor_data_dict["s_group"] + s_channel = tensor_data_dict["s_channel"] + _layout, original_shape, group_size, num_bits = tensor_attributes + return cls( + int_data, s_group, s_channel, _layout, original_shape, group_size, num_bits + ) + + def get_plain(self): + from torchao.quantization.marlin_qqq import ( + unpack_from_marlin_qqq, + ) + + int_data_expanded, s_group_expanded, s_channel_expanded = ( + unpack_from_marlin_qqq( + self.int_data, + self.s_group, + self.s_channel, + self.original_shape, + self.num_bits, + self.group_size, + ) + ) + int_data_expanded_t = int_data_expanded.t() + s_group_expanded_t = s_group_expanded.t() + s_channel_expanded_t = s_channel_expanded.t() + return int_data_expanded_t, s_group_expanded_t, s_channel_expanded_t + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + s_group: torch.Tensor, + s_channel: torch.Tensor, + _layout: Layout, + ): + from torchao.quantization.marlin_qqq import ( + const, + pack_to_marlin_qqq, + ) + + assert isinstance(_layout, MarlinQQQLayout) + + # Linear layers are (in_features, out_features) but the int_data that is reaching this point + # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. + q_w = int_data.t() + s_group_t = s_group.t() + s_channel_t = s_channel.t() + + if not torch.cuda.get_device_capability()[0] >= 8: + raise ValueError( + f"Can not use Marlin QQQ int4*int8 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel." + ) + + if q_w.dtype != torch.int32: + raise ValueError("Only `torch.int32` weights are supported.") + + in_features, out_features = q_w.shape + # (thread_k, thread_n) + thread_config = [(64, 256), (128, 128), (128, 64), (64, 128)] + if not any( + [ + in_features % thread_k == 0 and out_features % thread_n == 0 + for thread_k, thread_n in thread_config + ] + ): + raise ValueError( + "Not supported `in_features`: {} and `out_features`: {}.".format( + in_features, out_features + ) + ) + + num_bits = 4 if torch.max(q_w) - torch.min(q_w) < 16 else -1 + if num_bits not in [4]: + raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.") + + if s_group.numel() == 0: + group_size = -1 + else: + group_size = in_features // s_group_t.shape[0] + assert group_size <= in_features, ( + "Group size must be less than or equal to in_features." + ) + + if group_size not in const.SUPPORTED_GROUP_SIZES: + raise ValueError( + f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}." + ) + + # Compress quantized weight to marlin format + marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = pack_to_marlin_qqq( + q_w, s_group_t, s_channel_t, num_bits, group_size + ) + + return cls( + marlin_qqq_q_w, + marlin_qqq_s_group, + marlin_qqq_s_channel, + _layout, + q_w.shape, + group_size, + num_bits, + ) + + def get_layout(self) -> Layout: + return self._layout + + def _apply_fn_to_data(self, fn): + self.int_data = fn(self.int_data) + self.s_group = fn(self.s_group) + self.s_channel = fn(self.s_channel) + return self + + +def _linear_int8_act_int4_weight_marlin_qqq_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and input_tensor.dtype == torch.float16 + and input_tensor.tensor_impl.scale.dtype == torch.float32 + and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 + and isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.tensor_impl.dtype == torch.int32 + and len(weight_tensor.shape) == 2 + and isinstance(weight_tensor._layout, MarlinQQQLayout) + ) + + +def _linear_int8_act_int4_weight_marlin_qqq_impl(input_tensor, weight_tensor, bias): + from torchao.ops import marlin_qqq_gemm + from torchao.quantization.marlin_qqq import marlin_qqq_workspace + + assert isinstance(input_tensor, AffineQuantizedTensor) + assert isinstance(weight_tensor, AffineQuantizedTensor) + + input = input_tensor.tensor_impl.int_data + input_scale = input_tensor.tensor_impl.scale + + w_int4 = weight_tensor.tensor_impl.int_data + s_group = weight_tensor.tensor_impl.s_group + s_channel = weight_tensor.tensor_impl.s_channel + original_shape = weight_tensor.tensor_impl.original_shape + + # Folds batch dimension into the first dimension + input_2d = input.view(-1, input.shape[-1]) + input_scale = input_scale.view(1, -1) + + size_m = input_2d.shape[0] + size_n = s_channel.shape[1] + size_k = input_2d.shape[1] + workspace_qqq = marlin_qqq_workspace(original_shape[1]) + + out = marlin_qqq_gemm( + input_2d, + w_int4, + input_scale, + s_channel, + s_group, + workspace_qqq, + size_m, + size_n, + size_k, + ) + + # Unfold the batch dimension + out = out.reshape(input.shape[:-1] + (s_channel.shape[1],)) + + if bias is not None: + out += bias.to(out.dtype) + return out + + +to_marlinqqq_quantized_intx = MarlinQQQTensor.from_hp_to_intx From 3fb6a2c9b96053b47ade4fd19225437791d23380 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 7 Nov 2025 05:54:45 +0000 Subject: [PATCH 04/15] Move all deprecated api tests into a single file --- test/dtypes/test_api_deprecation_warning.py | 85 +++++++++++++++++++++ test/integration/test_integration.py | 27 ------- test/quantization/test_da8w4_cpu.py | 27 ------- test/sparsity/test_sparse_api.py | 27 ------- 4 files changed, 85 insertions(+), 81 deletions(-) create mode 100644 test/dtypes/test_api_deprecation_warning.py diff --git a/test/dtypes/test_api_deprecation_warning.py b/test/dtypes/test_api_deprecation_warning.py new file mode 100644 index 0000000000..3c50e8aae3 --- /dev/null +++ b/test/dtypes/test_api_deprecation_warning.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for deprecated API imports that have been moved to prototype. +TODO: Remove these tests once the deprecated APIs have been removed. +""" + +import sys +import warnings + + +def test_int8_dynamic_act_int4_weight_cpu_layout_deprecated(): + """Test deprecation warning for Int8DynamicActInt4WeightCPULayout.""" + # We need to clear the cache to force re-importing and trigger the warning again. + modules_to_clear = [ + "torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout", + "torchao.dtypes", + ] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + from torchao.dtypes import Int8DynamicActInt4WeightCPULayout # noqa: F401 + + warnings.simplefilter("always") # Ensure all warnings are captured + assert any( + issubclass(warning.category, DeprecationWarning) + and "Int8DynamicActInt4WeightCPULayout" in str(warning.message) + for warning in w + ), ( + f"Expected deprecation warning for Int8DynamicActInt4WeightCPULayout, got: {[str(warning.message) for warning in w]}" + ) + + +def test_cutlass_int4_packed_layout_deprecated(): + """Test deprecation warning for CutlassInt4PackedLayout.""" + # We need to clear the cache to force re-importing and trigger the warning again. + modules_to_clear = [ + "torchao.dtypes.uintx.cutlass_int4_packed_layout", + "torchao.dtypes", + ] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + from torchao.dtypes import CutlassInt4PackedLayout # noqa: F401 + + warnings.simplefilter("always") # Ensure all warnings are captured + assert any( + issubclass(warning.category, DeprecationWarning) + and "CutlassInt4PackedLayout" in str(warning.message) + for warning in w + ), ( + f"Expected deprecation warning for CutlassInt4PackedLayout, got: {[str(warning.message) for warning in w]}" + ) + + +def test_block_sparse_layout_deprecated(): + """Test deprecation warning for BlockSparseLayout.""" + # We need to clear the cache to force re-importing and trigger the warning again. + modules_to_clear = [ + "torchao.dtypes.uintx.block_sparse_layout", + "torchao.dtypes", + ] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + from torchao.dtypes import BlockSparseLayout # noqa: F401 + + warnings.simplefilter("always") # Ensure all warnings are captured + assert any( + issubclass(warning.category, DeprecationWarning) + and "BlockSparseLayout" in str(warning.message) + for warning in w + ), ( + f"Expected deprecation warning for BlockSparseLayout, got: {[str(warning.message) for warning in w]}" + ) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 2d05426d73..dc58470526 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1946,32 +1946,5 @@ def test_benchmark_model_cpu(self): assert self.run_benchmark_model("cpu") is not None -# TODO: Remove this test once the deprecated API has been removed -def test_cutlass_int4_packed_layout_deprecated(): - import sys - import warnings - - # We need to clear the cache to force re-importing and trigger the warning again. - modules_to_clear = [ - "torchao.dtypes.uintx.cutlass_int4_packed_layout", - "torchao.dtypes", - ] - for mod in modules_to_clear: - if mod in sys.modules: - del sys.modules[mod] - - with warnings.catch_warnings(record=True) as w: - from torchao.dtypes import CutlassInt4PackedLayout # noqa: F401 - - warnings.simplefilter("always") # Ensure all warnings are captured - assert any( - issubclass(warning.category, DeprecationWarning) - and "CutlassInt4PackedLayout" in str(warning.message) - for warning in w - ), ( - f"Expected deprecation warning for CutlassInt4PackedLayout, got: {[str(warning.message) for warning in w]}" - ) - - if __name__ == "__main__": unittest.main() diff --git a/test/quantization/test_da8w4_cpu.py b/test/quantization/test_da8w4_cpu.py index c4b0eac39f..d4f68c4333 100644 --- a/test/quantization/test_da8w4_cpu.py +++ b/test/quantization/test_da8w4_cpu.py @@ -176,32 +176,5 @@ def forward(self, x): common_utils.instantiate_parametrized_tests(TestDa8w4Cpu) -# TODO: Remove this test once the deprecated API has been removed -def test_int8_dynamic_act_int4_weight_cpu_layout_deprecated(): - import sys - import warnings - - # We need to clear the cache to force re-importing and trigger the warning again. - modules_to_clear = [ - "torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout", - "torchao.dtypes", - ] - for mod in modules_to_clear: - if mod in sys.modules: - del sys.modules[mod] - - with warnings.catch_warnings(record=True) as w: - from torchao.dtypes import Int8DynamicActInt4WeightCPULayout # noqa: F401 - - warnings.simplefilter("always") # Ensure all warnings are captured - assert any( - issubclass(warning.category, DeprecationWarning) - and "Int8DynamicActInt4WeightCPULayout" in str(warning.message) - for warning in w - ), ( - f"Expected deprecation warning for Int8DynamicActInt4WeightCPULayout, got: {[str(warning.message) for warning in w]}" - ) - - if __name__ == "__main__": run_tests() diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index c9d41a98a9..66cd032a9a 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -267,33 +267,6 @@ def test_sparse(self, compile): torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1) - # TODO: Remove this test once the deprecated API has been removed - def test_sparse_deprecated(self): - import sys - import warnings - - # We need to clear the cache to force re-importing and trigger the warning again. - modules_to_clear = [ - "torchao.dtypes.uintx.block_sparse_layout", - "torchao.dtypes", - ] - for mod in modules_to_clear: - if mod in sys.modules: - del sys.modules[mod] - - with warnings.catch_warnings(record=True) as w: - from torchao.dtypes import BlockSparseLayout # noqa: F401 - - warnings.simplefilter("always") # Ensure all warnings are captured - self.assertTrue( - any( - issubclass(warning.category, DeprecationWarning) - and "BlockSparseLayout" in str(warning.message) - for warning in w - ), - f"Expected deprecation warning for BlockSparseLayout, got: {[str(w.message) for w in w]}", - ) - common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse) common_utils.instantiate_parametrized_tests(TestQuantSemiSparse) From 67bc40a18167619d2441c6e94198653610e7b7f1 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 6 Nov 2025 22:45:14 -0800 Subject: [PATCH 05/15] Add to docs --- docs/source/api_ref_dtypes.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index e347dfd2e3..5c73d275eb 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -52,6 +52,7 @@ Prototype BlockSparseLayout CutlassInt4PackedLayout + Int8DynamicActInt4WeightCPULayout .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring From ae101f6fffb04f9ab3d359374824b81e5e7a09fe Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 6 Nov 2025 23:03:40 -0800 Subject: [PATCH 06/15] Add deprecation test --- docs/source/api_ref_dtypes.rst | 4 ++-- test/dtypes/test_api_deprecation_warning.py | 24 +++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index 5c73d275eb..58ad4ee8a4 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -23,8 +23,6 @@ Layouts and Tensor Subclasses FloatxTensorCoreLayout MarlinSparseLayout UintxLayout - MarlinQQQTensor - MarlinQQQLayout Int4CPULayout CutlassSemiSparseLayout @@ -53,6 +51,8 @@ Prototype BlockSparseLayout CutlassInt4PackedLayout Int8DynamicActInt4WeightCPULayout + MarlinQQQTensor + MarlinQQQLayout .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring diff --git a/test/dtypes/test_api_deprecation_warning.py b/test/dtypes/test_api_deprecation_warning.py index 3c50e8aae3..fc79bda4f1 100644 --- a/test/dtypes/test_api_deprecation_warning.py +++ b/test/dtypes/test_api_deprecation_warning.py @@ -83,3 +83,27 @@ def test_block_sparse_layout_deprecated(): ), ( f"Expected deprecation warning for BlockSparseLayout, got: {[str(warning.message) for warning in w]}" ) + + +def test_marlin_qqq_layout_deprecated(): + """Test deprecation warning for MarlinQQQLayout.""" + # We need to clear the cache to force re-importing and trigger the warning again. + modules_to_clear = [ + "torchao.dtypes.uintx.marlin_qqq_tensor", + "torchao.dtypes", + ] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + from torchao.dtypes import MarlinQQQLayout # noqa: F401 + + warnings.simplefilter("always") # Ensure all warnings are captured + assert any( + issubclass(warning.category, DeprecationWarning) + and "MarlinQQQLayout" in str(warning.message) + for warning in w + ), ( + f"Expected deprecation warning for MarlinQQQLayout, got: {[str(warning.message) for warning in w]}" + ) From 99bb707bdcf6885150a333b5bf674aae0a2e8681 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 6 Nov 2025 23:16:19 -0800 Subject: [PATCH 07/15] Fixes --- test/dtypes/test_api_deprecation_warning.py | 2 +- torchao/dtypes/__init__.py | 8 +++++--- torchao/dtypes/uintx/__init__.py | 11 +++++------ 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/test/dtypes/test_api_deprecation_warning.py b/test/dtypes/test_api_deprecation_warning.py index fc79bda4f1..90d711c634 100644 --- a/test/dtypes/test_api_deprecation_warning.py +++ b/test/dtypes/test_api_deprecation_warning.py @@ -97,9 +97,9 @@ def test_marlin_qqq_layout_deprecated(): del sys.modules[mod] with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") # Ensure all warnings are captured from torchao.dtypes import MarlinQQQLayout # noqa: F401 - warnings.simplefilter("always") # Ensure all warnings are captured assert any( issubclass(warning.category, DeprecationWarning) and "MarlinQQQLayout" in str(warning.message) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 354692e794..4c83de7ddd 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -16,19 +16,21 @@ from .uintx import ( Int4CPULayout, Int4XPULayout, - MarlinQQQLayout, - MarlinQQQTensor, MarlinSparseLayout, PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout, SemiSparseLayout, TensorCoreTiledLayout, UintxLayout, - to_marlinqqq_quantized_intx, ) from .uintx.block_sparse_layout import BlockSparseLayout from .uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout from .uintx.dyn_int8_act_int4_wei_cpu_layout import Int8DynamicActInt4WeightCPULayout +from .uintx.marlin_qqq_tensor import ( + MarlinQQQLayout, + MarlinQQQTensor, + to_marlinqqq_quantized_intx, +) from .utils import ( Layout, PlainLayout, diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 71106d809d..b76e80e0fc 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,9 +1,3 @@ -from torchao.prototype.dtypes.uintx.marlin_qqq_tensor import ( - MarlinQQQLayout, - MarlinQQQTensor, - to_marlinqqq_quantized_intx, -) - from .dyn_int8_act_int4_wei_cpu_layout import ( Int8DynamicActInt4WeightCPULayout, ) @@ -13,6 +7,11 @@ from .int4_xpu_layout import ( Int4XPULayout, ) +from .marlin_qqq_tensor import ( + MarlinQQQLayout, + MarlinQQQTensor, + to_marlinqqq_quantized_intx, +) from .marlin_sparse_layout import ( MarlinSparseLayout, ) From f5d7e3a709e7a7156d9b31d705ae79ff4b5efea3 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 6 Nov 2025 23:17:50 -0800 Subject: [PATCH 08/15] Empty commit to trigger CI From ecece7206661929694a860f2c77211fb8708d134 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 6 Nov 2025 23:18:25 -0800 Subject: [PATCH 09/15] Empty commit to trigger CI From ed58e1e2fa02e5abf63bcfbe92b8506a125d0652 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 7 Nov 2025 18:30:31 +0000 Subject: [PATCH 10/15] Update tests --- test/dtypes/test_api_deprecation_warning.py | 85 --------------------- test/dtypes/test_uintx.py | 37 +++++++++ 2 files changed, 37 insertions(+), 85 deletions(-) delete mode 100644 test/dtypes/test_api_deprecation_warning.py diff --git a/test/dtypes/test_api_deprecation_warning.py b/test/dtypes/test_api_deprecation_warning.py deleted file mode 100644 index 3c50e8aae3..0000000000 --- a/test/dtypes/test_api_deprecation_warning.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -""" -Tests for deprecated API imports that have been moved to prototype. -TODO: Remove these tests once the deprecated APIs have been removed. -""" - -import sys -import warnings - - -def test_int8_dynamic_act_int4_weight_cpu_layout_deprecated(): - """Test deprecation warning for Int8DynamicActInt4WeightCPULayout.""" - # We need to clear the cache to force re-importing and trigger the warning again. - modules_to_clear = [ - "torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout", - "torchao.dtypes", - ] - for mod in modules_to_clear: - if mod in sys.modules: - del sys.modules[mod] - - with warnings.catch_warnings(record=True) as w: - from torchao.dtypes import Int8DynamicActInt4WeightCPULayout # noqa: F401 - - warnings.simplefilter("always") # Ensure all warnings are captured - assert any( - issubclass(warning.category, DeprecationWarning) - and "Int8DynamicActInt4WeightCPULayout" in str(warning.message) - for warning in w - ), ( - f"Expected deprecation warning for Int8DynamicActInt4WeightCPULayout, got: {[str(warning.message) for warning in w]}" - ) - - -def test_cutlass_int4_packed_layout_deprecated(): - """Test deprecation warning for CutlassInt4PackedLayout.""" - # We need to clear the cache to force re-importing and trigger the warning again. - modules_to_clear = [ - "torchao.dtypes.uintx.cutlass_int4_packed_layout", - "torchao.dtypes", - ] - for mod in modules_to_clear: - if mod in sys.modules: - del sys.modules[mod] - - with warnings.catch_warnings(record=True) as w: - from torchao.dtypes import CutlassInt4PackedLayout # noqa: F401 - - warnings.simplefilter("always") # Ensure all warnings are captured - assert any( - issubclass(warning.category, DeprecationWarning) - and "CutlassInt4PackedLayout" in str(warning.message) - for warning in w - ), ( - f"Expected deprecation warning for CutlassInt4PackedLayout, got: {[str(warning.message) for warning in w]}" - ) - - -def test_block_sparse_layout_deprecated(): - """Test deprecation warning for BlockSparseLayout.""" - # We need to clear the cache to force re-importing and trigger the warning again. - modules_to_clear = [ - "torchao.dtypes.uintx.block_sparse_layout", - "torchao.dtypes", - ] - for mod in modules_to_clear: - if mod in sys.modules: - del sys.modules[mod] - - with warnings.catch_warnings(record=True) as w: - from torchao.dtypes import BlockSparseLayout # noqa: F401 - - warnings.simplefilter("always") # Ensure all warnings are captured - assert any( - issubclass(warning.category, DeprecationWarning) - and "BlockSparseLayout" in str(warning.message) - for warning in w - ), ( - f"Expected deprecation warning for BlockSparseLayout, got: {[str(warning.message) for warning in w]}" - ) diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index cb0c88b21c..6be8b29400 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -3,6 +3,9 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import sys +import warnings + import pytest import torch @@ -165,3 +168,37 @@ def test_uintx_model_size(dtype): quantize_(linear[0], UIntXWeightOnlyConfig(dtype)) quantized_size = get_model_size_in_bytes(linear) assert bf16_size * _dtype_to_ratio[dtype] == quantized_size + + +def test_uintx_api_deprecation(): + """ + Test that deprecated uintx APIs trigger deprecation warnings on import. + TODO: Remove this test once the deprecated APIs have been removed. + """ + deprecated_apis = [ + ( + "Int8DynamicActInt4WeightCPULayout", + "torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout", + ), + ("CutlassInt4PackedLayout", "torchao.dtypes.uintx.cutlass_int4_packed_layout"), + ("BlockSparseLayout", "torchao.dtypes.uintx.block_sparse_layout"), + ] + + for api_name, module_path in deprecated_apis: + # Clear the cache to force re-importing and trigger the warning again + modules_to_clear = [module_path, "torchao.dtypes"] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") # Ensure all warnings are captured + + # Dynamically import the deprecated API + exec(f"from torchao.dtypes import {api_name}") + + assert any( + issubclass(warning.category, DeprecationWarning) + and api_name in str(warning.message) + for warning in w + ), f"Expected deprecation warning for {api_name}, got: {[str(warning.message) for warning in w]}" From 4f5ddd82d621ee9c600ef9bf0c235269b8762d1c Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 7 Nov 2025 18:33:26 +0000 Subject: [PATCH 11/15] lint fixes --- test/dtypes/test_uintx.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index 6be8b29400..5d54a80753 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -201,4 +201,6 @@ def test_uintx_api_deprecation(): issubclass(warning.category, DeprecationWarning) and api_name in str(warning.message) for warning in w - ), f"Expected deprecation warning for {api_name}, got: {[str(warning.message) for warning in w]}" + ), ( + f"Expected deprecation warning for {api_name}, got: {[str(warning.message) for warning in w]}" + ) From 74511e20afdd3faf9f4d906fd813c74e7e1634d5 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 7 Nov 2025 12:08:40 -0800 Subject: [PATCH 12/15] Update test --- test/dtypes/test_uintx.py | 1 + test/quantization/test_marlin_qqq.py | 20 -------------------- 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index 5d54a80753..0878dfed4d 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -182,6 +182,7 @@ def test_uintx_api_deprecation(): ), ("CutlassInt4PackedLayout", "torchao.dtypes.uintx.cutlass_int4_packed_layout"), ("BlockSparseLayout", "torchao.dtypes.uintx.block_sparse_layout"), + ("MarlinQQQLayout", "torchao.dtypes.uintx.marlin_qqq_tensor"), ] for api_name, module_path in deprecated_apis: diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index ec52a71545..6f0f0d69ba 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -132,25 +132,5 @@ def test_pack_unpack_equivalence(self): ) -def test_marlin_qqq_tensor_deprecation_warning(): - """Test that importing from the old location raises a deprecation warning""" - import warnings - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - # Import from the old deprecated location - from torchao.dtypes.uintx.marlin_qqq_tensor import ( # noqa: F401 - MarlinQQQLayout, - ) - - # Verify the deprecation warning was raised - assert len(w) == 1 - assert issubclass(w[-1].category, DeprecationWarning) - assert "torchao.dtypes.uintx.marlin_qqq_tensor is deprecated" in str( - w[-1].message - ) - assert "torchao.prototype.dtypes import" in str(w[-1].message) - - if __name__ == "__main__": run_tests() From 5e0f397a7b66848c6a1e67f78fa6a78b8e99395f Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 7 Nov 2025 13:55:13 -0800 Subject: [PATCH 13/15] Move gemlite_layout.py to prototype/dtypes --- torchao/dtypes/affine_quantized_tensor_ops.py | 8 +- torchao/dtypes/uintx/gemlite_layout.py | 461 +----------------- .../prototype/dtypes/uintx/gemlite_layout.py | 452 +++++++++++++++++ torchao/quantization/autoquant.py | 2 +- torchao/quantization/quant_api.py | 2 +- 5 files changed, 476 insertions(+), 449 deletions(-) create mode 100644 torchao/prototype/dtypes/uintx/gemlite_layout.py diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 21f13729dd..6c7216ab12 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -25,10 +25,6 @@ _linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl, ) -from torchao.dtypes.uintx.gemlite_layout import ( - _linear_fp_act_int4_weight_gemlite_check, - _linear_fp_act_int4_weight_gemlite_impl, -) from torchao.dtypes.uintx.int4_cpu_layout import ( _linear_fp_act_uint4_weight_cpu_check, _linear_fp_act_uint4_weight_cpu_impl, @@ -90,6 +86,10 @@ _linear_int8_act_int4_weight_cpu_check, _linear_int8_act_int4_weight_cpu_impl, ) +from torchao.prototype.dtypes.uintx.gemlite_layout import ( + _linear_fp_act_int4_weight_gemlite_check, + _linear_fp_act_int4_weight_gemlite_impl, +) from torchao.prototype.dtypes.uintx.marlin_qqq_tensor import ( _linear_int8_act_int4_weight_marlin_qqq_check, _linear_int8_act_int4_weight_marlin_qqq_impl, diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py index 8a8f2309c9..c75c7fe1b1 100644 --- a/torchao/dtypes/uintx/gemlite_layout.py +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -3,450 +3,25 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from typing import Dict, Optional, Tuple -import torch -from torch.utils._python_dispatch import ( - is_traceable_wrapper_subclass, - return_and_correct_aliasing, -) +# Backward compatibility stub - imports from the new location +import warnings -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, - register_layout, +warnings.warn( + "Importing from torchao.dtypes.uintx.gemlite_layout is deprecated. " + "Please use 'from torchao.prototype.dtypes import GemlitePackedLayout' instead. " + "This import path will be removed in a future release of torchao. " + "See https://github.com/pytorch/ao/issues/2752 for more details.", + DeprecationWarning, + stacklevel=2, ) -from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledAQTTensorImpl -from torchao.dtypes.utils import Layout -from torchao.utils import fill_defaults - -try: - import gemlite -except: - gemlite = None - -aten = torch.ops.aten - - -def _same_metadata( - self: "GemliteAQTTensorImpl", - src: "GemliteAQTTensorImpl", -) -> bool: - kwargs_match = len(self.gemlite_kwargs) == len(src.gemlite_kwargs) - for k, v in self.gemlite_kwargs.items(): - if k in [ - "in_features", - "out_features", - "packing_bitwidth", - "elements_per_sample", - ]: - kwargs_match = kwargs_match and (v == src.gemlite_kwargs[k]) - - return ( - isinstance(self, GemliteAQTTensorImpl) - and isinstance(src, GemliteAQTTensorImpl) - and self.shape == src.shape - and self.packed_weight.shape == src.packed_weight.shape - and self.scale.shape == src.scale.shape - and self.zero_point.shape == src.zero_point.shape - and kwargs_match - and type(self._layout) == type(src._layout) - ) - - -def get_gemlite_quant_kwargs(bit_width, group_size, dtype): - from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain - - kwargs = {} - if bit_width != 8: - kwargs["mapping_type"] = MappingType.ASYMMETRIC - kwargs["block_size"] = (1, group_size) - kwargs["target_dtype"] = torch.uint8 - kwargs["eps"] = 1e-6 - kwargs["quant_min"] = 0 - kwargs["quant_max"] = (2**bit_width) - 1 - kwargs["eps"] = 1e-6 - kwargs["zero_point_dtype"] = dtype - kwargs["zero_point_domain"] = ZeroPointDomain.FLOAT - elif bit_width == 8: - kwargs["mapping_type"] = MappingType.SYMMETRIC - kwargs["block_size"] = (1, group_size) - kwargs["target_dtype"] = torch.int8 - kwargs["quant_min"] = -128 - kwargs["quant_max"] = 127 - kwargs["eps"] = 1e-5 - kwargs["zero_point_dtype"] = None - kwargs["zero_point_domain"] = ZeroPointDomain.NONE - return kwargs - - -def get_gemlite_aqt_kwargs( - weight, - group_size=64, - bit_width=4, - packing_bitwidth=None, - mode="weight_only", - use_hqq=True, -): - if gemlite is None: - raise ImportError( - "Unable to import 'gemlite'. Please ensure it is installed correctly. You can install it with: pip install gemlite" - ) - - assert bit_width in [ - 4, - 8, - ], f"gemlite only works with bit_width 4,8 but got {bit_width}" - - assert weight.dtype in [torch.float16, torch.bfloat16], ( - f"gemlite only works with dtype torch.float16 or torch.bfloat16 but got {weight.dtype}" - ) - assert group_size in [32, 64, 128, 256, 512, 1024, None] - assert group_size is None or bit_width != 8, ( - "gemlite only works with group_size=None for bit_width=8" - ) - assert packing_bitwidth in [8, 16, 32, None], ( - f"Invalid packing bitwidth, got {packing_bitwidth}" - ) - - assert mode in ["weight_only", "dynamic"], ( - f"Invalid mode: should be either weight_only or dynamic, got {mode}" - ) - - out_features, in_features = weight.shape - group_size = in_features if group_size is None else group_size - - aqt_kwargs = get_gemlite_quant_kwargs(bit_width, group_size, weight.dtype) - aqt_kwargs["_layout"] = GemlitePackedLayout( - group_size=group_size, - bit_width=bit_width, - packing_bitwidth=packing_bitwidth, - mode=mode, - ) - aqt_kwargs["use_hqq"] = use_hqq - return aqt_kwargs - - -@dataclass(frozen=True) -class GemlitePackedLayout(Layout): - group_size: Optional[int] = 128 - bit_width: int = 4 - packing_bitwidth: Optional[int] = None - mode: Optional[str] = "weight_only" - - -@register_layout(GemlitePackedLayout) -class GemliteAQTTensorImpl(TensorCoreTiledAQTTensorImpl): - def __new__( - cls, - packed_weight: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - gemlite_kwargs: Dict, - _layout: Layout, - ): - kwargs = {} - kwargs["device"] = packed_weight.device - kwargs["layout"] = ( - kwargs.get("layout") - if kwargs.get("layout", False) - else packed_weight.layout - ) - kwargs["dtype"] = packed_weight.dtype - kwargs["requires_grad"] = False - shape = packed_weight.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - packed_weight: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - gemlite_kwargs: Dict, - _layout: Layout, - ): - self.packed_weight = packed_weight - self.scale = scale - self.zero_point = zero_point - self.gemlite_kwargs = gemlite_kwargs - self._layout = _layout - - def __tensor_flatten__(self): - return ["packed_weight", "scale", "zero_point"], [ - self._layout, - self.gemlite_kwargs, - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - packed_weight, scale, zero_point = ( - tensor_data_dict["packed_weight"], - tensor_data_dict["scale"], - tensor_data_dict["zero_point"], - ) - _layout, gemlite_kwargs = tensor_attributes - return cls(packed_weight, scale, zero_point, gemlite_kwargs, _layout) - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - assert isinstance(_layout, GemlitePackedLayout), ( - f"GemliteAQTTensorImpl only works with GemliteLinearTriton but got {_layout}" - ) - device = int_data.device - if device.type != "cuda": - int_data = ( - int_data.cuda() - ) # We need int_data on cuda device because of Triton packing - - group_size, bit_width = _layout.group_size, _layout.bit_width - out_features, in_features = int_data.shape - packing_bitwidth = _layout.packing_bitwidth - mode = _layout.mode - - if bit_width == 8 and group_size == in_features: - processor = ( - gemlite.helper.A8W8_int8_dynamic - if mode == "dynamic" - else gemlite.helper.A16W8 - ) - gemlite_linear = processor(device=int_data.device).from_weights( - int_data, scales=scale, bias=None - ) - else: - processor = ( - gemlite.helper.A8Wn_dynamic - if mode == "dynamic" - else gemlite.helper.A16Wn - ) - gemlite_linear = processor( - device=int_data.device, packing_bitwidth=packing_bitwidth - ).from_weights( - int_data, scale, zero_point, bit_width, group_size, bias=None - ) - - meta_args = gemlite_linear.get_meta_args() - gemlite_kwargs = { - "in_features": in_features, - "out_features": out_features, - "packing_bitwidth": packing_bitwidth, - "data_contiguous": gemlite_linear.data_contiguous, - "elements_per_sample": gemlite_linear.elements_per_sample, - "W_group_mode": gemlite_linear.W_group_mode, - "meta_args": meta_args, - } - - packed_weight, scale, zero_point = gemlite_linear.get_tensor_args() - packed_weight = packed_weight.to(device) - if zero_point is None: - zero_point = torch.tensor( - [[]], device=packed_weight.device, dtype=torch.int32 - ) - - return cls(packed_weight, scale, zero_point, gemlite_kwargs, _layout) - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs["device"] - return self.__class__( - self.packed_weight.to(device), - self.scale.to(device), - self.zero_point.to(device), - self.gemlite_kwargs, - self._layout, - ) - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.packed_weight), - fn(self.scale), - fn(self.zero_point), - self.gemlite_kwargs, - self._layout, - ) - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - device = self.packed_weight.device - int_data = ( - ( - gemlite.bitpack.unpack_over_rows( - self.packed_weight.cuda(), - W_nbits=self._layout.bit_width, - num_output_rows=self.gemlite_kwargs["in_features"], - dtype=torch.uint8, - ) - ) - .to(device) - .t() - ) - - # Preserve col-row major layout - if self.gemlite_kwargs["data_contiguous"]: - int_data = int_data.contiguous() - - # Handle FMA mode: W_q * s + z -> (W_q - z) * s - if self.gemlite_kwargs["W_group_mode"] == 4: - scale_min_val = 1e-8 - scale = self.scale.clone().float() - scale[torch.logical_and(scale >= 0, scale.abs() <= scale_min_val)] = ( - scale_min_val - ) - scale[ - torch.logical_and(scale < 0, scale.abs() <= scale_min_val) - ] = -scale_min_val - zero_point = (-self.zero_point.float() / scale).clamp_(-100, 100) - zero_point = zero_point.to(self.scale.dtype) - else: - zero_point = self.zero_point - - scale = self.scale.t().contiguous() - zero_point = zero_point.t().contiguous() - - return int_data, scale, zero_point - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - # we don't handle transpose operations and just ignore them. In practice the only - # reason a transpsoe should occur is because the functional linear - # op can decompose into e.g. transpose + addmm so since we want - # to use the gemlite matmul kernel, which expects teh weight to be passed in as is, - # we ignore the transpose - if func is aten.detach.default or func is aten.t.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - if func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - if func is aten.slice.Tensor: - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - assert step == 1, "Only step == 1 is supported in slicing right now" - - if dim in [0, 1]: - # data in self is transposed, meaning forward() performs x @ W_deq not x @ W_deq.T - dim = 1 - dim - packed_weight = self.packed_weight - scale = self.scale - zero_point = self.zero_point - - gemlite_kwargs = self.gemlite_kwargs.copy() - orig_shape = [ - gemlite_kwargs["in_features"], - gemlite_kwargs["out_features"], - ] - elements_per_sample = gemlite_kwargs["elements_per_sample"] - data_len = orig_shape[dim] - scale_len = scale.shape[dim] - ratio = data_len / scale_len - start_scale = int(start / ratio) - end_scale = int(end / ratio) - - # For packing only the K dimension. This should be flipped for N-dim packing. - div = elements_per_sample if dim == 0 else 1 - packed_weight = aten.slice.Tensor( - packed_weight, dim, start // div, end // div, step - ) - - # Update in_features/out_features - gemlite_kwargs["in_features"] = ( - packed_weight.shape[0] * elements_per_sample - ) - gemlite_kwargs["out_features"] = packed_weight.shape[1] - - scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) - if zero_point is not None and zero_point.numel() > 0: - zero_point = aten.slice.Tensor( - zero_point, dim, start_scale, end_scale, step - ) - else: - zero_point = None - - sliced = GemliteAQTTensorImpl( - packed_weight, scale, zero_point, gemlite_kwargs, self._layout - ) - return return_and_correct_aliasing(func, args, kwargs, sliced) - - else: - raise NotImplementedError( - f"GemliteAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" - ) - - elif func is aten.copy_.default: - self = args[0] - src = args[1] - - # Handle zero_point = None with symmetric quant - if self.zero_point is None: - self.zero_point = torch.tensor( - [[]], device=self.packed_weight.device, dtype=torch.int32 - ) - - if src.zero_point is None: - src.zero_point = torch.tensor( - [[]], device=src.packed_weight.device, dtype=torch.int32 - ) - - if _same_metadata(self, src): - self_tensors = self.__tensor_flatten__()[0] - for tensor_name in self_tensors: - getattr(self, tensor_name).copy_(getattr(src, tensor_name)) - for key in self.gemlite_kwargs: - self.gemlite_kwargs[key] = src.gemlite_kwargs[key] - return - raise ValueError( - f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" - ) - - raise NotImplementedError( - f"GemliteAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - __torch_function__ = torch._C._disabled_torch_function_impl - - def get_layout(self) -> Layout: - return self._layout - - @property - def block_size(self): - return (1, self._layout.group_size) - - -def _linear_fp_act_int4_weight_gemlite_impl(input_tensor, weight_tensor, bias=None): - if hasattr(weight_tensor, "tensor_impl"): - weight_impl = weight_tensor.tensor_impl - else: - weight_impl = weight_tensor - - return gemlite.core.forward_functional( - x=input_tensor, - bias=bias, - tensor_args=( - weight_impl.packed_weight, - weight_impl.scale, - weight_impl.zero_point, - ), - meta_args=weight_impl.gemlite_kwargs["meta_args"], - ) - - -def _linear_fp_act_int4_weight_gemlite_check(input_tensor, weight_tensor, bias): - return ( - # input is native fp16 tensor - not is_traceable_wrapper_subclass(input_tensor) - # and input_tensor.dtype in [torch.float16, torch.bfloat16] - # weight is gemlite layout - and isinstance(weight_tensor, AffineQuantizedTensor) - and isinstance(weight_tensor._layout, GemlitePackedLayout) - ) +from torchao.prototype.dtypes.uintx.gemlite_layout import ( # noqa: F401 + GemliteAQTTensorImpl, # noqa: F401 + GemlitePackedLayout, # noqa: F401 + _linear_fp_act_int4_weight_gemlite_check, # noqa: F401 + _linear_fp_act_int4_weight_gemlite_impl, # noqa: F401 + _same_metadata, # noqa: F401 + get_gemlite_aqt_kwargs, # noqa: F401 + get_gemlite_quant_kwargs, # noqa: F401 +) diff --git a/torchao/prototype/dtypes/uintx/gemlite_layout.py b/torchao/prototype/dtypes/uintx/gemlite_layout.py new file mode 100644 index 0000000000..8a8f2309c9 --- /dev/null +++ b/torchao/prototype/dtypes/uintx/gemlite_layout.py @@ -0,0 +1,452 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + +import torch +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledAQTTensorImpl +from torchao.dtypes.utils import Layout +from torchao.utils import fill_defaults + +try: + import gemlite +except: + gemlite = None + +aten = torch.ops.aten + + +def _same_metadata( + self: "GemliteAQTTensorImpl", + src: "GemliteAQTTensorImpl", +) -> bool: + kwargs_match = len(self.gemlite_kwargs) == len(src.gemlite_kwargs) + for k, v in self.gemlite_kwargs.items(): + if k in [ + "in_features", + "out_features", + "packing_bitwidth", + "elements_per_sample", + ]: + kwargs_match = kwargs_match and (v == src.gemlite_kwargs[k]) + + return ( + isinstance(self, GemliteAQTTensorImpl) + and isinstance(src, GemliteAQTTensorImpl) + and self.shape == src.shape + and self.packed_weight.shape == src.packed_weight.shape + and self.scale.shape == src.scale.shape + and self.zero_point.shape == src.zero_point.shape + and kwargs_match + and type(self._layout) == type(src._layout) + ) + + +def get_gemlite_quant_kwargs(bit_width, group_size, dtype): + from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain + + kwargs = {} + if bit_width != 8: + kwargs["mapping_type"] = MappingType.ASYMMETRIC + kwargs["block_size"] = (1, group_size) + kwargs["target_dtype"] = torch.uint8 + kwargs["eps"] = 1e-6 + kwargs["quant_min"] = 0 + kwargs["quant_max"] = (2**bit_width) - 1 + kwargs["eps"] = 1e-6 + kwargs["zero_point_dtype"] = dtype + kwargs["zero_point_domain"] = ZeroPointDomain.FLOAT + elif bit_width == 8: + kwargs["mapping_type"] = MappingType.SYMMETRIC + kwargs["block_size"] = (1, group_size) + kwargs["target_dtype"] = torch.int8 + kwargs["quant_min"] = -128 + kwargs["quant_max"] = 127 + kwargs["eps"] = 1e-5 + kwargs["zero_point_dtype"] = None + kwargs["zero_point_domain"] = ZeroPointDomain.NONE + return kwargs + + +def get_gemlite_aqt_kwargs( + weight, + group_size=64, + bit_width=4, + packing_bitwidth=None, + mode="weight_only", + use_hqq=True, +): + if gemlite is None: + raise ImportError( + "Unable to import 'gemlite'. Please ensure it is installed correctly. You can install it with: pip install gemlite" + ) + + assert bit_width in [ + 4, + 8, + ], f"gemlite only works with bit_width 4,8 but got {bit_width}" + + assert weight.dtype in [torch.float16, torch.bfloat16], ( + f"gemlite only works with dtype torch.float16 or torch.bfloat16 but got {weight.dtype}" + ) + assert group_size in [32, 64, 128, 256, 512, 1024, None] + assert group_size is None or bit_width != 8, ( + "gemlite only works with group_size=None for bit_width=8" + ) + assert packing_bitwidth in [8, 16, 32, None], ( + f"Invalid packing bitwidth, got {packing_bitwidth}" + ) + + assert mode in ["weight_only", "dynamic"], ( + f"Invalid mode: should be either weight_only or dynamic, got {mode}" + ) + + out_features, in_features = weight.shape + group_size = in_features if group_size is None else group_size + + aqt_kwargs = get_gemlite_quant_kwargs(bit_width, group_size, weight.dtype) + aqt_kwargs["_layout"] = GemlitePackedLayout( + group_size=group_size, + bit_width=bit_width, + packing_bitwidth=packing_bitwidth, + mode=mode, + ) + aqt_kwargs["use_hqq"] = use_hqq + return aqt_kwargs + + +@dataclass(frozen=True) +class GemlitePackedLayout(Layout): + group_size: Optional[int] = 128 + bit_width: int = 4 + packing_bitwidth: Optional[int] = None + mode: Optional[str] = "weight_only" + + +@register_layout(GemlitePackedLayout) +class GemliteAQTTensorImpl(TensorCoreTiledAQTTensorImpl): + def __new__( + cls, + packed_weight: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + gemlite_kwargs: Dict, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + gemlite_kwargs: Dict, + _layout: Layout, + ): + self.packed_weight = packed_weight + self.scale = scale + self.zero_point = zero_point + self.gemlite_kwargs = gemlite_kwargs + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_weight", "scale", "zero_point"], [ + self._layout, + self.gemlite_kwargs, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scale, zero_point = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scale"], + tensor_data_dict["zero_point"], + ) + _layout, gemlite_kwargs = tensor_attributes + return cls(packed_weight, scale, zero_point, gemlite_kwargs, _layout) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, GemlitePackedLayout), ( + f"GemliteAQTTensorImpl only works with GemliteLinearTriton but got {_layout}" + ) + device = int_data.device + if device.type != "cuda": + int_data = ( + int_data.cuda() + ) # We need int_data on cuda device because of Triton packing + + group_size, bit_width = _layout.group_size, _layout.bit_width + out_features, in_features = int_data.shape + packing_bitwidth = _layout.packing_bitwidth + mode = _layout.mode + + if bit_width == 8 and group_size == in_features: + processor = ( + gemlite.helper.A8W8_int8_dynamic + if mode == "dynamic" + else gemlite.helper.A16W8 + ) + gemlite_linear = processor(device=int_data.device).from_weights( + int_data, scales=scale, bias=None + ) + else: + processor = ( + gemlite.helper.A8Wn_dynamic + if mode == "dynamic" + else gemlite.helper.A16Wn + ) + gemlite_linear = processor( + device=int_data.device, packing_bitwidth=packing_bitwidth + ).from_weights( + int_data, scale, zero_point, bit_width, group_size, bias=None + ) + + meta_args = gemlite_linear.get_meta_args() + gemlite_kwargs = { + "in_features": in_features, + "out_features": out_features, + "packing_bitwidth": packing_bitwidth, + "data_contiguous": gemlite_linear.data_contiguous, + "elements_per_sample": gemlite_linear.elements_per_sample, + "W_group_mode": gemlite_linear.W_group_mode, + "meta_args": meta_args, + } + + packed_weight, scale, zero_point = gemlite_linear.get_tensor_args() + packed_weight = packed_weight.to(device) + if zero_point is None: + zero_point = torch.tensor( + [[]], device=packed_weight.device, dtype=torch.int32 + ) + + return cls(packed_weight, scale, zero_point, gemlite_kwargs, _layout) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs["device"] + return self.__class__( + self.packed_weight.to(device), + self.scale.to(device), + self.zero_point.to(device), + self.gemlite_kwargs, + self._layout, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_weight), + fn(self.scale), + fn(self.zero_point), + self.gemlite_kwargs, + self._layout, + ) + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + device = self.packed_weight.device + int_data = ( + ( + gemlite.bitpack.unpack_over_rows( + self.packed_weight.cuda(), + W_nbits=self._layout.bit_width, + num_output_rows=self.gemlite_kwargs["in_features"], + dtype=torch.uint8, + ) + ) + .to(device) + .t() + ) + + # Preserve col-row major layout + if self.gemlite_kwargs["data_contiguous"]: + int_data = int_data.contiguous() + + # Handle FMA mode: W_q * s + z -> (W_q - z) * s + if self.gemlite_kwargs["W_group_mode"] == 4: + scale_min_val = 1e-8 + scale = self.scale.clone().float() + scale[torch.logical_and(scale >= 0, scale.abs() <= scale_min_val)] = ( + scale_min_val + ) + scale[ + torch.logical_and(scale < 0, scale.abs() <= scale_min_val) + ] = -scale_min_val + zero_point = (-self.zero_point.float() / scale).clamp_(-100, 100) + zero_point = zero_point.to(self.scale.dtype) + else: + zero_point = self.zero_point + + scale = self.scale.t().contiguous() + zero_point = zero_point.t().contiguous() + + return int_data, scale, zero_point + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + # we don't handle transpose operations and just ignore them. In practice the only + # reason a transpsoe should occur is because the functional linear + # op can decompose into e.g. transpose + addmm so since we want + # to use the gemlite matmul kernel, which expects teh weight to be passed in as is, + # we ignore the transpose + if func is aten.detach.default or func is aten.t.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + if func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + assert step == 1, "Only step == 1 is supported in slicing right now" + + if dim in [0, 1]: + # data in self is transposed, meaning forward() performs x @ W_deq not x @ W_deq.T + dim = 1 - dim + packed_weight = self.packed_weight + scale = self.scale + zero_point = self.zero_point + + gemlite_kwargs = self.gemlite_kwargs.copy() + orig_shape = [ + gemlite_kwargs["in_features"], + gemlite_kwargs["out_features"], + ] + elements_per_sample = gemlite_kwargs["elements_per_sample"] + data_len = orig_shape[dim] + scale_len = scale.shape[dim] + ratio = data_len / scale_len + start_scale = int(start / ratio) + end_scale = int(end / ratio) + + # For packing only the K dimension. This should be flipped for N-dim packing. + div = elements_per_sample if dim == 0 else 1 + packed_weight = aten.slice.Tensor( + packed_weight, dim, start // div, end // div, step + ) + + # Update in_features/out_features + gemlite_kwargs["in_features"] = ( + packed_weight.shape[0] * elements_per_sample + ) + gemlite_kwargs["out_features"] = packed_weight.shape[1] + + scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) + if zero_point is not None and zero_point.numel() > 0: + zero_point = aten.slice.Tensor( + zero_point, dim, start_scale, end_scale, step + ) + else: + zero_point = None + + sliced = GemliteAQTTensorImpl( + packed_weight, scale, zero_point, gemlite_kwargs, self._layout + ) + return return_and_correct_aliasing(func, args, kwargs, sliced) + + else: + raise NotImplementedError( + f"GemliteAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) + + elif func is aten.copy_.default: + self = args[0] + src = args[1] + + # Handle zero_point = None with symmetric quant + if self.zero_point is None: + self.zero_point = torch.tensor( + [[]], device=self.packed_weight.device, dtype=torch.int32 + ) + + if src.zero_point is None: + src.zero_point = torch.tensor( + [[]], device=src.packed_weight.device, dtype=torch.int32 + ) + + if _same_metadata(self, src): + self_tensors = self.__tensor_flatten__()[0] + for tensor_name in self_tensors: + getattr(self, tensor_name).copy_(getattr(src, tensor_name)) + for key in self.gemlite_kwargs: + self.gemlite_kwargs[key] = src.gemlite_kwargs[key] + return + raise ValueError( + f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" + ) + + raise NotImplementedError( + f"GemliteAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_layout(self) -> Layout: + return self._layout + + @property + def block_size(self): + return (1, self._layout.group_size) + + +def _linear_fp_act_int4_weight_gemlite_impl(input_tensor, weight_tensor, bias=None): + if hasattr(weight_tensor, "tensor_impl"): + weight_impl = weight_tensor.tensor_impl + else: + weight_impl = weight_tensor + + return gemlite.core.forward_functional( + x=input_tensor, + bias=bias, + tensor_args=( + weight_impl.packed_weight, + weight_impl.scale, + weight_impl.zero_point, + ), + meta_args=weight_impl.gemlite_kwargs["meta_args"], + ) + + +def _linear_fp_act_int4_weight_gemlite_check(input_tensor, weight_tensor, bias): + return ( + # input is native fp16 tensor + not is_traceable_wrapper_subclass(input_tensor) + # and input_tensor.dtype in [torch.float16, torch.bfloat16] + # weight is gemlite layout + and isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, GemlitePackedLayout) + ) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index c72e18a923..884c96559a 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -724,7 +724,7 @@ class AQGemliteInt4G32WeightOnlyQuantizedLinearWeight( @classmethod def from_float(cls, weight): from torchao.dtypes import to_affine_quantized_intx - from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs + from torchao.prototype.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs if weight.dtype != torch.float16: weight = weight.to(torch.float16) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 7054e00564..9da0e612b9 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1058,7 +1058,7 @@ def _gemlite_uintx_weight_only_transform( weight = module.weight - from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs + from torchao.prototype.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs use_hqq = True if bit_width == 4 else False new_weight = to_affine_quantized_intx( From 547e78575703286fe8ef5de87506a6304128d202 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 7 Nov 2025 14:44:21 -0800 Subject: [PATCH 14/15] updates --- torchao/prototype/dtypes/__init__.py | 2 ++ torchao/prototype/dtypes/uintx/__init__.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py index 294c7d0b15..6033cdb8b8 100644 --- a/torchao/prototype/dtypes/__init__.py +++ b/torchao/prototype/dtypes/__init__.py @@ -11,6 +11,7 @@ MarlinQQQLayout, MarlinQQQTensor, to_marlinqqq_quantized_intx, + GemlitePackedLayout, ) __all__ = [ @@ -20,4 +21,5 @@ "MarlinQQQLayout", "MarlinQQQTensor", "to_marlinqqq_quantized_intx", + "GemlitePackedLayout", ] diff --git a/torchao/prototype/dtypes/uintx/__init__.py b/torchao/prototype/dtypes/uintx/__init__.py index cd333a90e9..f21be2e072 100644 --- a/torchao/prototype/dtypes/uintx/__init__.py +++ b/torchao/prototype/dtypes/uintx/__init__.py @@ -12,6 +12,7 @@ MarlinQQQTensor, to_marlinqqq_quantized_intx, ) +from .gemlite_layout import GemlitePackedLayout __all__ = [ "BlockSparseLayout", @@ -20,4 +21,5 @@ "MarlinQQQLayout", "MarlinQQQTensor", "to_marlinqqq_quantized_intx", + "GemlitePackedLayout", ] From 625f8da29771d770fd7b19f1ec4ee7d34eac23d5 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 10 Nov 2025 11:11:37 -0800 Subject: [PATCH 15/15] lint fixes --- torchao/prototype/dtypes/__init__.py | 2 +- torchao/prototype/dtypes/uintx/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py index 6033cdb8b8..7ad78dbed6 100644 --- a/torchao/prototype/dtypes/__init__.py +++ b/torchao/prototype/dtypes/__init__.py @@ -7,11 +7,11 @@ from .uintx import ( BlockSparseLayout, CutlassInt4PackedLayout, + GemlitePackedLayout, Int8DynamicActInt4WeightCPULayout, MarlinQQQLayout, MarlinQQQTensor, to_marlinqqq_quantized_intx, - GemlitePackedLayout, ) __all__ = [ diff --git a/torchao/prototype/dtypes/uintx/__init__.py b/torchao/prototype/dtypes/uintx/__init__.py index f21be2e072..56b1eed50a 100644 --- a/torchao/prototype/dtypes/uintx/__init__.py +++ b/torchao/prototype/dtypes/uintx/__init__.py @@ -7,12 +7,12 @@ from .block_sparse_layout import BlockSparseLayout from .cutlass_int4_packed_layout import CutlassInt4PackedLayout from .dyn_int8_act_int4_wei_cpu_layout import Int8DynamicActInt4WeightCPULayout +from .gemlite_layout import GemlitePackedLayout from .marlin_qqq_tensor import ( MarlinQQQLayout, MarlinQQQTensor, to_marlinqqq_quantized_intx, ) -from .gemlite_layout import GemlitePackedLayout __all__ = [ "BlockSparseLayout",