Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
40de7e0
Move dyn_int8_act_int4_wei_cpu_layout to prototype
jainapurva Nov 6, 2025
d3db93e
Move dyn_int8_act_int4_wei_cpu_layout to prototype
jainapurva Nov 6, 2025
6d87fb3
Move marlin_qqq_tensor to prototype
jainapurva Nov 6, 2025
3fb6a2c
Move all deprecated api tests into a single file
jainapurva Nov 7, 2025
bcad575
Merge remote-tracking branch 'origin/move_dyn_int8' into move_marlin_qqq
jainapurva Nov 7, 2025
67bc40a
Add to docs
jainapurva Nov 7, 2025
2fd4321
Merge remote-tracking branch 'origin/move_dyn_int8' into move_marlin_qqq
jainapurva Nov 7, 2025
ae101f6
Add deprecation test
jainapurva Nov 7, 2025
99bb707
Fixes
jainapurva Nov 7, 2025
f5d7e3a
Empty commit to trigger CI
jainapurva Nov 7, 2025
ecece72
Empty commit to trigger CI
jainapurva Nov 7, 2025
4706b67
Merge remote-tracking branch 'origin/main' into move_dyn_int8
jainapurva Nov 7, 2025
ed58e1e
Update tests
jainapurva Nov 7, 2025
4f5ddd8
lint fixes
jainapurva Nov 7, 2025
257e94a
Merge remote-tracking branch 'origin/move_dyn_int8' into move_marlin_qqq
jainapurva Nov 7, 2025
74511e2
Update test
jainapurva Nov 7, 2025
5e0f397
Move gemlite_layout.py to prototype/dtypes
jainapurva Nov 7, 2025
547e785
updates
jainapurva Nov 7, 2025
88e54be
Merge branch 'main' into move_gemlite_layout
jainapurva Nov 8, 2025
71f039b
Move uintx_layout to prototype/dtypes
jainapurva Nov 8, 2025
625f8da
lint fixes
jainapurva Nov 10, 2025
8861a2e
Merge remote-tracking branch 'origin/move_gemlite_layout' into move_u…
jainapurva Nov 10, 2025
ea40463
Merge remote-tracking branch 'origin/main' into move_gemlite_layout
jainapurva Nov 11, 2025
0c876bb
Merge remote-tracking branch 'origin/main' into move_uintx_layout
jainapurva Nov 11, 2025
1b0cca5
Merge remote-tracking branch 'origin/move_gemlite_layout' into move_u…
jainapurva Nov 11, 2025
5a2bde6
Merge remote-tracking branch 'origin/main' into move_gemlite_layout
jainapurva Nov 12, 2025
57f57f6
Merge remote-tracking branch 'origin/move_gemlite_layout' into move_u…
jainapurva Nov 12, 2025
6889932
<Replace this line with a title. Use 1 line only, 67 chars or less>
jainapurva Nov 12, 2025
8576cbd
minor fixes
jainapurva Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ Layouts and Tensor Subclasses
FloatxTensor
FloatxTensorCoreLayout
MarlinSparseLayout
UintxLayout
Int4CPULayout
CutlassSemiSparseLayout

Expand Down Expand Up @@ -53,6 +52,7 @@ Prototype
Int8DynamicActInt4WeightCPULayout
MarlinQQQTensor
MarlinQQQLayout
UintxLayout

..
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring
Expand Down
3 changes: 2 additions & 1 deletion test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
import torch

from torchao.dtypes.uintx.uintx_layout import to_uintx
from torchao.prototype.dtypes.uintx.uintx_layout import to_uintx
from torchao.quantization.quant_api import UIntXWeightOnlyConfig, quantize_
from torchao.quantization.quant_primitives import (
MappingType,
Expand Down Expand Up @@ -183,6 +183,7 @@ def test_uintx_api_deprecation():
("CutlassInt4PackedLayout", "torchao.dtypes.uintx.cutlass_int4_packed_layout"),
("BlockSparseLayout", "torchao.dtypes.uintx.block_sparse_layout"),
("MarlinQQQLayout", "torchao.dtypes.uintx.marlin_qqq_tensor"),
("UintxLayout", "torchao.dtypes.uintx.uintx_layout"),
]

for api_name, module_path in deprecated_apis:
Expand Down
2 changes: 1 addition & 1 deletion torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
QDQLayout,
SemiSparseLayout,
TensorCoreTiledLayout,
UintxLayout,
)
from .uintx.block_sparse_layout import BlockSparseLayout
from .uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout
Expand All @@ -31,6 +30,7 @@
MarlinQQQTensor,
to_marlinqqq_quantized_intx,
)
from .uintx.uintx_layout import UintxLayout
from .utils import (
Layout,
PlainLayout,
Expand Down
260 changes: 17 additions & 243 deletions torchao/dtypes/uintx/uintx_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,250 +3,24 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import List, Tuple

import torch
from torch.utils._python_dispatch import return_and_correct_aliasing
# Backward compatibility stub - imports from the new location
import warnings

from torchao.dtypes.affine_quantized_tensor import register_layout
from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl
from torchao.dtypes.utils import (
Layout,
warnings.warn(
"Importing from torchao.dtypes.uintx.uintx_layout is deprecated. "
"Please use 'from torchao.prototype.dtypes import UintxLayout, UintxTensor' instead. "
"This import path will be removed in a future release of torchao. "
"See https://github.com/pytorch/ao/issues/2752 for more details.",
DeprecationWarning,
stacklevel=2,
)
from torchao.utils import TorchAOBaseTensor

from .bitpacking import pack, unpack

aten = torch.ops.aten

# Note: Uintx does not work for torch 2.3 and below
_DTYPE_TO_BIT_WIDTH = {}
_BIT_WIDTH_TO_DTYPE = {}

_DTYPE_TO_BIT_WIDTH = {
torch.uint1: 1,
torch.uint2: 2,
torch.uint3: 3,
torch.uint4: 4,
torch.uint5: 5,
torch.uint6: 6,
torch.uint7: 7,
}

_BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()}


class UintxTensor(TorchAOBaseTensor):
"""
Splits int data into packed shards based on bit size
fields:
int4_shard (torch.Tensor): 4 bit packed shard
int2_shard (torch.Tensor): 2 bit packed shard
int1_shard (torch.Tensor): 1 bit packed shard
bit_width (int): number of bits for each element
pack_dim: (int) dimension to pack along
"""

bits_to_shard = {
1: ["int1_shard"],
2: ["int2_shard"],
3: ["int2_shard", "int1_shard"],
4: ["int4_shard"],
5: ["int4_shard", "int1_shard"],
6: ["int4_shard", "int2_shard"],
7: ["int4_shard", "int2_shard", "int1_shard"],
}

def __new__(
cls,
shards: List[torch.Tensor],
packed_shape: List[int],
bit_width: int,
pack_dim: int = -1,
):
kwargs = {"device": shards[0].device}
kwargs["device"] = shards[0].device
kwargs["layout"] = shards[0].layout
kwargs["requires_grad"] = False
kwargs["dtype"] = torch.uint8
return torch.Tensor._make_wrapper_subclass(cls, packed_shape, **kwargs)

def __init__(
self,
shards: List[torch.Tensor],
packed_shape: List[int],
bit_width: int,
pack_dim: int = -1,
):
for i, attrib in enumerate(self.bits_to_shard[bit_width]):
setattr(self, attrib, shards[i])

self.packed_shape = packed_shape
self.bit_width = bit_width
self.pack_dim = pack_dim

def get_shards(self):
return [getattr(self, i) for i in self.__class__.bits_to_shard[self.bit_width]]

def __repr__(self):
return f"Int{self.bit_width}Tensor(shape = {self.packed_shape}, data = {unpack(self.get_shards(), self.bit_width, dim=self.pack_dim)})"

def __tensor_flatten__(self):
return self.__class__.bits_to_shard[self.bit_width], [
self.packed_shape,
self.bit_width,
self.pack_dim,
]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
shards = list(tensor_data_dict.values())
packed_shape, bit_width, pack_dim = tensor_attributes
return cls(shards, packed_shape, bit_width, pack_dim)

def get_plain(self):
return unpack(self.get_shards(), self.bit_width, dim=self.pack_dim)

# temporary until kernels on packed tensors are created
def apply_transformation(self, fn):
og = self.get_plain()
new = fn(og)
dtype = _BIT_WIDTH_TO_DTYPE[self.bit_width]
return self.from_uint8(new, dtype, self.pack_dim)

# temporary until kernels on packed tensors are created
def apply_fn_to_shards(self, fn):
new_shards = [fn(shard) for shard in self.get_shards()]
return self.__class__(
new_shards, self.packed_shape, self.bit_width, self.pack_dim
)

@classmethod
def from_uint8(cls, int_data: torch.Tensor, dtype: torch.dtype, pack_dim: int = -1):
assert dtype in _DTYPE_TO_BIT_WIDTH.keys(), (
"Expected dtype to be one of {_DTYPE_TO_BIT_WIDTH.keys()}"
)
bit_width = _DTYPE_TO_BIT_WIDTH[dtype]
shards = pack(int_data, bit_width, dim=pack_dim)
shape = list(int_data.shape)
shape[pack_dim] = shape[pack_dim] * bit_width // 8
return cls(shards, int_data.shape, bit_width, pack_dim)

def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
dtype = self.dtype if dtype is None else dtype
memory_format = (
memory_format if memory_format is not None else torch.preserve_format
)
kwargs = {
"device": device,
"dtype": dtype,
"memory_format": memory_format,
}
return kwargs

def to(self, *args, **kwargs):
if "copy" in kwargs:
return super().to(*args, **kwargs)
kwargs = self._get_to_kwargs(*args, **kwargs)
if "device" in kwargs:
return self.__class__(
list(shard.to(kwargs["device"]) for shard in self.get_shards()),
self.packed_shape,
self.bit_width,
self.pack_dim,
)
return super().to(*args, **kwargs)


implements = UintxTensor.implements


@implements(aten.detach.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0].apply_fn_to_shards(torch.detach)
)


@implements(aten.view.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0].apply_transformation(lambda x: x.view(*args[1:]))
)


@implements(aten._to_copy.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, args[0])


@implements(aten.sub.Tensor)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0].apply_transformation(lambda x: (x - args[1]).to(torch.uint8)),
)


@implements(aten.mul.Tensor)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8)),
)


# quantization api integrations
to_uintx = UintxTensor.from_uint8


@dataclass(frozen=True)
class UintxLayout(Layout):
"""A layout class for Uintx tensors, which are tensors with elements packed into
smaller bit-widths than the standard 8-bit byte. This layout is used to define
how the data is stored and processed in UintxTensor objects.

Attributes:
dtype (torch.dtype): The data type of the tensor elements, which determines
the bit-width used for packing.
pack_dim (int): The dimension along which the data is packed. Default is -1,
which indicates the last dimension.
"""

dtype: torch.dtype
pack_dim: int = -1

def post_process(
self,
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
block_size: Tuple[int, ...],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return to_uintx(input, self.dtype, self.pack_dim), scale, zero_point


@register_layout(UintxLayout)
class UintxAQTTensorImpl(PlainAQTTensorImpl):
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return self.int_data.get_plain(), self.scale, self.zero_point

@classmethod
def from_plain(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
_layout: Layout,
):
assert isinstance(_layout, UintxLayout)
return cls(int_data, scale, zero_point, _layout)
from torchao.prototype.dtypes.uintx.uintx_layout import ( # noqa: F401
_BIT_WIDTH_TO_DTYPE, # noqa: F401
_DTYPE_TO_BIT_WIDTH, # noqa: F401
UintxAQTTensorImpl, # noqa: F401
UintxLayout, # noqa: F401
UintxTensor, # noqa: F401
to_uintx, # noqa: F401
)
2 changes: 1 addition & 1 deletion torchao/prototype/autoround/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def to_uintx_weight(input_float):
quant_min = 0
quant_max = _auto_round_config.bits**2 - 1
block_size = (1, observed_linear.group_size)
from torchao.dtypes.uintx.uintx import (
from torchao.prototype.dtypes.uintx.uintx_layout import (
_BIT_WIDTH_TO_DTYPE,
UintxLayout,
)
Expand Down
4 changes: 3 additions & 1 deletion torchao/prototype/autoround/eval_autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def main(args):
)
elif args.uintx:
msg += f" (uintx {args.bits} bits)"
from torchao.dtypes.uintx.uintx import _BIT_WIDTH_TO_DTYPE
from torchao.prototype.dtypes.uintx.uintx_layout import (
_BIT_WIDTH_TO_DTYPE,
)
from torchao.quantization.quant_api import (
UIntXWeightOnlyConfig,
quantize_,
Expand Down
8 changes: 8 additions & 0 deletions torchao/prototype/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
Int8DynamicActInt4WeightCPULayout,
MarlinQQQLayout,
MarlinQQQTensor,
UintxAQTTensorImpl,
UintxLayout,
UintxTensor,
to_marlinqqq_quantized_intx,
to_uintx,
)

__all__ = [
Expand All @@ -22,4 +26,8 @@
"MarlinQQQTensor",
"to_marlinqqq_quantized_intx",
"GemlitePackedLayout",
"UintxLayout",
"UintxTensor",
"UintxAQTTensorImpl",
"to_uintx",
]
10 changes: 10 additions & 0 deletions torchao/prototype/dtypes/uintx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
MarlinQQQTensor,
to_marlinqqq_quantized_intx,
)
from .uintx_layout import (
UintxAQTTensorImpl,
UintxLayout,
UintxTensor,
to_uintx,
)

__all__ = [
"BlockSparseLayout",
Expand All @@ -22,4 +28,8 @@
"MarlinQQQTensor",
"to_marlinqqq_quantized_intx",
"GemlitePackedLayout",
"UintxLayout",
"UintxTensor",
"UintxAQTTensorImpl",
"to_uintx",
]
Loading
Loading