|
3 | 3 | # |
4 | 4 | # This source code is licensed under the BSD 3-Clause license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | | -from dataclasses import dataclass |
7 | | -from typing import List, Tuple |
8 | 6 |
|
9 | | -import torch |
10 | | -from torch.utils._python_dispatch import return_and_correct_aliasing |
| 7 | +# Backward compatibility stub - imports from the new location |
| 8 | +import warnings |
11 | 9 |
|
12 | | -from torchao.dtypes.affine_quantized_tensor import register_layout |
13 | | -from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl |
14 | | -from torchao.dtypes.utils import ( |
15 | | - Layout, |
| 10 | +warnings.warn( |
| 11 | + "Importing from torchao.dtypes.uintx.uintx_layout is deprecated. " |
| 12 | + "Please use 'from torchao.prototype.dtypes import UintxLayout, UintxTensor' instead. " |
| 13 | + "This import path will be removed in a future release of torchao. " |
| 14 | + "See https://github.com/pytorch/ao/issues/2752 for more details.", |
| 15 | + DeprecationWarning, |
| 16 | + stacklevel=2, |
16 | 17 | ) |
17 | | -from torchao.utils import TorchAOBaseTensor |
18 | 18 |
|
19 | | -from .bitpacking import pack, unpack |
20 | | - |
21 | | -aten = torch.ops.aten |
22 | | - |
23 | | -# Note: Uintx does not work for torch 2.3 and below |
24 | | -_DTYPE_TO_BIT_WIDTH = {} |
25 | | -_BIT_WIDTH_TO_DTYPE = {} |
26 | | - |
27 | | -_DTYPE_TO_BIT_WIDTH = { |
28 | | - torch.uint1: 1, |
29 | | - torch.uint2: 2, |
30 | | - torch.uint3: 3, |
31 | | - torch.uint4: 4, |
32 | | - torch.uint5: 5, |
33 | | - torch.uint6: 6, |
34 | | - torch.uint7: 7, |
35 | | -} |
36 | | - |
37 | | -_BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()} |
38 | | - |
39 | | - |
40 | | -class UintxTensor(TorchAOBaseTensor): |
41 | | - """ |
42 | | - Splits int data into packed shards based on bit size |
43 | | - fields: |
44 | | - int4_shard (torch.Tensor): 4 bit packed shard |
45 | | - int2_shard (torch.Tensor): 2 bit packed shard |
46 | | - int1_shard (torch.Tensor): 1 bit packed shard |
47 | | - bit_width (int): number of bits for each element |
48 | | - pack_dim: (int) dimension to pack along |
49 | | - """ |
50 | | - |
51 | | - bits_to_shard = { |
52 | | - 1: ["int1_shard"], |
53 | | - 2: ["int2_shard"], |
54 | | - 3: ["int2_shard", "int1_shard"], |
55 | | - 4: ["int4_shard"], |
56 | | - 5: ["int4_shard", "int1_shard"], |
57 | | - 6: ["int4_shard", "int2_shard"], |
58 | | - 7: ["int4_shard", "int2_shard", "int1_shard"], |
59 | | - } |
60 | | - |
61 | | - def __new__( |
62 | | - cls, |
63 | | - shards: List[torch.Tensor], |
64 | | - packed_shape: List[int], |
65 | | - bit_width: int, |
66 | | - pack_dim: int = -1, |
67 | | - ): |
68 | | - kwargs = {"device": shards[0].device} |
69 | | - kwargs["device"] = shards[0].device |
70 | | - kwargs["layout"] = shards[0].layout |
71 | | - kwargs["requires_grad"] = False |
72 | | - kwargs["dtype"] = torch.uint8 |
73 | | - return torch.Tensor._make_wrapper_subclass(cls, packed_shape, **kwargs) |
74 | | - |
75 | | - def __init__( |
76 | | - self, |
77 | | - shards: List[torch.Tensor], |
78 | | - packed_shape: List[int], |
79 | | - bit_width: int, |
80 | | - pack_dim: int = -1, |
81 | | - ): |
82 | | - for i, attrib in enumerate(self.bits_to_shard[bit_width]): |
83 | | - setattr(self, attrib, shards[i]) |
84 | | - |
85 | | - self.packed_shape = packed_shape |
86 | | - self.bit_width = bit_width |
87 | | - self.pack_dim = pack_dim |
88 | | - |
89 | | - def get_shards(self): |
90 | | - return [getattr(self, i) for i in self.__class__.bits_to_shard[self.bit_width]] |
91 | | - |
92 | | - def __repr__(self): |
93 | | - return f"Int{self.bit_width}Tensor(shape = {self.packed_shape}, data = {unpack(self.get_shards(), self.bit_width, dim=self.pack_dim)})" |
94 | | - |
95 | | - def __tensor_flatten__(self): |
96 | | - return self.__class__.bits_to_shard[self.bit_width], [ |
97 | | - self.packed_shape, |
98 | | - self.bit_width, |
99 | | - self.pack_dim, |
100 | | - ] |
101 | | - |
102 | | - @classmethod |
103 | | - def __tensor_unflatten__( |
104 | | - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride |
105 | | - ): |
106 | | - shards = list(tensor_data_dict.values()) |
107 | | - packed_shape, bit_width, pack_dim = tensor_attributes |
108 | | - return cls(shards, packed_shape, bit_width, pack_dim) |
109 | | - |
110 | | - def get_plain(self): |
111 | | - return unpack(self.get_shards(), self.bit_width, dim=self.pack_dim) |
112 | | - |
113 | | - # temporary until kernels on packed tensors are created |
114 | | - def apply_transformation(self, fn): |
115 | | - og = self.get_plain() |
116 | | - new = fn(og) |
117 | | - dtype = _BIT_WIDTH_TO_DTYPE[self.bit_width] |
118 | | - return self.from_uint8(new, dtype, self.pack_dim) |
119 | | - |
120 | | - # temporary until kernels on packed tensors are created |
121 | | - def apply_fn_to_shards(self, fn): |
122 | | - new_shards = [fn(shard) for shard in self.get_shards()] |
123 | | - return self.__class__( |
124 | | - new_shards, self.packed_shape, self.bit_width, self.pack_dim |
125 | | - ) |
126 | | - |
127 | | - @classmethod |
128 | | - def from_uint8(cls, int_data: torch.Tensor, dtype: torch.dtype, pack_dim: int = -1): |
129 | | - assert dtype in _DTYPE_TO_BIT_WIDTH.keys(), ( |
130 | | - "Expected dtype to be one of {_DTYPE_TO_BIT_WIDTH.keys()}" |
131 | | - ) |
132 | | - bit_width = _DTYPE_TO_BIT_WIDTH[dtype] |
133 | | - shards = pack(int_data, bit_width, dim=pack_dim) |
134 | | - shape = list(int_data.shape) |
135 | | - shape[pack_dim] = shape[pack_dim] * bit_width // 8 |
136 | | - return cls(shards, int_data.shape, bit_width, pack_dim) |
137 | | - |
138 | | - def _get_to_kwargs(self, *args, **kwargs): |
139 | | - device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) |
140 | | - device = self.device if device is None else device |
141 | | - dtype = self.dtype if dtype is None else dtype |
142 | | - memory_format = ( |
143 | | - memory_format if memory_format is not None else torch.preserve_format |
144 | | - ) |
145 | | - kwargs = { |
146 | | - "device": device, |
147 | | - "dtype": dtype, |
148 | | - "memory_format": memory_format, |
149 | | - } |
150 | | - return kwargs |
151 | | - |
152 | | - def to(self, *args, **kwargs): |
153 | | - if "copy" in kwargs: |
154 | | - return super().to(*args, **kwargs) |
155 | | - kwargs = self._get_to_kwargs(*args, **kwargs) |
156 | | - if "device" in kwargs: |
157 | | - return self.__class__( |
158 | | - list(shard.to(kwargs["device"]) for shard in self.get_shards()), |
159 | | - self.packed_shape, |
160 | | - self.bit_width, |
161 | | - self.pack_dim, |
162 | | - ) |
163 | | - return super().to(*args, **kwargs) |
164 | | - |
165 | | - |
166 | | -implements = UintxTensor.implements |
167 | | - |
168 | | - |
169 | | -@implements(aten.detach.default) |
170 | | -def _(func, types, args, kwargs): |
171 | | - return return_and_correct_aliasing( |
172 | | - func, args, kwargs, args[0].apply_fn_to_shards(torch.detach) |
173 | | - ) |
174 | | - |
175 | | - |
176 | | -@implements(aten.view.default) |
177 | | -def _(func, types, args, kwargs): |
178 | | - return return_and_correct_aliasing( |
179 | | - func, args, kwargs, args[0].apply_transformation(lambda x: x.view(*args[1:])) |
180 | | - ) |
181 | | - |
182 | | - |
183 | | -@implements(aten._to_copy.default) |
184 | | -def _(func, types, args, kwargs): |
185 | | - return return_and_correct_aliasing(func, args, kwargs, args[0]) |
186 | | - |
187 | | - |
188 | | -@implements(aten.sub.Tensor) |
189 | | -def _(func, types, args, kwargs): |
190 | | - return return_and_correct_aliasing( |
191 | | - func, |
192 | | - args, |
193 | | - kwargs, |
194 | | - args[0].apply_transformation(lambda x: (x - args[1]).to(torch.uint8)), |
195 | | - ) |
196 | | - |
197 | | - |
198 | | -@implements(aten.mul.Tensor) |
199 | | -def _(func, types, args, kwargs): |
200 | | - return return_and_correct_aliasing( |
201 | | - func, |
202 | | - args, |
203 | | - kwargs, |
204 | | - args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8)), |
205 | | - ) |
206 | | - |
207 | | - |
208 | | -# quantization api integrations |
209 | | -to_uintx = UintxTensor.from_uint8 |
210 | | - |
211 | | - |
212 | | -@dataclass(frozen=True) |
213 | | -class UintxLayout(Layout): |
214 | | - """A layout class for Uintx tensors, which are tensors with elements packed into |
215 | | - smaller bit-widths than the standard 8-bit byte. This layout is used to define |
216 | | - how the data is stored and processed in UintxTensor objects. |
217 | | -
|
218 | | - Attributes: |
219 | | - dtype (torch.dtype): The data type of the tensor elements, which determines |
220 | | - the bit-width used for packing. |
221 | | - pack_dim (int): The dimension along which the data is packed. Default is -1, |
222 | | - which indicates the last dimension. |
223 | | - """ |
224 | | - |
225 | | - dtype: torch.dtype |
226 | | - pack_dim: int = -1 |
227 | | - |
228 | | - def post_process( |
229 | | - self, |
230 | | - input: torch.Tensor, |
231 | | - scale: torch.Tensor, |
232 | | - zero_point: torch.Tensor, |
233 | | - block_size: Tuple[int, ...], |
234 | | - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
235 | | - return to_uintx(input, self.dtype, self.pack_dim), scale, zero_point |
236 | | - |
237 | | - |
238 | | -@register_layout(UintxLayout) |
239 | | -class UintxAQTTensorImpl(PlainAQTTensorImpl): |
240 | | - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
241 | | - return self.int_data.get_plain(), self.scale, self.zero_point |
242 | | - |
243 | | - @classmethod |
244 | | - def from_plain( |
245 | | - cls, |
246 | | - int_data: torch.Tensor, |
247 | | - scale: torch.Tensor, |
248 | | - zero_point: torch.Tensor, |
249 | | - _layout: Layout, |
250 | | - ): |
251 | | - assert isinstance(_layout, UintxLayout) |
252 | | - return cls(int_data, scale, zero_point, _layout) |
| 19 | +from torchao.prototype.dtypes.uintx.uintx_layout import ( # noqa: F401 |
| 20 | + _BIT_WIDTH_TO_DTYPE, # noqa: F401 |
| 21 | + _DTYPE_TO_BIT_WIDTH, # noqa: F401 |
| 22 | + UintxAQTTensorImpl, # noqa: F401 |
| 23 | + UintxLayout, # noqa: F401 |
| 24 | + UintxTensor, # noqa: F401 |
| 25 | + to_uintx, # noqa: F401 |
| 26 | +) |
0 commit comments