Skip to content

Commit ce8796d

Browse files
authored
GatherBlockQuantized supports zero points and 8 bits for uint8 dtype (microsoft#25214)
Add support for unit8 GatherBlockQuantized for the following two areas: * Allow zero points. * Add bits attribute and support bits=8. Major change is to update shape inference; and update unit tests to cover these. Note that only CPU implementation, and CUDA implementation will be added later in another PR. ### Motivation and Context Previously, zero points are not supported when dtype is uint8. Only 4 bit quantization without zero points were supported. This change is to share weights of lm_head with 8 bit quantization between GatherBlockQuantized and MatMulNBits. For example, when K is multiple of `block_size`, typical input and output shapes are like the following: * data has shape (N, K) for 8 bits, or (N, K / 2) for 4 bits. * scales has shape (N, k_blocks), where k_blocks = (K / block_size). * zero_points has shape (N, k_blocks) for 8 bits, (N, (k_blocks + 1) / 2) for 4 bits. * output will have shape (..., K), where ... is the shape of `indices`.
1 parent 4b18210 commit ce8796d

File tree

6 files changed

+404
-225
lines changed

6 files changed

+404
-225
lines changed

docs/ContribOperators.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2053,6 +2053,8 @@ This version of the operator has been available since version 1 of the 'com.micr
20532053
#### Attributes
20542054

20552055
<dl>
2056+
<dt><tt>bits</tt> : int</dt>
2057+
<dd>Number of bits used for weight quantization. Must be either 4 or 8. </dd>
20562058
<dt><tt>block_size</tt> : int</dt>
20572059
<dd>(Optional) block size used for weight quantization. It needs to be a power of 2 and not smaller than 16.</dd>
20582060
<dt><tt>gather_axis</tt> : int</dt>

js/web/test/data/ops/gather-block-quantized.jsonc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121
"name": "quantize_axis",
2222
"data": 2,
2323
"type": "int"
24+
},
25+
{
26+
"name": "bits",
27+
"data": 4,
28+
"type": "int"
2429
}
2530
],
2631
"cases": [

onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,18 @@ namespace contrib {
1818

1919
namespace {
2020
template <typename T1>
21-
int32_t GetDataElement(const T1* data_ptr, int64_t data_idx) {
21+
int32_t Get4BitElement(const T1* data_ptr, int64_t data_idx) {
2222
return static_cast<int32_t>(data_ptr[data_idx >> 1].GetElem(narrow<size_t>(data_idx & 1)));
2323
}
2424

2525
template <>
26-
int32_t GetDataElement<uint8_t>(const uint8_t* data_ptr, int64_t data_idx) {
26+
int32_t Get4BitElement<uint8_t>(const uint8_t* data_ptr, int64_t data_idx) {
2727
const uint8_t data_val_u8 = data_ptr[data_idx >> 1];
2828
// Weights are stored as (nibble2)(nibble1) in uint8_t.
2929
auto data_val = static_cast<int32_t>((data_idx & 1) ? ((data_val_u8 >> 4) & 0x0F) : (data_val_u8 & 0x0F));
3030
return data_val;
3131
}
32+
3233
} // namespace
3334

3435
template <typename T1, typename Tind>
@@ -47,6 +48,13 @@ class GatherBlockQuantized : public OpKernel {
4748
block_size_ = 128;
4849
}
4950

51+
ORT_ENFORCE(block_size_ >= 16 && ((block_size_ - 1) & block_size_) == 0,
52+
"'block_size' must be a power of 2 and not less than 16.");
53+
54+
constexpr int64_t default_bits = 4;
55+
info.GetAttrOrDefault("bits", &bits_, default_bits);
56+
ORT_ENFORCE(bits_ == 4 || bits_ == 8, "GatherBlockQuantized only support bits==4 or 8");
57+
5058
ORT_ENFORCE(block_size_ >= 16 && ((block_size_ - 1) & block_size_) == 0,
5159
"'block_size' must be 2's power and not less than 16.");
5260
}
@@ -84,6 +92,7 @@ class GatherBlockQuantized : public OpKernel {
8492
int64_t gather_axis_;
8593
int64_t quantize_axis_;
8694
int64_t block_size_;
95+
int64_t bits_;
8796
};
8897

8998
template <typename T1, typename Tind>
@@ -94,13 +103,21 @@ Status GatherBlockQuantized<T1, Tind>::PrepareForCompute(OpKernelContext* contex
94103
p.zero_points_tensor = context->Input<Tensor>(3);
95104

96105
const auto& data_shape = p.data_tensor->Shape();
97-
const auto& indices_shape = p.indices_tensor->Shape();
98106
const auto data_rank = data_shape.NumDimensions();
99107
p.gather_axis = HandleNegativeAxis(gather_axis_, narrow<int64_t>(data_rank));
108+
100109
p.quantize_axis = HandleNegativeAxis(quantize_axis_, narrow<int64_t>(data_rank));
110+
if constexpr (std::is_same_v<T1, uint8_t>) {
111+
ORT_RETURN_IF_NOT(p.gather_axis == 0, "For uint8_t data, gather_axis must be 0.");
112+
ORT_RETURN_IF_NOT(p.quantize_axis == static_cast<int64_t>(data_rank) - 1, "For uint8_t data, quantize_axis must be the last dimension.");
113+
ORT_RETURN_IF_NOT(p.gather_axis != p.quantize_axis, "gather_axis and quantize_axis must not be the same.");
114+
}
115+
116+
const auto& indices_shape = p.indices_tensor->Shape();
117+
const auto indices_rank = indices_shape.NumDimensions();
101118

102119
std::vector<int64_t> shape;
103-
shape.reserve(data_rank - 1 + indices_shape.NumDimensions());
120+
shape.reserve(data_rank - 1 + indices_rank);
104121

105122
// get output tensor
106123
// replace the dimension for p.gather_axis with the shape from the indices
@@ -113,12 +130,21 @@ Status GatherBlockQuantized<T1, Tind>::PrepareForCompute(OpKernelContext* contex
113130
for (int64_t i = p.gather_axis + 1; i < static_cast<int64_t>(data_rank); ++i)
114131
shape.push_back(data_shape[narrow<size_t>(i)]);
115132

116-
// When data is stored as uint8_t, each element has two int4 values.
133+
// When bits==4 and data is stored as uint8_t, each element has two int4 values.
117134
// The shape in the onnx model reflects that by having the last dimension be half the number of values.
118-
// Ex: For a true data size of 2000x3072, the onnx model would have data of shape 2000x1536.
135+
// Example: For a true data size of 2000x3072, the packed uint8 tensor has shape 2000x1536.
119136
// However the outputs still need to be of size 2000x3072. Therefore we x2 the last dimension here.
120-
uint32_t components = (std::is_same_v<T1, uint8_t>) ? 2 : 1;
121-
shape[shape.size() - 1] = shape.back() * components;
137+
uint32_t components = 1;
138+
if constexpr (std::is_same_v<T1, uint8_t>) {
139+
components = 8 / static_cast<int>(bits_);
140+
if (components > 1) {
141+
// To handle quantize_axis that is not the last dimension:
142+
// shape[(p.quantize_axis < p.gather_axis) ? p.quantize_axis : p.quantize_axis + indices_rank - 1] *= components;
143+
// Since we constraint the last dimension to be the quantize_axis, we can simplify it to:
144+
shape.back() *= components;
145+
}
146+
}
147+
122148
p.output_tensor = context->Output(0, TensorShape(std::move(shape)));
123149

124150
// validate quantization parameters
@@ -137,8 +163,14 @@ Status GatherBlockQuantized<T1, Tind>::PrepareForCompute(OpKernelContext* contex
137163
ORT_RETURN_IF_NOT(scales_shape.NumDimensions() == zero_points_shape.NumDimensions(),
138164
"scales and zero_points must have the same rank.");
139165
for (size_t i = 0; i < scales_shape.NumDimensions(); ++i) {
140-
ORT_RETURN_IF_NOT(scales_shape[i] == zero_points_shape[i],
141-
"scales and zero_points must have the same shape.");
166+
if (components > 1 && i == static_cast<size_t>(p.quantize_axis)) {
167+
// For uint8_t with bits=4, zero points is stored as 2 components per byte.
168+
ORT_RETURN_IF_NOT((scales_shape[i] + components - 1) / components == zero_points_shape[i],
169+
"scales and zero_points shape does not match.");
170+
} else {
171+
ORT_RETURN_IF_NOT(scales_shape[i] == zero_points_shape[i],
172+
"scales and zero_points must have the same shape.");
173+
}
142174
}
143175
}
144176

@@ -186,21 +218,44 @@ Status GatherBlockQuantized<T1, Tind>::CopyDataAndDequantize(const T1* data_ptr,
186218
int64_t output_idx = output_idx_base;
187219
int64_t data_idx = data_idx_base;
188220
for (int64_t i = 0; i < gather_block; ++i, ++output_idx, ++data_idx) {
189-
auto data_val = GetDataElement(data_ptr, data_idx);
221+
int32_t data_val;
222+
if constexpr (!std::is_same_v<T1, uint8_t>) {
223+
data_val = Get4BitElement(data_ptr, data_idx);
224+
} else { // unit8_t
225+
if (bits_ == 4) {
226+
data_val = Get4BitElement(data_ptr, data_idx);
227+
} else { // buts_ == 8
228+
data_val = static_cast<int32_t>(data_ptr[data_idx]);
229+
}
230+
}
190231

191232
int64_t x = data_idx / quantize_full_block;
192233
int64_t y = data_idx % quantize_full_block / quantize_N;
193234
int64_t z = data_idx % quantize_N;
194235
int64_t scale_idx = x * scale_full_block + y / block_size_ * quantize_N + z;
195236
auto scale_val = static_cast<float>(scales_ptr[scale_idx]);
196237
int32_t zp_val;
238+
197239
if constexpr (std::is_same_v<T1, uint8_t>) {
198-
// The default zero point for uint8 weights as stored by MatMulNBits op is 8.
199-
zp_val = 8;
240+
if (zero_points_ptr) {
241+
if (bits_ == 4) {
242+
uint8_t packed = zero_points_ptr[scale_idx >> 1];
243+
if (scale_idx & 1) {
244+
zp_val = static_cast<int32_t>((packed >> 4) & 0x0F);
245+
} else {
246+
zp_val = static_cast<int32_t>(packed & 0x0F);
247+
}
248+
} else { // bits_ == 8
249+
zp_val = static_cast<int32_t>(zero_points_ptr[scale_idx]);
250+
}
251+
} else {
252+
const int32_t default_zero_point = bits_ == 4 ? 8 : 128;
253+
zp_val = default_zero_point;
254+
}
200255
} else {
201-
zp_val = static_cast<int32_t>(zero_points_ptr
202-
? zero_points_ptr[scale_idx >> 1].GetElem(narrow<size_t>(scale_idx & 1))
203-
: 0);
256+
zp_val = zero_points_ptr
257+
? static_cast<int32_t>(zero_points_ptr[scale_idx >> 1].GetElem(narrow<size_t>(scale_idx & 1)))
258+
: 0;
204259
}
205260

206261
output_ptr[output_idx] = static_cast<T2>(static_cast<float>(data_val - zp_val) * scale_val);
@@ -232,7 +287,7 @@ template <typename T1, typename Tind>
232287
Status GatherBlockQuantized<T1, Tind>::Compute(OpKernelContext* context) const {
233288
Prepare p;
234289
ORT_RETURN_IF_ERROR(PrepareForCompute(context, p));
235-
auto components = (std::is_same_v<T1, uint8_t>) ? 2 : 1;
290+
int64_t components = std::is_same_v<T1, uint8_t> ? (8 / static_cast<int>(bits_)) : 1;
236291
const auto& data_shape = p.data_tensor->Shape();
237292
// re-shape the data tensor to [gather_M, gather_axis_dim, gather_block]
238293
// re-shape the indices tensor to [gather_N]

onnxruntime/contrib_ops/js/quantization/gather_block_quantized.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,14 @@ class GatherBlockQuantized : public JsKernel {
2828
block_size = 128;
2929
}
3030

31+
int64_t bits;
32+
constexpr int64_t default_bits = 4;
33+
info.GetAttrOrDefault("bits", &bits, default_bits);
34+
ORT_ENFORCE(bits == 4, "GatherBlockQuantized JS kernel only support bits==4");
35+
3136
ORT_ENFORCE(block_size >= 16 && ((block_size - 1) & block_size) == 0,
3237
"'block_size' must be 2's power and not less than 16.");
38+
3339
JSEP_INIT_KERNEL_ATTRIBUTE(GatherBlockQuantized, ({
3440
"gatherAxis" : $1,
3541
"quantizeAxis" : $2,

onnxruntime/core/graph/contrib_ops/contrib_defs.cc

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3599,6 +3599,10 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
35993599
"(Optional) block size used for weight quantization. It needs to be a power of 2 and not smaller than 16.",
36003600
AttributeProto::INT,
36013601
static_cast<int64_t>(128))
3602+
.Attr("bits",
3603+
"Number of bits used for weight quantization. Must be either 4 or 8. ",
3604+
AttributeProto::INT,
3605+
static_cast<int64_t>(4))
36023606
.Input(0, "data", "Tensor of rank r >= 1. Block-wise quantized.", "T1")
36033607
.Input(1,
36043608
"indices",
@@ -3614,22 +3618,25 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
36143618
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
36153619
// Type inference
36163620
propagateElemTypeFromInputToOutput(ctx, 2, 0);
3617-
// Shape inference
3621+
3622+
// The first 3 inputs must have shape.
36183623
if (!hasNInputShapes(ctx, 3)) {
36193624
return;
36203625
}
36213626
const TensorShapeProto& data_shape = ctx.getInputType(0)->tensor_type().shape();
36223627
const TensorShapeProto& indices_shape = ctx.getInputType(1)->tensor_type().shape();
36233628
const TensorShapeProto& scales_shape = ctx.getInputType(2)->tensor_type().shape();
3624-
int r = data_shape.dim_size();
36253629

3626-
if (r < 1) {
3627-
fail_shape_inference("data tensor must have rank >= 1");
3630+
int r = data_shape.dim_size();
3631+
if (r <= 1) {
3632+
fail_shape_inference("data tensor must have rank > 1");
36283633
}
36293634

36303635
int gather_axis = static_cast<int>(getAttribute(ctx, "gather_axis", 0));
36313636
int quantize_axis = static_cast<int>(getAttribute(ctx, "quantize_axis", 1));
3637+
int bits = static_cast<int>(getAttribute(ctx, "bits", 4));
36323638
auto block_size = getAttribute(ctx, "block_size", 128);
3639+
36333640
if (gather_axis < -r || gather_axis >= r) {
36343641
fail_shape_inference("gather_axis must be in [-r, r-1]");
36353642
}
@@ -3643,15 +3650,19 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
36433650
gather_axis = (gather_axis + r) % r;
36443651
quantize_axis = (quantize_axis + r) % r;
36453652

3646-
if ((ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) && gather_axis != 0) {
3647-
fail_shape_inference("gather_axis must be 0, for uint8 data");
3653+
if (ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) {
3654+
if (gather_axis != 0) {
3655+
fail_shape_inference("gather_axis must be 0, for uint8 data");
3656+
}
3657+
// CPU implementation requires quantize_axis to be the last dimension right now.
3658+
// we are relaxing it in the spec and shape inference since other EP might not have such restriction.
36483659
}
36493660

36503661
if (scales_shape.dim_size() != r) {
36513662
fail_shape_inference("scales must have the same rank as data");
36523663
}
36533664

3654-
uint32_t components = ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8 ? 2 : 1;
3665+
uint32_t components = (ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) ? (8 / bits) : 1;
36553666
for (int i = 0; i < r; ++i) {
36563667
if (!data_shape.dim(i).has_dim_value() ||
36573668
!scales_shape.dim(i).has_dim_value() ||
@@ -3663,10 +3674,6 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
36633674

36643675
// validate zero point shape
36653676
if (ctx.hasInput(3)) {
3666-
if (ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) {
3667-
fail_type_inference("zero_points are not supported for uint8_t data type");
3668-
}
3669-
36703677
if (!hasInputShape(ctx, 3)) {
36713678
fail_shape_inference("zero_points shape must be known");
36723679
}
@@ -3679,26 +3686,40 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
36793686
for (int i = 0; i < r; ++i) {
36803687
if (!zp_shape.dim(i).has_dim_value() ||
36813688
zp_shape.dim(i).dim_value() != scales_shape.dim(i).dim_value()) {
3689+
if (ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8 &&
3690+
bits == 4 &&
3691+
i == quantize_axis &&
3692+
zp_shape.dim(i).dim_value() == (scales_shape.dim(i).dim_value() + 1) / 2) {
3693+
continue;
3694+
}
36823695
fail_shape_inference("zero points shape and scales shape do not match");
36833696
}
36843697
}
36853698
}
36863699

36873700
int q = indices_shape.dim_size();
36883701
int out_rank = q + r - 1;
3689-
if (out_rank == 0) {
3690-
ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
3702+
auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
3703+
output_shape->clear_dim();
3704+
for (int i = 0; i < gather_axis; ++i) {
3705+
*output_shape->add_dim() = data_shape.dim(i);
3706+
}
3707+
for (int i = 0; i < q; ++i) {
3708+
*output_shape->add_dim() = indices_shape.dim(i);
36913709
}
3692-
for (int i = 0; i < out_rank; ++i) {
3693-
// For uint8_t data type the last dimension needs to be expanded back to actual dimension,
3694-
// because the data 2 int4s are stored packed in a single uint8_t.
3695-
auto last_dimension_components = (i == out_rank - 1) ? components : 1;
3696-
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape()->add_dim() =
3697-
(i < gather_axis)
3698-
? data_shape.dim(i)
3699-
: (i >= gather_axis && i < gather_axis + q)
3700-
? indices_shape.dim(i - gather_axis)
3701-
: data_shape.dim(i - q + 1) * last_dimension_components;
3710+
for (int i = gather_axis + 1; i < r; ++i) {
3711+
*output_shape->add_dim() = data_shape.dim(i);
3712+
}
3713+
3714+
// Find the correct dimension to expand and multiply it by components
3715+
if (components > 1) {
3716+
int quantize_output_dim_idx = (quantize_axis < gather_axis) ? quantize_axis : quantize_axis + q - 1;
3717+
if (quantize_output_dim_idx < out_rank) {
3718+
auto* dim_to_update = output_shape->mutable_dim(quantize_output_dim_idx);
3719+
if (dim_to_update->has_dim_value()) {
3720+
dim_to_update->set_dim_value(dim_to_update->dim_value() * components);
3721+
}
3722+
}
37023723
}
37033724
});
37043725

0 commit comments

Comments
 (0)