99#include " QnnTypes.h"
1010#include " core/providers/qnn/builder/qnn_model_wrapper.h"
1111
12+ #define ALIGN_PTR_UP (ptr, align, type ) \
13+ reinterpret_cast <type>((reinterpret_cast <std::uintptr_t >(ptr) + (align)-1 ) & ~((align)-1 ))
14+
1215namespace onnxruntime {
1316namespace qnn {
1417
@@ -38,9 +41,10 @@ QnnQuantParamsWrapper QnnQuantParamsWrapper::Copy() const {
3841 return QnnQuantParamsWrapper (*this );
3942}
4043
44+ // Initializes by copying from a Qnn_QuantizeParams_t.
4145Status QnnQuantParamsWrapper::Init (const Qnn_QuantizeParams_t& params) {
42- if (scale_offset_data_ ) {
43- scale_offset_data_ .reset (nullptr );
46+ if (per_channel_data_ ) {
47+ per_channel_data_ .reset (nullptr );
4448 params_ = QNN_QUANTIZE_PARAMS_INIT;
4549 }
4650
@@ -51,6 +55,7 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params) {
5155
5256 switch (params.quantizationEncoding ) {
5357 case QNN_QUANTIZATION_ENCODING_SCALE_OFFSET:
58+ case QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET:
5459 params_ = params;
5560 break ;
5661 case QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET: {
@@ -63,27 +68,63 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params) {
6368 const uint32_t num_elems = params.axisScaleOffsetEncoding .numScaleOffsets ;
6469
6570 if (num_elems > 0 ) {
66- scale_offset_data_ = std::make_unique<Qnn_ScaleOffset_t[]>(num_elems);
67- gsl::span<Qnn_ScaleOffset_t> src_span (params.axisScaleOffsetEncoding .scaleOffset , num_elems);
68- std::copy (src_span.begin (), src_span.end (), scale_offset_data_.get ());
69- params_.axisScaleOffsetEncoding .scaleOffset = scale_offset_data_.get ();
71+ const size_t num_bytes = num_elems * sizeof (Qnn_ScaleOffset_t);
72+ constexpr std::uintptr_t align = alignof (Qnn_ScaleOffset_t);
73+ per_channel_data_ = std::make_unique<char []>(num_bytes + align);
74+ Qnn_ScaleOffset_t* aligned_dst = ALIGN_PTR_UP (per_channel_data_.get (), align, Qnn_ScaleOffset_t*);
75+
76+ std::memcpy (aligned_dst, params.axisScaleOffsetEncoding .scaleOffset , num_bytes);
77+ params_.axisScaleOffsetEncoding .scaleOffset = aligned_dst;
7078 } else {
7179 params_.axisScaleOffsetEncoding .scaleOffset = nullptr ;
7280 }
7381 break ;
7482 }
83+ case QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET: {
84+ const uint32_t num_elems = params.bwAxisScaleOffsetEncoding .numElements ;
85+
86+ params_.encodingDefinition = params.encodingDefinition ;
87+ params_.quantizationEncoding = params.quantizationEncoding ;
88+ params_.bwAxisScaleOffsetEncoding .axis = params.bwAxisScaleOffsetEncoding .axis ;
89+ params_.bwAxisScaleOffsetEncoding .bitwidth = params.bwAxisScaleOffsetEncoding .bitwidth ;
90+ params_.bwAxisScaleOffsetEncoding .numElements = num_elems;
91+
92+ // Deep copy the scales[] and offsets[] arrays
93+ if (num_elems > 0 ) {
94+ const size_t num_scale_bytes = num_elems * sizeof (float );
95+ const size_t num_zp_bytes = num_elems * sizeof (int32_t );
96+ const size_t num_bytes = num_scale_bytes + num_zp_bytes;
97+ constexpr std::uintptr_t align = alignof (float );
98+ static_assert (alignof (float ) == alignof (int32_t ));
99+
100+ per_channel_data_ = std::make_unique<char []>(num_bytes + align);
101+ char * scales_begin = ALIGN_PTR_UP (per_channel_data_.get (), align, char *);
102+ char * zps_begin = scales_begin + num_scale_bytes;
103+
104+ std::memcpy (scales_begin, params.bwAxisScaleOffsetEncoding .scales , num_scale_bytes);
105+ std::memcpy (zps_begin, params.bwAxisScaleOffsetEncoding .offsets , num_zp_bytes);
106+ params_.bwAxisScaleOffsetEncoding .scales = reinterpret_cast <float *>(scales_begin);
107+ params_.bwAxisScaleOffsetEncoding .offsets = reinterpret_cast <int32_t *>(zps_begin);
108+ } else {
109+ params_.bwAxisScaleOffsetEncoding .scales = nullptr ;
110+ params_.bwAxisScaleOffsetEncoding .offsets = nullptr ;
111+ }
112+ break ;
113+ }
75114 default :
76115 return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " Unsupported QNN quantization encoding: " , params.quantizationEncoding );
77116 }
78117
79118 return Status::OK ();
80119}
81120
121+ // Initialize this object from a (potentially) quantized ONNX tensor.
122+ // QnnModelWrapper provides utilities for unpacking scale and zero-point ONNX initializers.
82123Status QnnQuantParamsWrapper::Init (const QnnModelWrapper& qnn_model_wrapper, const NodeUnitIODef& io_def) {
83124 const std::optional<NodeUnitIODef::QuantParam>& ort_quant_params = io_def.quant_param ;
84125
85- if (scale_offset_data_ ) {
86- scale_offset_data_ .reset (nullptr );
126+ if (per_channel_data_ ) {
127+ per_channel_data_ .reset (nullptr );
87128 params_ = QNN_QUANTIZE_PARAMS_INIT;
88129 }
89130
@@ -98,17 +139,25 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con
98139
99140 ORT_RETURN_IF_ERROR (qnn_model_wrapper.UnpackScales (ort_quant_params->scale .Name (), scales));
100141
142+ bool is_int4_type = false ;
143+
101144 if (ort_quant_params->zero_point != nullptr ) {
102- ORT_RETURN_IF_ERROR (qnn_model_wrapper.UnpackZeroPoints (ort_quant_params->zero_point ->Name (), zero_points));
145+ int32_t onnx_tp_type = 0 ;
146+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.UnpackZeroPoints (ort_quant_params->zero_point ->Name (), zero_points,
147+ onnx_tp_type));
148+
149+ is_int4_type = (onnx_tp_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) ||
150+ (onnx_tp_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4);
103151 }
104152
105153 const bool is_per_tensor = scales.size () == 1 ;
106154
107- if (is_per_tensor) {
155+ // QNN uses different structs to represent quantization parameters depending on
156+ // - per-tensor vs per-channel
157+ // - int4 vs not int4
158+ if (is_per_tensor && !is_int4_type) {
108159 params_.encodingDefinition = QNN_DEFINITION_DEFINED;
109160 params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET;
110-
111- // Parse scale & zero_point
112161 params_.scaleOffsetEncoding .scale = scales[0 ];
113162
114163 if (ort_quant_params->zero_point != nullptr ) {
@@ -117,8 +166,62 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con
117166 } else {
118167 params_.scaleOffsetEncoding .offset = 0 ;
119168 }
120- } else {
121- // Per-channel quantization.
169+ } else if (is_per_tensor && is_int4_type) {
170+ params_.encodingDefinition = QNN_DEFINITION_DEFINED;
171+ params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET;
172+ params_.bwScaleOffsetEncoding .bitwidth = 4 ;
173+ params_.bwScaleOffsetEncoding .scale = scales[0 ];
174+
175+ if (ort_quant_params->zero_point != nullptr ) {
176+ ORT_RETURN_IF_NOT (zero_points.size () == 1 , " Expected one zero-point value" );
177+ params_.bwScaleOffsetEncoding .offset = zero_points[0 ];
178+ } else {
179+ params_.bwScaleOffsetEncoding .offset = 0 ;
180+ }
181+ } else if (!is_per_tensor && is_int4_type) {
182+ const auto * io_shape = io_def.node_arg .Shape ();
183+ ORT_RETURN_IF (io_shape == nullptr , " Input/output tensor proto must have a shape" );
184+ const int32_t io_rank = io_shape->dim_size ();
185+
186+ constexpr int64_t DEFAULT_QDQ_AXIS = 1 ;
187+ int64_t axis = ort_quant_params->axis .value_or (DEFAULT_QDQ_AXIS);
188+ if (axis < 0 ) {
189+ axis += io_rank;
190+ }
191+ ORT_RETURN_IF_NOT (axis >= 0 && axis < io_rank,
192+ " Quantization axis must be within the range [0, rank - 1]" );
193+
194+ const size_t num_elems = scales.size ();
195+ const bool no_zero_points = zero_points.empty ();
196+ ORT_RETURN_IF_NOT (num_elems > 1 , " Expected more than one scale value" );
197+ ORT_RETURN_IF_NOT (no_zero_points || zero_points.size () == num_elems,
198+ " Expected the same number of zero-points and scales for per-channel quantization" );
199+
200+ params_.encodingDefinition = QNN_DEFINITION_DEFINED;
201+ params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET;
202+ params_.bwAxisScaleOffsetEncoding .axis = static_cast <int32_t >(*(ort_quant_params->axis ));
203+ params_.bwAxisScaleOffsetEncoding .bitwidth = 4 ;
204+ params_.bwAxisScaleOffsetEncoding .numElements = static_cast <uint32_t >(num_elems);
205+
206+ const size_t num_scale_bytes = num_elems * sizeof (float );
207+ const size_t num_zp_bytes = num_elems * sizeof (int32_t );
208+ const size_t num_bytes = num_scale_bytes + num_zp_bytes;
209+ constexpr std::uintptr_t align = alignof (float );
210+ per_channel_data_ = std::make_unique<char []>(num_bytes + align);
211+
212+ char * scales_begin = ALIGN_PTR_UP (per_channel_data_.get (), align, char *);
213+ char * zps_begin = scales_begin + num_scale_bytes;
214+ gsl::span<float > scales_span (reinterpret_cast <float *>(scales_begin), num_elems);
215+ gsl::span<int32_t > zps_span (reinterpret_cast <int32_t *>(zps_begin), num_elems);
216+
217+ for (size_t i = 0 ; i < num_elems; i++) {
218+ scales_span[i] = scales[i];
219+ zps_span[i] = no_zero_points ? 0 : zero_points[i];
220+ }
221+
222+ params_.bwAxisScaleOffsetEncoding .scales = scales_span.data ();
223+ params_.bwAxisScaleOffsetEncoding .offsets = zps_span.data ();
224+ } else if (!is_per_tensor && !is_int4_type) {
122225 const auto * io_shape = io_def.node_arg .Shape ();
123226 ORT_RETURN_IF (io_shape == nullptr , " Input/output tensor proto must have a shape" );
124227 const int32_t io_rank = io_shape->dim_size ();
@@ -140,8 +243,11 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con
140243 ORT_RETURN_IF_NOT (no_zero_points || zero_points.size () == num_elems,
141244 " Expected the same number of zero-points and scales for per-channel quantization" );
142245
143- scale_offset_data_ = std::make_unique<Qnn_ScaleOffset_t[]>(num_elems);
144- gsl::span<Qnn_ScaleOffset_t> data_span (scale_offset_data_.get (), num_elems);
246+ const size_t num_bytes = num_elems * sizeof (Qnn_ScaleOffset_t);
247+ constexpr std::uintptr_t align = alignof (Qnn_ScaleOffset_t);
248+ per_channel_data_ = std::make_unique<char []>(num_bytes + align);
249+ Qnn_ScaleOffset_t* aligned_dst = ALIGN_PTR_UP (per_channel_data_.get (), align, Qnn_ScaleOffset_t*);
250+ gsl::span<Qnn_ScaleOffset_t> data_span (aligned_dst, num_elems);
145251
146252 for (size_t i = 0 ; i < num_elems; i++) {
147253 data_span[i].scale = scales[i];
@@ -151,6 +257,8 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con
151257 params_.axisScaleOffsetEncoding .axis = static_cast <int32_t >(axis);
152258 params_.axisScaleOffsetEncoding .numScaleOffsets = static_cast <uint32_t >(num_elems);
153259 params_.axisScaleOffsetEncoding .scaleOffset = data_span.data ();
260+ } else {
261+ return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " Unexpected tensor kind for QuantParamsWrapper::Init()" );
154262 }
155263
156264 return Status::OK ();
0 commit comments