Skip to content

Commit 539d0ed

Browse files
[QNN EP] Add support for GatherNd Op in QNN EP (microsoft#25635)
- Added new op builder for GatherNd Op - Added unit tests for GatherNd and QDQ tests - Disabled two tests in ORT test as QNN CPU does not support negative indices ### Description Adding support for GatherNd op in QNN EP ### Motivation and Context Currently GatherNd op is not supported in QNN EP and hence falls back to ORT CPU.
1 parent 59871e3 commit 539d0ed

File tree

5 files changed

+459
-0
lines changed

5 files changed

+459
-0
lines changed

onnxruntime/core/providers/qnn/builder/op_builder_factory.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
199199
{
200200
CreateCumSumOpBuilder("CumSum", *this);
201201
}
202+
203+
{
204+
CreateGatherNDOpBuilder("GatherND", *this);
205+
}
202206
}
203207

204208
const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) {

onnxruntime/core/providers/qnn/builder/op_builder_factory.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,5 +113,7 @@ void CreateCumSumOpBuilder(const std::string& op_type, OpBuilderRegistrations& o
113113

114114
void CreateMeanOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
115115

116+
void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
117+
116118
} // namespace qnn
117119
} // namespace onnxruntime
Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include <cassert>
5+
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
6+
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
7+
#include "core/providers/qnn/builder/op_builder_factory.h"
8+
#include "core/providers/qnn/builder/qnn_utils.h"
9+
10+
namespace onnxruntime {
11+
namespace qnn {
12+
13+
// Handles GatherND
14+
class GatherNDOpBuilder : public BaseOpBuilder {
15+
public:
16+
GatherNDOpBuilder() : BaseOpBuilder("GatherNDOpBuilder") {}
17+
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GatherNDOpBuilder);
18+
19+
protected:
20+
Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
21+
const NodeUnit& node_unit,
22+
const logging::Logger& logger,
23+
std::vector<std::string>& input_names,
24+
bool do_op_validation) const override ORT_MUST_USE_RESULT;
25+
26+
Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
27+
const NodeUnit& node_unit,
28+
std::vector<std::string>&& input_names,
29+
const logging::Logger& logger,
30+
bool do_op_validation) const override ORT_MUST_USE_RESULT;
31+
};
32+
33+
// Fixes negative indices and converts int64 to uint32 for GatherND
34+
template <typename SrcType, typename DstType>
35+
bool FixStaticIndicesForGatherND(const std::vector<uint8_t>& onnx_bytes,
36+
const std::vector<int64_t>& indices_shape,
37+
const std::vector<int64_t>& data_shape,
38+
int64_t batch_dims,
39+
std::vector<uint8_t>& qnn_bytes) {
40+
const int64_t index_tuple_size = indices_shape.back();
41+
const size_t num_tuples = onnx_bytes.size() / (index_tuple_size * sizeof(SrcType));
42+
43+
gsl::span<const SrcType> onnx_indices{
44+
reinterpret_cast<const SrcType*>(onnx_bytes.data()), num_tuples * index_tuple_size};
45+
46+
qnn_bytes.resize(num_tuples * index_tuple_size * sizeof(DstType));
47+
gsl::span<DstType> qnn_indices{
48+
reinterpret_cast<DstType*>(qnn_bytes.data()), num_tuples * index_tuple_size};
49+
50+
for (size_t i = 0; i < num_tuples; ++i) {
51+
for (int64_t j = 0; j < index_tuple_size; ++j) {
52+
SrcType idx = onnx_indices[i * index_tuple_size + j];
53+
int64_t dim = data_shape[batch_dims + j];
54+
55+
if (idx < 0) {
56+
idx += static_cast<SrcType>(dim);
57+
}
58+
59+
if (idx < 0 || static_cast<int64_t>(idx) >= dim) {
60+
return false; // Out-of-bounds index
61+
}
62+
63+
qnn_indices[i * index_tuple_size + j] = static_cast<DstType>(idx);
64+
}
65+
}
66+
67+
return true;
68+
}
69+
70+
Status GatherNDOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
71+
const NodeUnit& node_unit,
72+
const logging::Logger& logger,
73+
std::vector<std::string>& input_names,
74+
bool do_op_validation) const {
75+
const auto& inputs = node_unit.Inputs();
76+
77+
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names));
78+
79+
const auto& data_input = inputs[0];
80+
const auto& indices_input = inputs[1];
81+
const auto& indices_tensor_name = indices_input.node_arg.Name();
82+
83+
TensorInfo indices_info = {};
84+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(indices_input, indices_info));
85+
86+
std::vector<uint8_t> qnn_indices_bytes;
87+
88+
if (indices_info.is_initializer) {
89+
std::vector<uint8_t> onnx_indices_bytes;
90+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*indices_info.initializer_tensor, onnx_indices_bytes));
91+
92+
std::vector<uint32_t> data_shape;
93+
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(data_input.node_arg, data_shape),
94+
"Failed to get data shape for GatherND.");
95+
96+
std::vector<uint32_t> indices_shape;
97+
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(indices_input.node_arg, indices_shape),
98+
"Failed to get indices shape for GatherND.");
99+
100+
if (indices_shape.empty()) {
101+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Indices shape is empty for GatherND.");
102+
}
103+
104+
// Get batch_dims for proper index processing
105+
NodeAttrHelper node_helper(node_unit);
106+
int64_t batch_dims = node_helper.Get("batch_dims", static_cast<int64_t>(0));
107+
108+
if (indices_info.qnn_data_type == QNN_DATATYPE_INT_64) {
109+
ORT_RETURN_IF_NOT((
110+
FixStaticIndicesForGatherND<int64_t, int32_t>(
111+
onnx_indices_bytes,
112+
std::vector<int64_t>(indices_shape.begin(), indices_shape.end()),
113+
std::vector<int64_t>(data_shape.begin(), data_shape.end()),
114+
batch_dims,
115+
qnn_indices_bytes)),
116+
"QNN does not support negative or out-of-bounds indices for GatherND.");
117+
indices_info.qnn_data_type = QNN_DATATYPE_INT_32;
118+
} else {
119+
qnn_indices_bytes = std::move(onnx_indices_bytes);
120+
}
121+
}
122+
123+
Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType(indices_tensor_name);
124+
std::vector<uint32_t> cast_output_shape(indices_info.shape);
125+
126+
if (!qnn_model_wrapper.IsQnnTensorWrapperExist(indices_tensor_name)) {
127+
QnnTensorWrapper input_tensorwrapper(indices_tensor_name, tensor_type, indices_info.qnn_data_type,
128+
QnnQuantParamsWrapper(), std::move(indices_info.shape),
129+
std::move(qnn_indices_bytes));
130+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");
131+
}
132+
133+
std::string indices_casted_name{indices_tensor_name};
134+
if (indices_info.qnn_data_type == QNN_DATATYPE_INT_64) {
135+
assert(!indices_info.is_initializer);
136+
indices_casted_name += "_int32";
137+
if (qnn_model_wrapper.IsQnnTensorWrapperExist(indices_casted_name)) {
138+
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << indices_casted_name;
139+
} else {
140+
QnnTensorWrapper indices_cast_tensor(indices_casted_name,
141+
QNN_TENSOR_TYPE_NATIVE,
142+
QNN_DATATYPE_INT_32,
143+
QnnQuantParamsWrapper(),
144+
std::move(cast_output_shape));
145+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(indices_cast_tensor)),
146+
"Failed to add gather indices cast tensor.");
147+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(indices_casted_name,
148+
QNN_OP_PACKAGE_NAME_QTI_AISW,
149+
QNN_OP_CAST,
150+
{indices_tensor_name},
151+
{indices_casted_name},
152+
{},
153+
do_op_validation),
154+
"Failed to add GatherNd indices cast node.");
155+
}
156+
}
157+
158+
input_names.push_back(indices_casted_name);
159+
160+
return Status::OK();
161+
}
162+
163+
Status GatherNDOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
164+
const NodeUnit& node_unit,
165+
std::vector<std::string>&& input_names,
166+
const logging::Logger& logger,
167+
bool do_op_validation) const {
168+
ORT_UNUSED_PARAMETER(logger);
169+
const auto& output = node_unit.Outputs()[0];
170+
const std::string& output_name = output.node_arg.Name();
171+
172+
QnnQuantParamsWrapper quant_params;
173+
ORT_RETURN_IF_ERROR(quant_params.Init(qnn_model_wrapper, output));
174+
175+
const auto* type_proto = output.node_arg.TypeAsProto();
176+
Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32;
177+
ORT_RETURN_IF_ERROR(utils::GetQnnDataType(quant_params.IsQuantized(), type_proto, qnn_data_type));
178+
179+
if (quant_params.IsPerTensor()) {
180+
// Make sure the output quantization parameters are equal to the input.
181+
ORT_RETURN_IF_ERROR(SetOutputQParamEqualToInputIfNearlyEqual(qnn_model_wrapper, node_unit, logger, input_names,
182+
0 /*input_index*/, 0 /*output_index*/, qnn_data_type,
183+
quant_params));
184+
}
185+
186+
NodeAttrHelper node_helper(node_unit);
187+
int64_t batch_dims = node_helper.Get("batch_dims", static_cast<int64_t>(0));
188+
189+
Qnn_Scalar_t batch_dims_scalar = QNN_SCALAR_INIT;
190+
batch_dims_scalar.dataType = QNN_DATATYPE_UINT_32;
191+
batch_dims_scalar.uint32Value = static_cast<uint32_t>(batch_dims);
192+
193+
QnnParamWrapper batch_dims_param(node_unit.Index(), node_unit.Name(),
194+
QNN_OP_GATHER_ND_PARAM_BATCH_DIMS, batch_dims_scalar);
195+
std::vector<std::string> param_tensor_names = {batch_dims_param.GetParamTensorName()};
196+
qnn_model_wrapper.AddParamWrapper(std::move(batch_dims_param));
197+
198+
// Get tensor wrappers for shape calculation
199+
const auto& data_tensor_wrapper = qnn_model_wrapper.GetQnnTensorWrapper(input_names[0]);
200+
const auto& indices_tensor_wrapper = qnn_model_wrapper.GetQnnTensorWrapper(input_names[1]);
201+
202+
// Calculate the QNN output shape for GatherND
203+
std::vector<uint32_t> qnn_output_shape;
204+
const auto& data_dims = data_tensor_wrapper.GetTensorDims();
205+
const auto& indices_dims = indices_tensor_wrapper.GetTensorDims();
206+
207+
// GatherND output shape calculation:
208+
size_t batch_dims_size = static_cast<size_t>(batch_dims);
209+
size_t indices_last_dim = indices_dims.back();
210+
211+
// Add batch dimensions from data
212+
for (size_t i = 0; i < batch_dims_size && i < data_dims.size(); ++i) {
213+
qnn_output_shape.push_back(data_dims[i]);
214+
}
215+
216+
// Add indices dimensions except the last one
217+
for (size_t i = 0; i < indices_dims.size() - 1; ++i) {
218+
qnn_output_shape.push_back(indices_dims[i]);
219+
}
220+
221+
// Add remaining data dimensions after batch_dims + indices_last_dim
222+
size_t start_dim = batch_dims_size + indices_last_dim;
223+
for (size_t i = start_dim; i < data_dims.size(); ++i) {
224+
qnn_output_shape.push_back(data_dims[i]);
225+
}
226+
227+
std::vector<uint32_t> target_output_shape;
228+
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(output.node_arg, target_output_shape),
229+
"Cannot get target output shape");
230+
231+
bool reshape_required = (qnn_output_shape.size() != target_output_shape.size());
232+
bool is_graph_output = qnn_model_wrapper.IsGraphOutput(output_name);
233+
234+
// Check if we need to add a cast node for int64
235+
bool needs_int64_cast = false;
236+
if (is_graph_output) {
237+
for (const auto& input_name : input_names) {
238+
if (input_name.find("_cast_int32") != std::string::npos) {
239+
needs_int64_cast = true;
240+
break;
241+
}
242+
}
243+
}
244+
struct CastNodeInfo {
245+
std::string node_name;
246+
std::string input_name;
247+
std::string output_name;
248+
};
249+
std::vector<CastNodeInfo> cast_node_info_vec;
250+
251+
// Get the output info for the gather output tensor
252+
TensorInfo output_info = {};
253+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(output, output_info));
254+
255+
// If a cast to int64 is needed, add the cast node
256+
if (needs_int64_cast) {
257+
std::string cast_node_name = output_name + "_cast_int64";
258+
std::string cast_input_name = output_name + "_cast_int64_aux";
259+
std::string cast_output_name = output_name;
260+
261+
// Create the cast input tensor wrapper - use qnn_output_shape for the intermediate tensor
262+
QnnTensorWrapper cast_input_tensorwrapper(cast_input_name,
263+
QNN_TENSOR_TYPE_NATIVE,
264+
output_info.qnn_data_type,
265+
output_info.quant_param.Copy(),
266+
std::vector<uint32_t>(qnn_output_shape));
267+
268+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_input_tensorwrapper)), "Failed to add tensor.");
269+
cast_node_info_vec.push_back({cast_node_name, cast_input_name, cast_output_name});
270+
Qnn_TensorType_t cast_tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
271+
QnnTensorWrapper cast_output(output_name, cast_tensor_type, qnn_data_type, quant_params.Copy(),
272+
std::vector<uint32_t>(target_output_shape));
273+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_output)), "Failed to add tensor.");
274+
}
275+
276+
std::string gather_output_name = output_name;
277+
if (reshape_required) {
278+
gather_output_name += "_ort_qnn_ep_reshape";
279+
} else if (needs_int64_cast) {
280+
gather_output_name += "_cast_int64_aux";
281+
}
282+
283+
Qnn_TensorType_t tensor_type = (!reshape_required && is_graph_output)
284+
? QNN_TENSOR_TYPE_APP_READ
285+
: QNN_TENSOR_TYPE_NATIVE;
286+
287+
QnnTensorWrapper gather_output_tensor(gather_output_name, tensor_type, qnn_data_type,
288+
quant_params.Copy(), std::move(qnn_output_shape));
289+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(gather_output_tensor)),
290+
"Failed to add GatherND output tensor.");
291+
292+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit),
293+
QNN_OP_PACKAGE_NAME_QTI_AISW,
294+
QNN_OP_GATHER_ND,
295+
std::move(input_names),
296+
{gather_output_name},
297+
std::move(param_tensor_names),
298+
do_op_validation),
299+
"Failed to create GatherND node.");
300+
301+
if (reshape_required) {
302+
Qnn_TensorType_t reshape_tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
303+
QnnTensorWrapper reshape_output(output_name, reshape_tensor_type, qnn_data_type,
304+
std::move(quant_params), std::move(target_output_shape));
305+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_output)), "Failed to add reshape output.");
306+
307+
std::string node_output_name = output_name;
308+
if (needs_int64_cast) {
309+
// If needs_int64 is true, the output name should be the input name of the cast node
310+
node_output_name = output_name + "_cast_int64_aux";
311+
}
312+
313+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(output_name,
314+
QNN_OP_PACKAGE_NAME_QTI_AISW,
315+
QNN_OP_RESHAPE,
316+
{gather_output_name},
317+
{node_output_name},
318+
{},
319+
do_op_validation),
320+
"Failed to add Reshape node.");
321+
}
322+
323+
if (needs_int64_cast) {
324+
for (const auto& cast_node_info : cast_node_info_vec) {
325+
// Insert cast node.
326+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(cast_node_info.node_name,
327+
QNN_OP_PACKAGE_NAME_QTI_AISW,
328+
QNN_OP_CAST,
329+
{cast_node_info.input_name},
330+
{cast_node_info.output_name},
331+
{}),
332+
"Failed to add Cast node");
333+
}
334+
}
335+
336+
return Status::OK();
337+
}
338+
339+
void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
340+
op_registrations.AddOpBuilder(op_type, std::make_unique<GatherNDOpBuilder>());
341+
}
342+
343+
} // namespace qnn
344+
} // namespace onnxruntime

onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ TEST(GatherNDOpTest, int64_t) {
8585
RunTest<int64_t>({2, 2, 2}, {0LL, 1LL, 2LL, 3LL, 4LL, 5LL, 6LL, 7LL}, {2, 1, 1}, {1, 0}, {2, 1, 2, 2},
8686
{4LL, 5LL, 6LL, 7LL, 0LL, 1LL, 2LL, 3LL});
8787

88+
if (DefaultQnnExecutionProvider().get() != nullptr) {
89+
GTEST_SKIP() << "Skipping because QNN CPU does not support negative indices being inputs.";
90+
}
91+
8892
// with negative indices
8993
RunTest<int64_t>({2, 2, 2}, {0LL, 1LL, 2LL, 3LL, 4LL, 5LL, 6LL, 7LL}, {2, 1, 1}, {-1, 0}, {2, 1, 2, 2},
9094
{4LL, 5LL, 6LL, 7LL, 0LL, 1LL, 2LL, 3LL});
@@ -97,6 +101,10 @@ TEST(GatherNDOpTest, float) {
97101

98102
RunTest<float>({2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}, {2, 1}, {1LL, 0LL}, {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f});
99103

104+
if (DefaultQnnExecutionProvider().get() != nullptr) {
105+
GTEST_SKIP() << "Skipping because QNN CPU does not support negative indices being inputs.";
106+
}
107+
100108
// with negative indices
101109
RunTest<float>({2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}, {2, 1}, {-1LL, 0LL}, {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f});
102110
}

0 commit comments

Comments
 (0)