Skip to content

Commit 5753f8d

Browse files
[QNN EP] Initial INT4 support (microsoft#21171)
### Description - Adds support for int4 quantized weights (per-tensor and per-channel) on QNN EP - Adds test script that creates an INT4 qdq model with a Conv - Adds a unit tests demonstrating accuracy issues. ### Motivation and Context This is the next step in being able to run models that use 4-bit quantized weights on QNN EP.
1 parent 1b82d83 commit 5753f8d

File tree

12 files changed

+522
-42
lines changed

12 files changed

+522
-42
lines changed

onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ Status ConvOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
127127
int32_t elem_data_type = 0;
128128
ORT_RETURN_IF_ERROR(utils::GetOnnxTensorElemDataType(input_1.node_arg, elem_data_type));
129129

130-
const bool is_signed_type = (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) ||
130+
const bool is_signed_type = (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) ||
131+
(elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) ||
131132
(elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT16);
132133
ORT_RETURN_IF_NOT(is_signed_type, "Conv weights must be of a signed quantized type if quantized per-channel");
133134

onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <cstdlib>
66
#include <cstring>
77
#include <numeric>
8+
#include <utility>
9+
#include <vector>
810

911
#include "qnn_model_wrapper.h"
1012
#include "core/common/safeint.h"
@@ -313,7 +315,8 @@ bool QnnModelWrapper::GetOnnxShape(const NodeArg& node_arg, std::vector<uint32_t
313315
}
314316

315317
Status QnnModelWrapper::UnpackZeroPoints(const std::string& initializer_name,
316-
std::vector<int32_t>& zero_points) const {
318+
/*out*/ std::vector<int32_t>& zero_points,
319+
/*out*/ int32_t& onnx_data_type) const {
317320
const auto& graph_initializers = GetInitializerTensors();
318321
auto iter = graph_initializers.find(initializer_name);
319322
ORT_RETURN_IF(iter == graph_initializers.end(), "Unable to find initializer for zero-point(s): ",
@@ -323,13 +326,14 @@ Status QnnModelWrapper::UnpackZeroPoints(const std::string& initializer_name,
323326
ORT_RETURN_IF_NOT(zp_tensor_proto->has_data_type(), "Expected zero-point initializer ", initializer_name.c_str(),
324327
" to have a proto data type.");
325328

326-
const int32_t onnx_data_type = zp_tensor_proto->data_type();
329+
onnx_data_type = zp_tensor_proto->data_type();
327330
std::vector<uint8_t> initializer_bytes;
328331

329332
ORT_RETURN_IF_ERROR(UnpackInitializerData(*zp_tensor_proto, initializer_bytes));
330333

331334
switch (onnx_data_type) {
332335
// QNN use -offset for some reason
336+
case ONNX_NAMESPACE::TensorProto_DataType_INT4: // INT4 zero-points are unpacked as 8-bit values for QNN
333337
case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
334338
auto int8_span = ReinterpretAsSpan<const int8_t>(gsl::make_span(initializer_bytes));
335339
std::transform(int8_span.begin(), int8_span.end(), std::back_inserter(zero_points),
@@ -338,6 +342,7 @@ Status QnnModelWrapper::UnpackZeroPoints(const std::string& initializer_name,
338342
});
339343
break;
340344
}
345+
case ONNX_NAMESPACE::TensorProto_DataType_UINT4: // UINT4 zero-points are unpacked as 8-bit values for QNN
341346
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
342347
auto uint8_span = ReinterpretAsSpan<const uint8_t>(gsl::make_span(initializer_bytes));
343348
std::transform(uint8_span.begin(), uint8_span.end(), std::back_inserter(zero_points),
@@ -584,10 +589,36 @@ void QnnModelWrapper::GetGraphInputOutputTensorWrapper(const std::vector<std::st
584589
Status QnnModelWrapper::UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer,
585590
std::vector<uint8_t>& unpacked_tensor) const {
586591
if (initializer.data_location() == onnx::TensorProto_DataLocation_EXTERNAL) {
587-
return onnxruntime::utils::UnpackInitializerData(initializer, graph_viewer_.ModelPath(), unpacked_tensor);
592+
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(initializer, graph_viewer_.ModelPath(),
593+
unpacked_tensor));
594+
} else {
595+
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(initializer, unpacked_tensor));
596+
}
597+
598+
int32_t onnx_data_type = initializer.data_type();
599+
600+
// If this is an int4, we need to unpack it because QNN treats int4 as a full int8.
601+
if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) {
602+
TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer);
603+
const size_t num_elems = shape.Size();
604+
std::vector<uint8_t> packed_int4_bytes = std::move(unpacked_tensor);
605+
unpacked_tensor = std::vector<uint8_t>(num_elems);
606+
607+
auto dst = gsl::make_span(reinterpret_cast<int8_t*>(unpacked_tensor.data()), unpacked_tensor.size());
608+
auto src = gsl::make_span(reinterpret_cast<const Int4x2*>(packed_int4_bytes.data()), packed_int4_bytes.size());
609+
ORT_RETURN_IF_NOT(Int4x2::Unpack(dst, src), "Failed to unpack Tensor<Int4x2> for QNN");
610+
} else if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) {
611+
TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer);
612+
const size_t num_elems = shape.Size();
613+
std::vector<uint8_t> packed_int4_bytes = std::move(unpacked_tensor);
614+
unpacked_tensor = std::vector<uint8_t>(num_elems);
615+
616+
auto dst = gsl::make_span(reinterpret_cast<uint8_t*>(unpacked_tensor.data()), unpacked_tensor.size());
617+
auto src = gsl::make_span(reinterpret_cast<const UInt4x2*>(packed_int4_bytes.data()), packed_int4_bytes.size());
618+
ORT_RETURN_IF_NOT(UInt4x2::Unpack(dst, src), "Failed to unpack Tensor<UInt4x2> for QNN");
588619
}
589620

590-
return onnxruntime::utils::UnpackInitializerData(initializer, unpacked_tensor);
621+
return Status::OK();
591622
}
592623

593624
} // namespace qnn

onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,9 @@ class QnnModelWrapper {
216216
Status UnpackScales(const std::string& initializer_name, std::vector<float>& scales) const;
217217

218218
// Unpack zero-points from initializer and convert to int32_t (1 zero-point for per-tensor, > 1 for per-channel).
219-
Status UnpackZeroPoints(const std::string& initializer_name, std::vector<int32_t>& zero_points) const;
219+
Status UnpackZeroPoints(const std::string& initializer_name,
220+
/*out*/ std::vector<int32_t>& zero_points,
221+
/*out*/ int32_t& onnx_data_type) const;
220222

221223
// Checks if a tensor in the ONNX graph is per-channel quantized.
222224
Status IsPerChannelQuantized(const onnxruntime::NodeUnitIODef& io_def,

onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc

Lines changed: 124 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
#include "QnnTypes.h"
1010
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
1111

12+
#define ALIGN_PTR_UP(ptr, align, type) \
13+
reinterpret_cast<type>((reinterpret_cast<std::uintptr_t>(ptr) + (align)-1) & ~((align)-1))
14+
1215
namespace onnxruntime {
1316
namespace qnn {
1417

@@ -38,9 +41,10 @@ QnnQuantParamsWrapper QnnQuantParamsWrapper::Copy() const {
3841
return QnnQuantParamsWrapper(*this);
3942
}
4043

44+
// Initializes by copying from a Qnn_QuantizeParams_t.
4145
Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params) {
42-
if (scale_offset_data_) {
43-
scale_offset_data_.reset(nullptr);
46+
if (per_channel_data_) {
47+
per_channel_data_.reset(nullptr);
4448
params_ = QNN_QUANTIZE_PARAMS_INIT;
4549
}
4650

@@ -51,6 +55,7 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params) {
5155

5256
switch (params.quantizationEncoding) {
5357
case QNN_QUANTIZATION_ENCODING_SCALE_OFFSET:
58+
case QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET:
5459
params_ = params;
5560
break;
5661
case QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET: {
@@ -63,27 +68,63 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params) {
6368
const uint32_t num_elems = params.axisScaleOffsetEncoding.numScaleOffsets;
6469

6570
if (num_elems > 0) {
66-
scale_offset_data_ = std::make_unique<Qnn_ScaleOffset_t[]>(num_elems);
67-
gsl::span<Qnn_ScaleOffset_t> src_span(params.axisScaleOffsetEncoding.scaleOffset, num_elems);
68-
std::copy(src_span.begin(), src_span.end(), scale_offset_data_.get());
69-
params_.axisScaleOffsetEncoding.scaleOffset = scale_offset_data_.get();
71+
const size_t num_bytes = num_elems * sizeof(Qnn_ScaleOffset_t);
72+
constexpr std::uintptr_t align = alignof(Qnn_ScaleOffset_t);
73+
per_channel_data_ = std::make_unique<char[]>(num_bytes + align);
74+
Qnn_ScaleOffset_t* aligned_dst = ALIGN_PTR_UP(per_channel_data_.get(), align, Qnn_ScaleOffset_t*);
75+
76+
std::memcpy(aligned_dst, params.axisScaleOffsetEncoding.scaleOffset, num_bytes);
77+
params_.axisScaleOffsetEncoding.scaleOffset = aligned_dst;
7078
} else {
7179
params_.axisScaleOffsetEncoding.scaleOffset = nullptr;
7280
}
7381
break;
7482
}
83+
case QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET: {
84+
const uint32_t num_elems = params.bwAxisScaleOffsetEncoding.numElements;
85+
86+
params_.encodingDefinition = params.encodingDefinition;
87+
params_.quantizationEncoding = params.quantizationEncoding;
88+
params_.bwAxisScaleOffsetEncoding.axis = params.bwAxisScaleOffsetEncoding.axis;
89+
params_.bwAxisScaleOffsetEncoding.bitwidth = params.bwAxisScaleOffsetEncoding.bitwidth;
90+
params_.bwAxisScaleOffsetEncoding.numElements = num_elems;
91+
92+
// Deep copy the scales[] and offsets[] arrays
93+
if (num_elems > 0) {
94+
const size_t num_scale_bytes = num_elems * sizeof(float);
95+
const size_t num_zp_bytes = num_elems * sizeof(int32_t);
96+
const size_t num_bytes = num_scale_bytes + num_zp_bytes;
97+
constexpr std::uintptr_t align = alignof(float);
98+
static_assert(alignof(float) == alignof(int32_t));
99+
100+
per_channel_data_ = std::make_unique<char[]>(num_bytes + align);
101+
char* scales_begin = ALIGN_PTR_UP(per_channel_data_.get(), align, char*);
102+
char* zps_begin = scales_begin + num_scale_bytes;
103+
104+
std::memcpy(scales_begin, params.bwAxisScaleOffsetEncoding.scales, num_scale_bytes);
105+
std::memcpy(zps_begin, params.bwAxisScaleOffsetEncoding.offsets, num_zp_bytes);
106+
params_.bwAxisScaleOffsetEncoding.scales = reinterpret_cast<float*>(scales_begin);
107+
params_.bwAxisScaleOffsetEncoding.offsets = reinterpret_cast<int32_t*>(zps_begin);
108+
} else {
109+
params_.bwAxisScaleOffsetEncoding.scales = nullptr;
110+
params_.bwAxisScaleOffsetEncoding.offsets = nullptr;
111+
}
112+
break;
113+
}
75114
default:
76115
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported QNN quantization encoding: ", params.quantizationEncoding);
77116
}
78117

79118
return Status::OK();
80119
}
81120

121+
// Initialize this object from a (potentially) quantized ONNX tensor.
122+
// QnnModelWrapper provides utilities for unpacking scale and zero-point ONNX initializers.
82123
Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, const NodeUnitIODef& io_def) {
83124
const std::optional<NodeUnitIODef::QuantParam>& ort_quant_params = io_def.quant_param;
84125

85-
if (scale_offset_data_) {
86-
scale_offset_data_.reset(nullptr);
126+
if (per_channel_data_) {
127+
per_channel_data_.reset(nullptr);
87128
params_ = QNN_QUANTIZE_PARAMS_INIT;
88129
}
89130

@@ -98,17 +139,25 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con
98139

99140
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackScales(ort_quant_params->scale.Name(), scales));
100141

142+
bool is_int4_type = false;
143+
101144
if (ort_quant_params->zero_point != nullptr) {
102-
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackZeroPoints(ort_quant_params->zero_point->Name(), zero_points));
145+
int32_t onnx_tp_type = 0;
146+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackZeroPoints(ort_quant_params->zero_point->Name(), zero_points,
147+
onnx_tp_type));
148+
149+
is_int4_type = (onnx_tp_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) ||
150+
(onnx_tp_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4);
103151
}
104152

105153
const bool is_per_tensor = scales.size() == 1;
106154

107-
if (is_per_tensor) {
155+
// QNN uses different structs to represent quantization parameters depending on
156+
// - per-tensor vs per-channel
157+
// - int4 vs not int4
158+
if (is_per_tensor && !is_int4_type) {
108159
params_.encodingDefinition = QNN_DEFINITION_DEFINED;
109160
params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET;
110-
111-
// Parse scale & zero_point
112161
params_.scaleOffsetEncoding.scale = scales[0];
113162

114163
if (ort_quant_params->zero_point != nullptr) {
@@ -117,8 +166,62 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con
117166
} else {
118167
params_.scaleOffsetEncoding.offset = 0;
119168
}
120-
} else {
121-
// Per-channel quantization.
169+
} else if (is_per_tensor && is_int4_type) {
170+
params_.encodingDefinition = QNN_DEFINITION_DEFINED;
171+
params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET;
172+
params_.bwScaleOffsetEncoding.bitwidth = 4;
173+
params_.bwScaleOffsetEncoding.scale = scales[0];
174+
175+
if (ort_quant_params->zero_point != nullptr) {
176+
ORT_RETURN_IF_NOT(zero_points.size() == 1, "Expected one zero-point value");
177+
params_.bwScaleOffsetEncoding.offset = zero_points[0];
178+
} else {
179+
params_.bwScaleOffsetEncoding.offset = 0;
180+
}
181+
} else if (!is_per_tensor && is_int4_type) {
182+
const auto* io_shape = io_def.node_arg.Shape();
183+
ORT_RETURN_IF(io_shape == nullptr, "Input/output tensor proto must have a shape");
184+
const int32_t io_rank = io_shape->dim_size();
185+
186+
constexpr int64_t DEFAULT_QDQ_AXIS = 1;
187+
int64_t axis = ort_quant_params->axis.value_or(DEFAULT_QDQ_AXIS);
188+
if (axis < 0) {
189+
axis += io_rank;
190+
}
191+
ORT_RETURN_IF_NOT(axis >= 0 && axis < io_rank,
192+
"Quantization axis must be within the range [0, rank - 1]");
193+
194+
const size_t num_elems = scales.size();
195+
const bool no_zero_points = zero_points.empty();
196+
ORT_RETURN_IF_NOT(num_elems > 1, "Expected more than one scale value");
197+
ORT_RETURN_IF_NOT(no_zero_points || zero_points.size() == num_elems,
198+
"Expected the same number of zero-points and scales for per-channel quantization");
199+
200+
params_.encodingDefinition = QNN_DEFINITION_DEFINED;
201+
params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET;
202+
params_.bwAxisScaleOffsetEncoding.axis = static_cast<int32_t>(*(ort_quant_params->axis));
203+
params_.bwAxisScaleOffsetEncoding.bitwidth = 4;
204+
params_.bwAxisScaleOffsetEncoding.numElements = static_cast<uint32_t>(num_elems);
205+
206+
const size_t num_scale_bytes = num_elems * sizeof(float);
207+
const size_t num_zp_bytes = num_elems * sizeof(int32_t);
208+
const size_t num_bytes = num_scale_bytes + num_zp_bytes;
209+
constexpr std::uintptr_t align = alignof(float);
210+
per_channel_data_ = std::make_unique<char[]>(num_bytes + align);
211+
212+
char* scales_begin = ALIGN_PTR_UP(per_channel_data_.get(), align, char*);
213+
char* zps_begin = scales_begin + num_scale_bytes;
214+
gsl::span<float> scales_span(reinterpret_cast<float*>(scales_begin), num_elems);
215+
gsl::span<int32_t> zps_span(reinterpret_cast<int32_t*>(zps_begin), num_elems);
216+
217+
for (size_t i = 0; i < num_elems; i++) {
218+
scales_span[i] = scales[i];
219+
zps_span[i] = no_zero_points ? 0 : zero_points[i];
220+
}
221+
222+
params_.bwAxisScaleOffsetEncoding.scales = scales_span.data();
223+
params_.bwAxisScaleOffsetEncoding.offsets = zps_span.data();
224+
} else if (!is_per_tensor && !is_int4_type) {
122225
const auto* io_shape = io_def.node_arg.Shape();
123226
ORT_RETURN_IF(io_shape == nullptr, "Input/output tensor proto must have a shape");
124227
const int32_t io_rank = io_shape->dim_size();
@@ -140,8 +243,11 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con
140243
ORT_RETURN_IF_NOT(no_zero_points || zero_points.size() == num_elems,
141244
"Expected the same number of zero-points and scales for per-channel quantization");
142245

143-
scale_offset_data_ = std::make_unique<Qnn_ScaleOffset_t[]>(num_elems);
144-
gsl::span<Qnn_ScaleOffset_t> data_span(scale_offset_data_.get(), num_elems);
246+
const size_t num_bytes = num_elems * sizeof(Qnn_ScaleOffset_t);
247+
constexpr std::uintptr_t align = alignof(Qnn_ScaleOffset_t);
248+
per_channel_data_ = std::make_unique<char[]>(num_bytes + align);
249+
Qnn_ScaleOffset_t* aligned_dst = ALIGN_PTR_UP(per_channel_data_.get(), align, Qnn_ScaleOffset_t*);
250+
gsl::span<Qnn_ScaleOffset_t> data_span(aligned_dst, num_elems);
145251

146252
for (size_t i = 0; i < num_elems; i++) {
147253
data_span[i].scale = scales[i];
@@ -151,6 +257,8 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con
151257
params_.axisScaleOffsetEncoding.axis = static_cast<int32_t>(axis);
152258
params_.axisScaleOffsetEncoding.numScaleOffsets = static_cast<uint32_t>(num_elems);
153259
params_.axisScaleOffsetEncoding.scaleOffset = data_span.data();
260+
} else {
261+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected tensor kind for QuantParamsWrapper::Init()");
154262
}
155263

156264
return Status::OK();

onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,17 @@ class QnnQuantParamsWrapper {
4848
(include_bw && params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET));
4949
}
5050

51-
bool IsPerChannel(bool include_bw = false) const {
51+
bool IsPerChannel() const {
5252
return params_.encodingDefinition == QNN_DEFINITION_DEFINED &&
5353
(params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET ||
54-
(include_bw && params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET));
54+
(params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET));
5555
}
5656

5757
// Handle transposing of a per-channel quantized tensor. The quantization parameter's axis
5858
// must be transposed using the inverse permutation of the Transpose.
5959
template <typename IntType>
6060
Status HandleTranspose(gsl::span<const IntType> perm) {
61-
if (!IsPerChannel(true)) {
61+
if (!IsPerChannel()) {
6262
return Status::OK();
6363
}
6464

@@ -82,7 +82,7 @@ class QnnQuantParamsWrapper {
8282
template <typename IntType>
8383
Status HandleUnsqueeze(gsl::span<const IntType> orig_shape,
8484
gsl::span<const IntType> new_shape) {
85-
if (!IsPerChannel(true)) {
85+
if (!IsPerChannel()) {
8686
return Status::OK();
8787
}
8888

@@ -134,7 +134,13 @@ class QnnQuantParamsWrapper {
134134

135135
private:
136136
Qnn_QuantizeParams_t params_;
137-
std::unique_ptr<Qnn_ScaleOffset_t[]> scale_offset_data_; // Stores per-channel scales and offsets
137+
138+
// Stores arrays of per-channel scales and offsets. Fields in params_ point to this data.
139+
//
140+
// Use an opaque array of bytes because QNN uses different data layouts depending on the quantization encoding:
141+
// - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET: array of scale/zp pairs [{scale0, zp0}, {scale1, zp1}, ...]
142+
// - QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET: parallel arrays for scales and zps [scale0, ...] [zp0, zp1, ...]
143+
std::unique_ptr<char[]> per_channel_data_;
138144
};
139145

140146
} // namespace qnn

0 commit comments

Comments
 (0)