Skip to content

Commit 1a9b6f4

Browse files
authored
Move uintx_layout to prototype/dtypes (#3316)
1 parent 1e5bc3b commit 1a9b6f4

File tree

10 files changed

+295
-249
lines changed

10 files changed

+295
-249
lines changed

docs/source/api_ref_dtypes.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ Layouts and Tensor Subclasses
2222
FloatxTensor
2323
FloatxTensorCoreLayout
2424
MarlinSparseLayout
25-
UintxLayout
2625
Int4CPULayout
2726
CutlassSemiSparseLayout
2827

@@ -53,6 +52,7 @@ Prototype
5352
Int8DynamicActInt4WeightCPULayout
5453
MarlinQQQTensor
5554
MarlinQQQLayout
55+
UintxLayout
5656

5757
..
5858
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring

test/dtypes/test_uintx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytest
1010
import torch
1111

12-
from torchao.dtypes.uintx.uintx_layout import to_uintx
12+
from torchao.prototype.dtypes.uintx.uintx_layout import to_uintx
1313
from torchao.quantization.quant_api import UIntXWeightOnlyConfig, quantize_
1414
from torchao.quantization.quant_primitives import (
1515
MappingType,
@@ -183,6 +183,7 @@ def test_uintx_api_deprecation():
183183
("CutlassInt4PackedLayout", "torchao.dtypes.uintx.cutlass_int4_packed_layout"),
184184
("BlockSparseLayout", "torchao.dtypes.uintx.block_sparse_layout"),
185185
("MarlinQQQLayout", "torchao.dtypes.uintx.marlin_qqq_tensor"),
186+
("UintxLayout", "torchao.dtypes.uintx.uintx_layout"),
186187
]
187188

188189
for api_name, module_path in deprecated_apis:

torchao/dtypes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
QDQLayout,
2222
SemiSparseLayout,
2323
TensorCoreTiledLayout,
24-
UintxLayout,
2524
)
2625
from .uintx.block_sparse_layout import BlockSparseLayout
2726
from .uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout
@@ -31,6 +30,7 @@
3130
MarlinQQQTensor,
3231
to_marlinqqq_quantized_intx,
3332
)
33+
from .uintx.uintx_layout import UintxLayout
3434
from .utils import (
3535
Layout,
3636
PlainLayout,

torchao/dtypes/uintx/uintx_layout.py

Lines changed: 17 additions & 243 deletions
Original file line numberDiff line numberDiff line change
@@ -3,250 +3,24 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
from dataclasses import dataclass
7-
from typing import List, Tuple
86

9-
import torch
10-
from torch.utils._python_dispatch import return_and_correct_aliasing
7+
# Backward compatibility stub - imports from the new location
8+
import warnings
119

12-
from torchao.dtypes.affine_quantized_tensor import register_layout
13-
from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl
14-
from torchao.dtypes.utils import (
15-
Layout,
10+
warnings.warn(
11+
"Importing from torchao.dtypes.uintx.uintx_layout is deprecated. "
12+
"Please use 'from torchao.prototype.dtypes import UintxLayout, UintxTensor' instead. "
13+
"This import path will be removed in a future release of torchao. "
14+
"See https://github.com/pytorch/ao/issues/2752 for more details.",
15+
DeprecationWarning,
16+
stacklevel=2,
1617
)
17-
from torchao.utils import TorchAOBaseTensor
1818

19-
from .bitpacking import pack, unpack
20-
21-
aten = torch.ops.aten
22-
23-
# Note: Uintx does not work for torch 2.3 and below
24-
_DTYPE_TO_BIT_WIDTH = {}
25-
_BIT_WIDTH_TO_DTYPE = {}
26-
27-
_DTYPE_TO_BIT_WIDTH = {
28-
torch.uint1: 1,
29-
torch.uint2: 2,
30-
torch.uint3: 3,
31-
torch.uint4: 4,
32-
torch.uint5: 5,
33-
torch.uint6: 6,
34-
torch.uint7: 7,
35-
}
36-
37-
_BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()}
38-
39-
40-
class UintxTensor(TorchAOBaseTensor):
41-
"""
42-
Splits int data into packed shards based on bit size
43-
fields:
44-
int4_shard (torch.Tensor): 4 bit packed shard
45-
int2_shard (torch.Tensor): 2 bit packed shard
46-
int1_shard (torch.Tensor): 1 bit packed shard
47-
bit_width (int): number of bits for each element
48-
pack_dim: (int) dimension to pack along
49-
"""
50-
51-
bits_to_shard = {
52-
1: ["int1_shard"],
53-
2: ["int2_shard"],
54-
3: ["int2_shard", "int1_shard"],
55-
4: ["int4_shard"],
56-
5: ["int4_shard", "int1_shard"],
57-
6: ["int4_shard", "int2_shard"],
58-
7: ["int4_shard", "int2_shard", "int1_shard"],
59-
}
60-
61-
def __new__(
62-
cls,
63-
shards: List[torch.Tensor],
64-
packed_shape: List[int],
65-
bit_width: int,
66-
pack_dim: int = -1,
67-
):
68-
kwargs = {"device": shards[0].device}
69-
kwargs["device"] = shards[0].device
70-
kwargs["layout"] = shards[0].layout
71-
kwargs["requires_grad"] = False
72-
kwargs["dtype"] = torch.uint8
73-
return torch.Tensor._make_wrapper_subclass(cls, packed_shape, **kwargs)
74-
75-
def __init__(
76-
self,
77-
shards: List[torch.Tensor],
78-
packed_shape: List[int],
79-
bit_width: int,
80-
pack_dim: int = -1,
81-
):
82-
for i, attrib in enumerate(self.bits_to_shard[bit_width]):
83-
setattr(self, attrib, shards[i])
84-
85-
self.packed_shape = packed_shape
86-
self.bit_width = bit_width
87-
self.pack_dim = pack_dim
88-
89-
def get_shards(self):
90-
return [getattr(self, i) for i in self.__class__.bits_to_shard[self.bit_width]]
91-
92-
def __repr__(self):
93-
return f"Int{self.bit_width}Tensor(shape = {self.packed_shape}, data = {unpack(self.get_shards(), self.bit_width, dim=self.pack_dim)})"
94-
95-
def __tensor_flatten__(self):
96-
return self.__class__.bits_to_shard[self.bit_width], [
97-
self.packed_shape,
98-
self.bit_width,
99-
self.pack_dim,
100-
]
101-
102-
@classmethod
103-
def __tensor_unflatten__(
104-
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
105-
):
106-
shards = list(tensor_data_dict.values())
107-
packed_shape, bit_width, pack_dim = tensor_attributes
108-
return cls(shards, packed_shape, bit_width, pack_dim)
109-
110-
def get_plain(self):
111-
return unpack(self.get_shards(), self.bit_width, dim=self.pack_dim)
112-
113-
# temporary until kernels on packed tensors are created
114-
def apply_transformation(self, fn):
115-
og = self.get_plain()
116-
new = fn(og)
117-
dtype = _BIT_WIDTH_TO_DTYPE[self.bit_width]
118-
return self.from_uint8(new, dtype, self.pack_dim)
119-
120-
# temporary until kernels on packed tensors are created
121-
def apply_fn_to_shards(self, fn):
122-
new_shards = [fn(shard) for shard in self.get_shards()]
123-
return self.__class__(
124-
new_shards, self.packed_shape, self.bit_width, self.pack_dim
125-
)
126-
127-
@classmethod
128-
def from_uint8(cls, int_data: torch.Tensor, dtype: torch.dtype, pack_dim: int = -1):
129-
assert dtype in _DTYPE_TO_BIT_WIDTH.keys(), (
130-
"Expected dtype to be one of {_DTYPE_TO_BIT_WIDTH.keys()}"
131-
)
132-
bit_width = _DTYPE_TO_BIT_WIDTH[dtype]
133-
shards = pack(int_data, bit_width, dim=pack_dim)
134-
shape = list(int_data.shape)
135-
shape[pack_dim] = shape[pack_dim] * bit_width // 8
136-
return cls(shards, int_data.shape, bit_width, pack_dim)
137-
138-
def _get_to_kwargs(self, *args, **kwargs):
139-
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
140-
device = self.device if device is None else device
141-
dtype = self.dtype if dtype is None else dtype
142-
memory_format = (
143-
memory_format if memory_format is not None else torch.preserve_format
144-
)
145-
kwargs = {
146-
"device": device,
147-
"dtype": dtype,
148-
"memory_format": memory_format,
149-
}
150-
return kwargs
151-
152-
def to(self, *args, **kwargs):
153-
if "copy" in kwargs:
154-
return super().to(*args, **kwargs)
155-
kwargs = self._get_to_kwargs(*args, **kwargs)
156-
if "device" in kwargs:
157-
return self.__class__(
158-
list(shard.to(kwargs["device"]) for shard in self.get_shards()),
159-
self.packed_shape,
160-
self.bit_width,
161-
self.pack_dim,
162-
)
163-
return super().to(*args, **kwargs)
164-
165-
166-
implements = UintxTensor.implements
167-
168-
169-
@implements(aten.detach.default)
170-
def _(func, types, args, kwargs):
171-
return return_and_correct_aliasing(
172-
func, args, kwargs, args[0].apply_fn_to_shards(torch.detach)
173-
)
174-
175-
176-
@implements(aten.view.default)
177-
def _(func, types, args, kwargs):
178-
return return_and_correct_aliasing(
179-
func, args, kwargs, args[0].apply_transformation(lambda x: x.view(*args[1:]))
180-
)
181-
182-
183-
@implements(aten._to_copy.default)
184-
def _(func, types, args, kwargs):
185-
return return_and_correct_aliasing(func, args, kwargs, args[0])
186-
187-
188-
@implements(aten.sub.Tensor)
189-
def _(func, types, args, kwargs):
190-
return return_and_correct_aliasing(
191-
func,
192-
args,
193-
kwargs,
194-
args[0].apply_transformation(lambda x: (x - args[1]).to(torch.uint8)),
195-
)
196-
197-
198-
@implements(aten.mul.Tensor)
199-
def _(func, types, args, kwargs):
200-
return return_and_correct_aliasing(
201-
func,
202-
args,
203-
kwargs,
204-
args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8)),
205-
)
206-
207-
208-
# quantization api integrations
209-
to_uintx = UintxTensor.from_uint8
210-
211-
212-
@dataclass(frozen=True)
213-
class UintxLayout(Layout):
214-
"""A layout class for Uintx tensors, which are tensors with elements packed into
215-
smaller bit-widths than the standard 8-bit byte. This layout is used to define
216-
how the data is stored and processed in UintxTensor objects.
217-
218-
Attributes:
219-
dtype (torch.dtype): The data type of the tensor elements, which determines
220-
the bit-width used for packing.
221-
pack_dim (int): The dimension along which the data is packed. Default is -1,
222-
which indicates the last dimension.
223-
"""
224-
225-
dtype: torch.dtype
226-
pack_dim: int = -1
227-
228-
def post_process(
229-
self,
230-
input: torch.Tensor,
231-
scale: torch.Tensor,
232-
zero_point: torch.Tensor,
233-
block_size: Tuple[int, ...],
234-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
235-
return to_uintx(input, self.dtype, self.pack_dim), scale, zero_point
236-
237-
238-
@register_layout(UintxLayout)
239-
class UintxAQTTensorImpl(PlainAQTTensorImpl):
240-
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
241-
return self.int_data.get_plain(), self.scale, self.zero_point
242-
243-
@classmethod
244-
def from_plain(
245-
cls,
246-
int_data: torch.Tensor,
247-
scale: torch.Tensor,
248-
zero_point: torch.Tensor,
249-
_layout: Layout,
250-
):
251-
assert isinstance(_layout, UintxLayout)
252-
return cls(int_data, scale, zero_point, _layout)
19+
from torchao.prototype.dtypes.uintx.uintx_layout import ( # noqa: F401
20+
_BIT_WIDTH_TO_DTYPE, # noqa: F401
21+
_DTYPE_TO_BIT_WIDTH, # noqa: F401
22+
UintxAQTTensorImpl, # noqa: F401
23+
UintxLayout, # noqa: F401
24+
UintxTensor, # noqa: F401
25+
to_uintx, # noqa: F401
26+
)

torchao/prototype/autoround/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def to_uintx_weight(input_float):
189189
quant_min = 0
190190
quant_max = _auto_round_config.bits**2 - 1
191191
block_size = (1, observed_linear.group_size)
192-
from torchao.dtypes.uintx.uintx import (
192+
from torchao.prototype.dtypes.uintx.uintx_layout import (
193193
_BIT_WIDTH_TO_DTYPE,
194194
UintxLayout,
195195
)

torchao/prototype/autoround/eval_autoround.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def main(args):
111111
)
112112
elif args.uintx:
113113
msg += f" (uintx {args.bits} bits)"
114-
from torchao.dtypes.uintx.uintx import _BIT_WIDTH_TO_DTYPE
114+
from torchao.prototype.dtypes.uintx.uintx_layout import (
115+
_BIT_WIDTH_TO_DTYPE,
116+
)
115117
from torchao.quantization.quant_api import (
116118
UIntXWeightOnlyConfig,
117119
quantize_,

torchao/prototype/dtypes/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
Int8DynamicActInt4WeightCPULayout,
1212
MarlinQQQLayout,
1313
MarlinQQQTensor,
14+
UintxAQTTensorImpl,
15+
UintxLayout,
16+
UintxTensor,
1417
to_marlinqqq_quantized_intx,
18+
to_uintx,
1519
)
1620

1721
__all__ = [
@@ -22,4 +26,8 @@
2226
"MarlinQQQTensor",
2327
"to_marlinqqq_quantized_intx",
2428
"GemlitePackedLayout",
29+
"UintxLayout",
30+
"UintxTensor",
31+
"UintxAQTTensorImpl",
32+
"to_uintx",
2533
]

torchao/prototype/dtypes/uintx/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
MarlinQQQTensor,
1414
to_marlinqqq_quantized_intx,
1515
)
16+
from .uintx_layout import (
17+
UintxAQTTensorImpl,
18+
UintxLayout,
19+
UintxTensor,
20+
to_uintx,
21+
)
1622

1723
__all__ = [
1824
"BlockSparseLayout",
@@ -22,4 +28,8 @@
2228
"MarlinQQQTensor",
2329
"to_marlinqqq_quantized_intx",
2430
"GemlitePackedLayout",
31+
"UintxLayout",
32+
"UintxTensor",
33+
"UintxAQTTensorImpl",
34+
"to_uintx",
2535
]

0 commit comments

Comments
 (0)