1+ // Copyright (c) Microsoft Corporation. All rights reserved.
2+ // Licensed under the MIT License.
3+
4+ #include < cassert>
5+ #include " core/providers/qnn/builder/opbuilder/base_op_builder.h"
6+ #include " core/providers/qnn/builder/qnn_model_wrapper.h"
7+ #include " core/providers/qnn/builder/op_builder_factory.h"
8+ #include " core/providers/qnn/builder/qnn_utils.h"
9+
10+ namespace onnxruntime {
11+ namespace qnn {
12+
13+ // Handles GatherND
14+ class GatherNDOpBuilder : public BaseOpBuilder {
15+ public:
16+ GatherNDOpBuilder () : BaseOpBuilder(" GatherNDOpBuilder" ) {}
17+ ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE (GatherNDOpBuilder);
18+
19+ protected:
20+ Status ProcessInputs (QnnModelWrapper& qnn_model_wrapper,
21+ const NodeUnit& node_unit,
22+ const logging::Logger& logger,
23+ std::vector<std::string>& input_names,
24+ bool do_op_validation) const override ORT_MUST_USE_RESULT;
25+
26+ Status ProcessAttributesAndOutputs (QnnModelWrapper& qnn_model_wrapper,
27+ const NodeUnit& node_unit,
28+ std::vector<std::string>&& input_names,
29+ const logging::Logger& logger,
30+ bool do_op_validation) const override ORT_MUST_USE_RESULT;
31+ };
32+
33+ // Fixes negative indices and converts int64 to uint32 for GatherND
34+ template <typename SrcType, typename DstType>
35+ bool FixStaticIndicesForGatherND (const std::vector<uint8_t >& onnx_bytes,
36+ const std::vector<int64_t >& indices_shape,
37+ const std::vector<int64_t >& data_shape,
38+ int64_t batch_dims,
39+ std::vector<uint8_t >& qnn_bytes) {
40+ const int64_t index_tuple_size = indices_shape.back ();
41+ const size_t num_tuples = onnx_bytes.size () / (index_tuple_size * sizeof (SrcType));
42+
43+ gsl::span<const SrcType> onnx_indices{
44+ reinterpret_cast <const SrcType*>(onnx_bytes.data ()), num_tuples * index_tuple_size};
45+
46+ qnn_bytes.resize (num_tuples * index_tuple_size * sizeof (DstType));
47+ gsl::span<DstType> qnn_indices{
48+ reinterpret_cast <DstType*>(qnn_bytes.data ()), num_tuples * index_tuple_size};
49+
50+ for (size_t i = 0 ; i < num_tuples; ++i) {
51+ for (int64_t j = 0 ; j < index_tuple_size; ++j) {
52+ SrcType idx = onnx_indices[i * index_tuple_size + j];
53+ int64_t dim = data_shape[batch_dims + j];
54+
55+ if (idx < 0 ) {
56+ idx += static_cast <SrcType>(dim);
57+ }
58+
59+ if (idx < 0 || static_cast <int64_t >(idx) >= dim) {
60+ return false ; // Out-of-bounds index
61+ }
62+
63+ qnn_indices[i * index_tuple_size + j] = static_cast <DstType>(idx);
64+ }
65+ }
66+
67+ return true ;
68+ }
69+
70+ Status GatherNDOpBuilder::ProcessInputs (QnnModelWrapper& qnn_model_wrapper,
71+ const NodeUnit& node_unit,
72+ const logging::Logger& logger,
73+ std::vector<std::string>& input_names,
74+ bool do_op_validation) const {
75+ const auto & inputs = node_unit.Inputs ();
76+
77+ ORT_RETURN_IF_ERROR (ProcessInput (qnn_model_wrapper, inputs[0 ], logger, input_names));
78+
79+ const auto & data_input = inputs[0 ];
80+ const auto & indices_input = inputs[1 ];
81+ const auto & indices_tensor_name = indices_input.node_arg .Name ();
82+
83+ TensorInfo indices_info = {};
84+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (indices_input, indices_info));
85+
86+ std::vector<uint8_t > qnn_indices_bytes;
87+
88+ if (indices_info.is_initializer ) {
89+ std::vector<uint8_t > onnx_indices_bytes;
90+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.UnpackInitializerData (*indices_info.initializer_tensor , onnx_indices_bytes));
91+
92+ std::vector<uint32_t > data_shape;
93+ ORT_RETURN_IF_NOT (qnn_model_wrapper.GetOnnxShape (data_input.node_arg , data_shape),
94+ " Failed to get data shape for GatherND." );
95+
96+ std::vector<uint32_t > indices_shape;
97+ ORT_RETURN_IF_NOT (qnn_model_wrapper.GetOnnxShape (indices_input.node_arg , indices_shape),
98+ " Failed to get indices shape for GatherND." );
99+
100+ if (indices_shape.empty ()) {
101+ return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " Indices shape is empty for GatherND." );
102+ }
103+
104+ // Get batch_dims for proper index processing
105+ NodeAttrHelper node_helper (node_unit);
106+ int64_t batch_dims = node_helper.Get (" batch_dims" , static_cast <int64_t >(0 ));
107+
108+ if (indices_info.qnn_data_type == QNN_DATATYPE_INT_64) {
109+ ORT_RETURN_IF_NOT ((
110+ FixStaticIndicesForGatherND<int64_t , int32_t >(
111+ onnx_indices_bytes,
112+ std::vector<int64_t >(indices_shape.begin (), indices_shape.end ()),
113+ std::vector<int64_t >(data_shape.begin (), data_shape.end ()),
114+ batch_dims,
115+ qnn_indices_bytes)),
116+ " QNN does not support negative or out-of-bounds indices for GatherND." );
117+ indices_info.qnn_data_type = QNN_DATATYPE_INT_32;
118+ } else {
119+ qnn_indices_bytes = std::move (onnx_indices_bytes);
120+ }
121+ }
122+
123+ Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType (indices_tensor_name);
124+ std::vector<uint32_t > cast_output_shape (indices_info.shape );
125+
126+ if (!qnn_model_wrapper.IsQnnTensorWrapperExist (indices_tensor_name)) {
127+ QnnTensorWrapper input_tensorwrapper (indices_tensor_name, tensor_type, indices_info.qnn_data_type ,
128+ QnnQuantParamsWrapper (), std::move (indices_info.shape ),
129+ std::move (qnn_indices_bytes));
130+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (input_tensorwrapper)), " Failed to add tensor." );
131+ }
132+
133+ std::string indices_casted_name{indices_tensor_name};
134+ if (indices_info.qnn_data_type == QNN_DATATYPE_INT_64) {
135+ assert (!indices_info.is_initializer );
136+ indices_casted_name += " _int32" ;
137+ if (qnn_model_wrapper.IsQnnTensorWrapperExist (indices_casted_name)) {
138+ LOGS (logger, VERBOSE) << " Tensor already added, skip it: " << indices_casted_name;
139+ } else {
140+ QnnTensorWrapper indices_cast_tensor (indices_casted_name,
141+ QNN_TENSOR_TYPE_NATIVE,
142+ QNN_DATATYPE_INT_32,
143+ QnnQuantParamsWrapper (),
144+ std::move (cast_output_shape));
145+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (indices_cast_tensor)),
146+ " Failed to add gather indices cast tensor." );
147+ ORT_RETURN_IF_NOT (qnn_model_wrapper.CreateQnnNode (indices_casted_name,
148+ QNN_OP_PACKAGE_NAME_QTI_AISW,
149+ QNN_OP_CAST,
150+ {indices_tensor_name},
151+ {indices_casted_name},
152+ {},
153+ do_op_validation),
154+ " Failed to add GatherNd indices cast node." );
155+ }
156+ }
157+
158+ input_names.push_back (indices_casted_name);
159+
160+ return Status::OK ();
161+ }
162+
163+ Status GatherNDOpBuilder::ProcessAttributesAndOutputs (QnnModelWrapper& qnn_model_wrapper,
164+ const NodeUnit& node_unit,
165+ std::vector<std::string>&& input_names,
166+ const logging::Logger& logger,
167+ bool do_op_validation) const {
168+ ORT_UNUSED_PARAMETER (logger);
169+ const auto & output = node_unit.Outputs ()[0 ];
170+ const std::string& output_name = output.node_arg .Name ();
171+
172+ QnnQuantParamsWrapper quant_params;
173+ ORT_RETURN_IF_ERROR (quant_params.Init (qnn_model_wrapper, output));
174+
175+ const auto * type_proto = output.node_arg .TypeAsProto ();
176+ Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32;
177+ ORT_RETURN_IF_ERROR (utils::GetQnnDataType (quant_params.IsQuantized (), type_proto, qnn_data_type));
178+
179+ if (quant_params.IsPerTensor ()) {
180+ // Make sure the output quantization parameters are equal to the input.
181+ ORT_RETURN_IF_ERROR (SetOutputQParamEqualToInputIfNearlyEqual (qnn_model_wrapper, node_unit, logger, input_names,
182+ 0 /* input_index*/ , 0 /* output_index*/ , qnn_data_type,
183+ quant_params));
184+ }
185+
186+ NodeAttrHelper node_helper (node_unit);
187+ int64_t batch_dims = node_helper.Get (" batch_dims" , static_cast <int64_t >(0 ));
188+
189+ Qnn_Scalar_t batch_dims_scalar = QNN_SCALAR_INIT;
190+ batch_dims_scalar.dataType = QNN_DATATYPE_UINT_32;
191+ batch_dims_scalar.uint32Value = static_cast <uint32_t >(batch_dims);
192+
193+ QnnParamWrapper batch_dims_param (node_unit.Index (), node_unit.Name (),
194+ QNN_OP_GATHER_ND_PARAM_BATCH_DIMS, batch_dims_scalar);
195+ std::vector<std::string> param_tensor_names = {batch_dims_param.GetParamTensorName ()};
196+ qnn_model_wrapper.AddParamWrapper (std::move (batch_dims_param));
197+
198+ // Get tensor wrappers for shape calculation
199+ const auto & data_tensor_wrapper = qnn_model_wrapper.GetQnnTensorWrapper (input_names[0 ]);
200+ const auto & indices_tensor_wrapper = qnn_model_wrapper.GetQnnTensorWrapper (input_names[1 ]);
201+
202+ // Calculate the QNN output shape for GatherND
203+ std::vector<uint32_t > qnn_output_shape;
204+ const auto & data_dims = data_tensor_wrapper.GetTensorDims ();
205+ const auto & indices_dims = indices_tensor_wrapper.GetTensorDims ();
206+
207+ // GatherND output shape calculation:
208+ size_t batch_dims_size = static_cast <size_t >(batch_dims);
209+ size_t indices_last_dim = indices_dims.back ();
210+
211+ // Add batch dimensions from data
212+ for (size_t i = 0 ; i < batch_dims_size && i < data_dims.size (); ++i) {
213+ qnn_output_shape.push_back (data_dims[i]);
214+ }
215+
216+ // Add indices dimensions except the last one
217+ for (size_t i = 0 ; i < indices_dims.size () - 1 ; ++i) {
218+ qnn_output_shape.push_back (indices_dims[i]);
219+ }
220+
221+ // Add remaining data dimensions after batch_dims + indices_last_dim
222+ size_t start_dim = batch_dims_size + indices_last_dim;
223+ for (size_t i = start_dim; i < data_dims.size (); ++i) {
224+ qnn_output_shape.push_back (data_dims[i]);
225+ }
226+
227+ std::vector<uint32_t > target_output_shape;
228+ ORT_RETURN_IF_NOT (qnn_model_wrapper.GetOnnxShape (output.node_arg , target_output_shape),
229+ " Cannot get target output shape" );
230+
231+ bool reshape_required = (qnn_output_shape.size () != target_output_shape.size ());
232+ bool is_graph_output = qnn_model_wrapper.IsGraphOutput (output_name);
233+
234+ // Check if we need to add a cast node for int64
235+ bool needs_int64_cast = false ;
236+ if (is_graph_output) {
237+ for (const auto & input_name : input_names) {
238+ if (input_name.find (" _cast_int32" ) != std::string::npos) {
239+ needs_int64_cast = true ;
240+ break ;
241+ }
242+ }
243+ }
244+ struct CastNodeInfo {
245+ std::string node_name;
246+ std::string input_name;
247+ std::string output_name;
248+ };
249+ std::vector<CastNodeInfo> cast_node_info_vec;
250+
251+ // Get the output info for the gather output tensor
252+ TensorInfo output_info = {};
253+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (output, output_info));
254+
255+ // If a cast to int64 is needed, add the cast node
256+ if (needs_int64_cast) {
257+ std::string cast_node_name = output_name + " _cast_int64" ;
258+ std::string cast_input_name = output_name + " _cast_int64_aux" ;
259+ std::string cast_output_name = output_name;
260+
261+ // Create the cast input tensor wrapper - use qnn_output_shape for the intermediate tensor
262+ QnnTensorWrapper cast_input_tensorwrapper (cast_input_name,
263+ QNN_TENSOR_TYPE_NATIVE,
264+ output_info.qnn_data_type ,
265+ output_info.quant_param .Copy (),
266+ std::vector<uint32_t >(qnn_output_shape));
267+
268+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (cast_input_tensorwrapper)), " Failed to add tensor." );
269+ cast_node_info_vec.push_back ({cast_node_name, cast_input_name, cast_output_name});
270+ Qnn_TensorType_t cast_tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
271+ QnnTensorWrapper cast_output (output_name, cast_tensor_type, qnn_data_type, quant_params.Copy (),
272+ std::vector<uint32_t >(target_output_shape));
273+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (cast_output)), " Failed to add tensor." );
274+ }
275+
276+ std::string gather_output_name = output_name;
277+ if (reshape_required) {
278+ gather_output_name += " _ort_qnn_ep_reshape" ;
279+ } else if (needs_int64_cast) {
280+ gather_output_name += " _cast_int64_aux" ;
281+ }
282+
283+ Qnn_TensorType_t tensor_type = (!reshape_required && is_graph_output)
284+ ? QNN_TENSOR_TYPE_APP_READ
285+ : QNN_TENSOR_TYPE_NATIVE;
286+
287+ QnnTensorWrapper gather_output_tensor (gather_output_name, tensor_type, qnn_data_type,
288+ quant_params.Copy (), std::move (qnn_output_shape));
289+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (gather_output_tensor)),
290+ " Failed to add GatherND output tensor." );
291+
292+ ORT_RETURN_IF_NOT (qnn_model_wrapper.CreateQnnNode (utils::GetNodeName (node_unit),
293+ QNN_OP_PACKAGE_NAME_QTI_AISW,
294+ QNN_OP_GATHER_ND,
295+ std::move (input_names),
296+ {gather_output_name},
297+ std::move (param_tensor_names),
298+ do_op_validation),
299+ " Failed to create GatherND node." );
300+
301+ if (reshape_required) {
302+ Qnn_TensorType_t reshape_tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
303+ QnnTensorWrapper reshape_output (output_name, reshape_tensor_type, qnn_data_type,
304+ std::move (quant_params), std::move (target_output_shape));
305+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (reshape_output)), " Failed to add reshape output." );
306+
307+ std::string node_output_name = output_name;
308+ if (needs_int64_cast) {
309+ // If needs_int64 is true, the output name should be the input name of the cast node
310+ node_output_name = output_name + " _cast_int64_aux" ;
311+ }
312+
313+ ORT_RETURN_IF_NOT (qnn_model_wrapper.CreateQnnNode (output_name,
314+ QNN_OP_PACKAGE_NAME_QTI_AISW,
315+ QNN_OP_RESHAPE,
316+ {gather_output_name},
317+ {node_output_name},
318+ {},
319+ do_op_validation),
320+ " Failed to add Reshape node." );
321+ }
322+
323+ if (needs_int64_cast) {
324+ for (const auto & cast_node_info : cast_node_info_vec) {
325+ // Insert cast node.
326+ ORT_RETURN_IF_NOT (qnn_model_wrapper.CreateQnnNode (cast_node_info.node_name ,
327+ QNN_OP_PACKAGE_NAME_QTI_AISW,
328+ QNN_OP_CAST,
329+ {cast_node_info.input_name },
330+ {cast_node_info.output_name },
331+ {}),
332+ " Failed to add Cast node" );
333+ }
334+ }
335+
336+ return Status::OK ();
337+ }
338+
339+ void CreateGatherNDOpBuilder (const std::string& op_type, OpBuilderRegistrations& op_registrations) {
340+ op_registrations.AddOpBuilder (op_type, std::make_unique<GatherNDOpBuilder>());
341+ }
342+
343+ } // namespace qnn
344+ } // namespace onnxruntime
0 commit comments