@@ -80,6 +80,7 @@ Tensor pack_weights_cpu(
8080 weight_scales.dtype () == torch::kFloat32 ,
8181 " weight_scales must be float32" );
8282 CHECK_MSG (weight_scales.dim () == 1 , " weight_scales must be 1D" );
83+ CHECK_MSG (group_size >= 1 , " group_size must be >= 1" );
8384 CHECK_MSG (
8485 weight_scales.size (0 ) == ((n * k) / group_size),
8586 " expected 1 scale per group" );
@@ -134,9 +135,9 @@ Tensor pack_weights_without_zeros_cpu(
134135 const Tensor& weight_qvals,
135136 const Tensor& weight_scales,
136137 // TODO(T200095131): convert to int64_t when supported by AOTI
137- // group_size is a meta tensor with size (group_size)
138+ // group_size is a tensor with size (0, group_size)
138139 const Tensor& group_size_tensor) {
139- int64_t group_size = group_size_tensor.size (0 );
140+ int64_t group_size = group_size_tensor.size (1 );
140141 return pack_weights_cpu<weight_nbit, /* has_weight_zeros*/ false >(
141142 weight_qvals, weight_scales, std::nullopt , group_size);
142143}
@@ -151,7 +152,7 @@ Tensor pack_weights_with_zeros_cpu(
151152 // TODO(T200095131): convert to int64_t when supported by AOTI
152153 // group_size is a meta tensor with size (group_size)
153154 const Tensor& group_size_tensor) {
154- int64_t group_size = group_size_tensor.size (0 );
155+ int64_t group_size = group_size_tensor.size (1 );
155156 return pack_weights_cpu<weight_nbit, /* has_weight_zeros*/ true >(
156157 weight_qvals, weight_scales, weight_zeros, group_size);
157158}
@@ -164,6 +165,7 @@ Tensor pack_weights_meta(
164165 const Tensor& weight_scales,
165166 const std::optional<Tensor>& weight_zeros,
166167 int64_t group_size) {
168+ CHECK_MSG (group_size >= 1 , " group_size must be >= 1" );
167169 int n = weight_qvals.size (0 );
168170 int k = weight_qvals.size (1 );
169171
@@ -190,7 +192,7 @@ Tensor pack_weights_without_zeros_meta(
190192 // TODO(T200095131): convert to int64_t when supported by AOTI
191193 // group_size is a meta tensor with size (group_size)
192194 const Tensor& group_size_tensor) {
193- int64_t group_size = group_size_tensor.size (0 );
195+ int64_t group_size = group_size_tensor.size (1 );
194196 return pack_weights_meta<weight_nbit, /* has_weight_zeros*/ false >(
195197 weight_qvals, weight_scales, std::nullopt , group_size);
196198}
@@ -205,7 +207,7 @@ Tensor pack_weights_with_zeros_meta(
205207 // TODO(T200095131): convert to int64_t when supported by AOTI
206208 // group_size is a meta tensor with size (group_size)
207209 const Tensor& group_size_tensor) {
208- int64_t group_size = group_size_tensor.size (0 );
210+ int64_t group_size = group_size_tensor.size (1 );
209211 return pack_weights_meta<weight_nbit, /* has_weight_zeros*/ true >(
210212 weight_qvals, weight_scales, weight_zeros, group_size);
211213}
@@ -216,16 +218,19 @@ template <int weight_nbit, bool has_weight_zeros>
216218Tensor linear_out_cpu (
217219 const Tensor& packed_weights,
218220 // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
219- // int64_t when supported by AOTI Currently they are meta tensors with size
220- // equal to the int they wrap
221+ // int64_t when supported by AOTI Currently they are tensors with size
222+ // equal to (0, the int they wrap)
221223 const Tensor& n_tensor,
222224 const Tensor& k_tensor,
223225 const Tensor& group_size_tensor,
224226 const Tensor& activations,
225227 Tensor& out) {
226- int n = n_tensor.size (0 );
227- int k = k_tensor.size (0 );
228- int group_size = group_size_tensor.size (0 );
228+ int n = n_tensor.size (1 );
229+ int k = k_tensor.size (1 );
230+ int group_size = group_size_tensor.size (1 );
231+ CHECK_MSG (n >= 1 , " n must be >= 1" );
232+ CHECK_MSG (k >= 1 , " k must be >= 1" );
233+ CHECK_MSG (group_size >= 1 , " group_size must be >= 1" );
229234
230235#ifdef USE_ATEN
231236 CHECK_MSG (
@@ -303,8 +308,8 @@ template <int weight_nbit, bool has_weight_zeros>
303308Tensor linear_cpu (
304309 const Tensor& packed_weights,
305310 // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
306- // int64_t when supported by AOTI Currently they are meta tensors with size
307- // equal to the int they wrap
311+ // int64_t when supported by AOTI Currently they are tensors with size
312+ // equal to (0, the int they wrap)
308313 const Tensor& n_tensor,
309314 const Tensor& k_tensor,
310315 const Tensor& group_size_tensor,
@@ -327,14 +332,17 @@ Tensor linear_meta(
327332 const Tensor& packed_weights,
328333 // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
329334 // int64_t when supported by AOTI
330- // Currently they are meta tensors with size equal to the int they wrap
335+ // Currently they are tensors with size equal to (0, the int they wrap)
331336 const Tensor& n_tensor,
332337 const Tensor& k_tensor,
333338 const Tensor& group_size_tensor,
334339 const Tensor& activations) {
335- int n = n_tensor.size (0 );
336- int k = k_tensor.size (0 );
340+ int n = n_tensor.size (1 );
341+ int k = k_tensor.size (1 );
342+ CHECK_MSG (n >= 1 , " n must be >= 1" );
343+ CHECK_MSG (k >= 1 , " k must be >= 1" );
337344
345+ CHECK_MSG (activations.dim () == 2 , " activations must be 2D" );
338346 int m = activations.size (0 );
339347 int k_ = activations.size (1 );
340348 CHECK_MSG (k == k_, " activation shape is incompatible with packed weights." );
0 commit comments