Skip to content

Commit 63cb7a9

Browse files
authored
Change n, k, group_size tensors to have no elements
Differential Revision: D63467171 Pull Request resolved: #956
1 parent b149edb commit 63cb7a9

File tree

3 files changed

+37
-26
lines changed

3 files changed

+37
-26
lines changed

torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq-impl.h

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ Tensor pack_weights_cpu(
8080
weight_scales.dtype() == torch::kFloat32,
8181
"weight_scales must be float32");
8282
CHECK_MSG(weight_scales.dim() == 1, "weight_scales must be 1D");
83+
CHECK_MSG(group_size >= 1, "group_size must be >= 1");
8384
CHECK_MSG(
8485
weight_scales.size(0) == ((n * k) / group_size),
8586
"expected 1 scale per group");
@@ -134,9 +135,9 @@ Tensor pack_weights_without_zeros_cpu(
134135
const Tensor& weight_qvals,
135136
const Tensor& weight_scales,
136137
// TODO(T200095131): convert to int64_t when supported by AOTI
137-
// group_size is a meta tensor with size (group_size)
138+
// group_size is a tensor with size (0, group_size)
138139
const Tensor& group_size_tensor) {
139-
int64_t group_size = group_size_tensor.size(0);
140+
int64_t group_size = group_size_tensor.size(1);
140141
return pack_weights_cpu<weight_nbit, /*has_weight_zeros*/ false>(
141142
weight_qvals, weight_scales, std::nullopt, group_size);
142143
}
@@ -151,7 +152,7 @@ Tensor pack_weights_with_zeros_cpu(
151152
// TODO(T200095131): convert to int64_t when supported by AOTI
152153
// group_size is a meta tensor with size (group_size)
153154
const Tensor& group_size_tensor) {
154-
int64_t group_size = group_size_tensor.size(0);
155+
int64_t group_size = group_size_tensor.size(1);
155156
return pack_weights_cpu<weight_nbit, /*has_weight_zeros*/ true>(
156157
weight_qvals, weight_scales, weight_zeros, group_size);
157158
}
@@ -164,6 +165,7 @@ Tensor pack_weights_meta(
164165
const Tensor& weight_scales,
165166
const std::optional<Tensor>& weight_zeros,
166167
int64_t group_size) {
168+
CHECK_MSG(group_size >= 1, "group_size must be >= 1");
167169
int n = weight_qvals.size(0);
168170
int k = weight_qvals.size(1);
169171

@@ -190,7 +192,7 @@ Tensor pack_weights_without_zeros_meta(
190192
// TODO(T200095131): convert to int64_t when supported by AOTI
191193
// group_size is a meta tensor with size (group_size)
192194
const Tensor& group_size_tensor) {
193-
int64_t group_size = group_size_tensor.size(0);
195+
int64_t group_size = group_size_tensor.size(1);
194196
return pack_weights_meta<weight_nbit, /*has_weight_zeros*/ false>(
195197
weight_qvals, weight_scales, std::nullopt, group_size);
196198
}
@@ -205,7 +207,7 @@ Tensor pack_weights_with_zeros_meta(
205207
// TODO(T200095131): convert to int64_t when supported by AOTI
206208
// group_size is a meta tensor with size (group_size)
207209
const Tensor& group_size_tensor) {
208-
int64_t group_size = group_size_tensor.size(0);
210+
int64_t group_size = group_size_tensor.size(1);
209211
return pack_weights_meta<weight_nbit, /*has_weight_zeros*/ true>(
210212
weight_qvals, weight_scales, weight_zeros, group_size);
211213
}
@@ -216,16 +218,19 @@ template <int weight_nbit, bool has_weight_zeros>
216218
Tensor linear_out_cpu(
217219
const Tensor& packed_weights,
218220
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
219-
// int64_t when supported by AOTI Currently they are meta tensors with size
220-
// equal to the int they wrap
221+
// int64_t when supported by AOTI Currently they are tensors with size
222+
// equal to (0, the int they wrap)
221223
const Tensor& n_tensor,
222224
const Tensor& k_tensor,
223225
const Tensor& group_size_tensor,
224226
const Tensor& activations,
225227
Tensor& out) {
226-
int n = n_tensor.size(0);
227-
int k = k_tensor.size(0);
228-
int group_size = group_size_tensor.size(0);
228+
int n = n_tensor.size(1);
229+
int k = k_tensor.size(1);
230+
int group_size = group_size_tensor.size(1);
231+
CHECK_MSG(n >= 1, "n must be >= 1");
232+
CHECK_MSG(k >= 1, "k must be >= 1");
233+
CHECK_MSG(group_size >= 1, "group_size must be >= 1");
229234

230235
#ifdef USE_ATEN
231236
CHECK_MSG(
@@ -303,8 +308,8 @@ template <int weight_nbit, bool has_weight_zeros>
303308
Tensor linear_cpu(
304309
const Tensor& packed_weights,
305310
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
306-
// int64_t when supported by AOTI Currently they are meta tensors with size
307-
// equal to the int they wrap
311+
// int64_t when supported by AOTI Currently they are tensors with size
312+
// equal to (0, the int they wrap)
308313
const Tensor& n_tensor,
309314
const Tensor& k_tensor,
310315
const Tensor& group_size_tensor,
@@ -327,14 +332,17 @@ Tensor linear_meta(
327332
const Tensor& packed_weights,
328333
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
329334
// int64_t when supported by AOTI
330-
// Currently they are meta tensors with size equal to the int they wrap
335+
// Currently they are tensors with size equal to (0, the int they wrap)
331336
const Tensor& n_tensor,
332337
const Tensor& k_tensor,
333338
const Tensor& group_size_tensor,
334339
const Tensor& activations) {
335-
int n = n_tensor.size(0);
336-
int k = k_tensor.size(0);
340+
int n = n_tensor.size(1);
341+
int k = k_tensor.size(1);
342+
CHECK_MSG(n >= 1, "n must be >= 1");
343+
CHECK_MSG(k >= 1, "k must be >= 1");
337344

345+
CHECK_MSG(activations.dim() == 2, "activations must be 2D");
338346
int m = activations.size(0);
339347
int k_ = activations.size(1);
340348
CHECK_MSG(k == k_, "activation shape is incompatible with packed weights.");

torchao/experimental/quant_api.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,11 @@ def quantize_and_pack_weights(self, weights, nbit, group_size, has_weight_zeros)
8080

8181
# TODO(T200095131): convert self.n, self.k, self.group_size to
8282
# int when supported by AOTI
83-
self._n = torch.empty(n, dtype=torch.int8)
84-
self._k = torch.empty(k, dtype=torch.int8)
85-
self._group_size = torch.empty(self.group_size, dtype=torch.int8)
83+
# AOTI does not allow a tensor of size (n, 0), so we do (0, n)
84+
self._n = torch.empty(0, n, dtype=torch.int8)
85+
self._k = torch.empty(0, k, dtype=torch.int8)
86+
self._group_size = torch.empty(0, group_size, dtype=torch.int8)
87+
8688

8789
weight_qvals, weight_scales, weight_zeros = _quantize(
8890
weights, self.group_size, self.nbit, self.has_weight_zeros
@@ -109,7 +111,7 @@ def forward(self, x):
109111
assert x.dim() >= 3
110112
lead_shape = x.shape[0:-2]
111113
m, k = x.shape[-2], x.shape[-1]
112-
n = self._n.shape[0]
114+
n = self._n.shape[1]
113115
x = x.reshape(-1, m, k)
114116

115117
res = [
@@ -254,7 +256,7 @@ def _replace_linear_with_quantized_linear(module: nn.Module, kwargs={}):
254256
if not isinstance(qlinear, _Int8DynActIntxWeightQuantizedLinearNative):
255257
raise e
256258
logger.warning(
257-
"_Int8DynActIntxWeightQuantizedLinearNative raised an exception during quantize_and_pack_weights: {e}\n"
259+
f"_Int8DynActIntxWeightQuantizedLinearNative raised an exception during quantize_and_pack_weights: {e}\n"
258260
+ "Falling back to **slow** implementation _Int8DynActIntxWeightQuantizedLinearFallback."
259261
)
260262
qlinear = _Int8DynActIntxWeightQuantizedLinearFallback()

torchao/experimental/tests/test_int8_dyn_act_intx_weight_quantizer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,17 @@ def test_accuracy(self):
7777

7878
def test_export_compile_aoti(self):
7979
group_size = 32
80-
m = 1
81-
n = 256
82-
k = 256
80+
m = 3
81+
k0 = 512
82+
k1 = 256
83+
k2 = 128
84+
k3 = 1024
8385
nbit = 4
8486
has_weight_zeros = False
85-
n_layers = 3
86-
layers = [torch.nn.Linear(k, n, bias=False) for _ in range(n_layers)]
87+
layers = [torch.nn.Linear(k0, k1, bias=False), torch.nn.Linear(k1, k2, bias=False), torch.nn.Linear(k2, k3, bias=False)]
8788
model = torch.nn.Sequential(*layers)
8889

89-
activations = torch.randn(m, k, dtype=torch.float32)
90+
activations = torch.randn(2, 1, m, k0, dtype=torch.float32)
9091

9192
print("Quantizing model")
9293
quantizer = Int8DynActIntxWeightQuantizer(

0 commit comments

Comments
 (0)