@@ -18,17 +18,18 @@ namespace contrib {
1818
1919namespace {
2020template <typename T1>
21- int32_t GetDataElement (const T1* data_ptr, int64_t data_idx) {
21+ int32_t Get4BitElement (const T1* data_ptr, int64_t data_idx) {
2222 return static_cast <int32_t >(data_ptr[data_idx >> 1 ].GetElem (narrow<size_t >(data_idx & 1 )));
2323}
2424
2525template <>
26- int32_t GetDataElement <uint8_t >(const uint8_t * data_ptr, int64_t data_idx) {
26+ int32_t Get4BitElement <uint8_t >(const uint8_t * data_ptr, int64_t data_idx) {
2727 const uint8_t data_val_u8 = data_ptr[data_idx >> 1 ];
2828 // Weights are stored as (nibble2)(nibble1) in uint8_t.
2929 auto data_val = static_cast <int32_t >((data_idx & 1 ) ? ((data_val_u8 >> 4 ) & 0x0F ) : (data_val_u8 & 0x0F ));
3030 return data_val;
3131}
32+
3233} // namespace
3334
3435template <typename T1, typename Tind>
@@ -47,6 +48,13 @@ class GatherBlockQuantized : public OpKernel {
4748 block_size_ = 128 ;
4849 }
4950
51+ ORT_ENFORCE (block_size_ >= 16 && ((block_size_ - 1 ) & block_size_) == 0 ,
52+ " 'block_size' must be a power of 2 and not less than 16." );
53+
54+ constexpr int64_t default_bits = 4 ;
55+ info.GetAttrOrDefault (" bits" , &bits_, default_bits);
56+ ORT_ENFORCE (bits_ == 4 || bits_ == 8 , " GatherBlockQuantized only support bits==4 or 8" );
57+
5058 ORT_ENFORCE (block_size_ >= 16 && ((block_size_ - 1 ) & block_size_) == 0 ,
5159 " 'block_size' must be 2's power and not less than 16." );
5260 }
@@ -84,6 +92,7 @@ class GatherBlockQuantized : public OpKernel {
8492 int64_t gather_axis_;
8593 int64_t quantize_axis_;
8694 int64_t block_size_;
95+ int64_t bits_;
8796};
8897
8998template <typename T1, typename Tind>
@@ -94,13 +103,21 @@ Status GatherBlockQuantized<T1, Tind>::PrepareForCompute(OpKernelContext* contex
94103 p.zero_points_tensor = context->Input <Tensor>(3 );
95104
96105 const auto & data_shape = p.data_tensor ->Shape ();
97- const auto & indices_shape = p.indices_tensor ->Shape ();
98106 const auto data_rank = data_shape.NumDimensions ();
99107 p.gather_axis = HandleNegativeAxis (gather_axis_, narrow<int64_t >(data_rank));
108+
100109 p.quantize_axis = HandleNegativeAxis (quantize_axis_, narrow<int64_t >(data_rank));
110+ if constexpr (std::is_same_v<T1, uint8_t >) {
111+ ORT_RETURN_IF_NOT (p.gather_axis == 0 , " For uint8_t data, gather_axis must be 0." );
112+ ORT_RETURN_IF_NOT (p.quantize_axis == static_cast <int64_t >(data_rank) - 1 , " For uint8_t data, quantize_axis must be the last dimension." );
113+ ORT_RETURN_IF_NOT (p.gather_axis != p.quantize_axis , " gather_axis and quantize_axis must not be the same." );
114+ }
115+
116+ const auto & indices_shape = p.indices_tensor ->Shape ();
117+ const auto indices_rank = indices_shape.NumDimensions ();
101118
102119 std::vector<int64_t > shape;
103- shape.reserve (data_rank - 1 + indices_shape. NumDimensions () );
120+ shape.reserve (data_rank - 1 + indices_rank );
104121
105122 // get output tensor
106123 // replace the dimension for p.gather_axis with the shape from the indices
@@ -113,12 +130,21 @@ Status GatherBlockQuantized<T1, Tind>::PrepareForCompute(OpKernelContext* contex
113130 for (int64_t i = p.gather_axis + 1 ; i < static_cast <int64_t >(data_rank); ++i)
114131 shape.push_back (data_shape[narrow<size_t >(i)]);
115132
116- // When data is stored as uint8_t, each element has two int4 values.
133+ // When bits==4 and data is stored as uint8_t, each element has two int4 values.
117134 // The shape in the onnx model reflects that by having the last dimension be half the number of values.
118- // Ex : For a true data size of 2000x3072, the onnx model would have data of shape 2000x1536.
135+ // Example : For a true data size of 2000x3072, the packed uint8 tensor has shape 2000x1536.
119136 // However the outputs still need to be of size 2000x3072. Therefore we x2 the last dimension here.
120- uint32_t components = (std::is_same_v<T1, uint8_t >) ? 2 : 1 ;
121- shape[shape.size () - 1 ] = shape.back () * components;
137+ uint32_t components = 1 ;
138+ if constexpr (std::is_same_v<T1, uint8_t >) {
139+ components = 8 / static_cast <int >(bits_);
140+ if (components > 1 ) {
141+ // To handle quantize_axis that is not the last dimension:
142+ // shape[(p.quantize_axis < p.gather_axis) ? p.quantize_axis : p.quantize_axis + indices_rank - 1] *= components;
143+ // Since we constraint the last dimension to be the quantize_axis, we can simplify it to:
144+ shape.back () *= components;
145+ }
146+ }
147+
122148 p.output_tensor = context->Output (0 , TensorShape (std::move (shape)));
123149
124150 // validate quantization parameters
@@ -137,8 +163,14 @@ Status GatherBlockQuantized<T1, Tind>::PrepareForCompute(OpKernelContext* contex
137163 ORT_RETURN_IF_NOT (scales_shape.NumDimensions () == zero_points_shape.NumDimensions (),
138164 " scales and zero_points must have the same rank." );
139165 for (size_t i = 0 ; i < scales_shape.NumDimensions (); ++i) {
140- ORT_RETURN_IF_NOT (scales_shape[i] == zero_points_shape[i],
141- " scales and zero_points must have the same shape." );
166+ if (components > 1 && i == static_cast <size_t >(p.quantize_axis )) {
167+ // For uint8_t with bits=4, zero points is stored as 2 components per byte.
168+ ORT_RETURN_IF_NOT ((scales_shape[i] + components - 1 ) / components == zero_points_shape[i],
169+ " scales and zero_points shape does not match." );
170+ } else {
171+ ORT_RETURN_IF_NOT (scales_shape[i] == zero_points_shape[i],
172+ " scales and zero_points must have the same shape." );
173+ }
142174 }
143175 }
144176
@@ -186,21 +218,44 @@ Status GatherBlockQuantized<T1, Tind>::CopyDataAndDequantize(const T1* data_ptr,
186218 int64_t output_idx = output_idx_base;
187219 int64_t data_idx = data_idx_base;
188220 for (int64_t i = 0 ; i < gather_block; ++i, ++output_idx, ++data_idx) {
189- auto data_val = GetDataElement (data_ptr, data_idx);
221+ int32_t data_val;
222+ if constexpr (!std::is_same_v<T1, uint8_t >) {
223+ data_val = Get4BitElement (data_ptr, data_idx);
224+ } else { // unit8_t
225+ if (bits_ == 4 ) {
226+ data_val = Get4BitElement (data_ptr, data_idx);
227+ } else { // buts_ == 8
228+ data_val = static_cast <int32_t >(data_ptr[data_idx]);
229+ }
230+ }
190231
191232 int64_t x = data_idx / quantize_full_block;
192233 int64_t y = data_idx % quantize_full_block / quantize_N;
193234 int64_t z = data_idx % quantize_N;
194235 int64_t scale_idx = x * scale_full_block + y / block_size_ * quantize_N + z;
195236 auto scale_val = static_cast <float >(scales_ptr[scale_idx]);
196237 int32_t zp_val;
238+
197239 if constexpr (std::is_same_v<T1, uint8_t >) {
198- // The default zero point for uint8 weights as stored by MatMulNBits op is 8.
199- zp_val = 8 ;
240+ if (zero_points_ptr) {
241+ if (bits_ == 4 ) {
242+ uint8_t packed = zero_points_ptr[scale_idx >> 1 ];
243+ if (scale_idx & 1 ) {
244+ zp_val = static_cast <int32_t >((packed >> 4 ) & 0x0F );
245+ } else {
246+ zp_val = static_cast <int32_t >(packed & 0x0F );
247+ }
248+ } else { // bits_ == 8
249+ zp_val = static_cast <int32_t >(zero_points_ptr[scale_idx]);
250+ }
251+ } else {
252+ const int32_t default_zero_point = bits_ == 4 ? 8 : 128 ;
253+ zp_val = default_zero_point;
254+ }
200255 } else {
201- zp_val = static_cast < int32_t >( zero_points_ptr
202- ? zero_points_ptr[scale_idx >> 1 ].GetElem (narrow<size_t >(scale_idx & 1 ))
203- : 0 ) ;
256+ zp_val = zero_points_ptr
257+ ? static_cast < int32_t >( zero_points_ptr[scale_idx >> 1 ].GetElem (narrow<size_t >(scale_idx & 1 ) ))
258+ : 0 ;
204259 }
205260
206261 output_ptr[output_idx] = static_cast <T2>(static_cast <float >(data_val - zp_val) * scale_val);
@@ -232,7 +287,7 @@ template <typename T1, typename Tind>
232287Status GatherBlockQuantized<T1, Tind>::Compute(OpKernelContext* context) const {
233288 Prepare p;
234289 ORT_RETURN_IF_ERROR (PrepareForCompute (context, p));
235- auto components = ( std::is_same_v<T1, uint8_t >) ? 2 : 1 ;
290+ int64_t components = std::is_same_v<T1, uint8_t > ? ( 8 / static_cast < int >(bits_)) : 1 ;
236291 const auto & data_shape = p.data_tensor ->Shape ();
237292 // re-shape the data tensor to [gather_M, gather_axis_dim, gather_block]
238293 // re-shape the indices tensor to [gather_N]
0 commit comments