11// Copyright (c) Microsoft Corporation. All rights reserved.
22// Licensed under the MIT License.
3+
4+ #include < memory>
5+ #include < string>
6+ #include < utility>
7+ #include < vector>
8+
39#include " core/providers/qnn/builder/opbuilder/base_op_builder.h"
410#include " core/providers/qnn/builder/op_builder_factory.h"
511#include " core/providers/qnn/builder/qnn_utils.h"
12+
613namespace onnxruntime {
714namespace qnn {
15+
816const int TOPK_MIN_INPUT = 2 ;
917const int TOPK_MAX_INPUT = 2 ;
18+
1019class TopKOpBuilder : public BaseOpBuilder {
1120 public:
1221 TopKOpBuilder () : BaseOpBuilder(" TopKOpBuilder" ) {}
@@ -41,8 +50,11 @@ class TopKOpBuilder : public BaseOpBuilder {
4150
4251Status TopKOpBuilder::ExplictOpCheck (QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const {
4352 size_t input_count = node_unit.Inputs ().size ();
53+ size_t output_count = node_unit.Outputs ().size ();
4454 ORT_RETURN_IF_NOT (input_count >= TOPK_MIN_INPUT && input_count <= TOPK_MAX_INPUT,
4555 " For ONNX TopK operation the expected number of inputs is 2." );
56+ ORT_RETURN_IF_NOT (output_count == 2 , " QNN TopK expects exactly 2 outputs." );
57+
4658 // Skip the first input. The second input needs to be an initializer.
4759 const auto & input_1 = node_unit.Inputs ()[1 ].node_arg .Name ();
4860 if (!qnn_model_wrapper.IsConstantInput (input_1)) {
@@ -57,14 +69,6 @@ Status TopKOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const N
5769 if (0 == largest) {
5870 return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " QNN TopK output is always largest values" );
5971 }
60- auto & input_0 = node_unit.Inputs ()[0 ];
61- std::vector<uint32_t > input_shape;
62- ORT_RETURN_IF_NOT (qnn_model_wrapper.GetOnnxShape (input_0.node_arg , input_shape), " Cannot get shape" );
63- auto rank = input_shape.size ();
64- auto axis = node_helper.Get (" axis" , -1 );
65-
66- ORT_RETURN_IF_NOT (axis == -1 || axis == static_cast <int32_t >(rank - 1 ),
67- " QNN TopK's axis is always the last dimension" );
6872
6973 return Status::OK ();
7074}
@@ -81,6 +85,40 @@ Status TopKOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
8185 const auto & inputs = node_unit.Inputs ();
8286 ORT_RETURN_IF_ERROR (ProcessInput (qnn_model_wrapper, inputs[0 ], logger, input_names));
8387
88+ // HTP only supports TopK at the last axis, and thus check whether extra Transpose is required.
89+ TensorInfo input_info = {};
90+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (node_unit.Inputs ()[0 ], input_info));
91+
92+ size_t input_rank = input_info.shape .size ();
93+ int32_t axis = NodeAttrHelper (node_unit).Get (" axis" , -1 );
94+ if (axis == -1 || axis == static_cast <int32_t >(input_rank - 1 )) {
95+ return Status::OK ();
96+ }
97+
98+ // Add Transpose to permute axis to the last.
99+ std::string transpose_output_name = input_names[0 ] + " _ort_qnn_ep_transpose" ;
100+ std::vector<uint32_t > transpose_perm;
101+ ORT_RETURN_IF_ERROR (utils::GetPermToLastAxis (static_cast <uint32_t >(axis),
102+ static_cast <uint32_t >(input_rank),
103+ transpose_perm));
104+
105+ std::vector<uint32_t > transpose_output_shape = input_info.shape ;
106+ transpose_output_shape[input_rank - 1 ] = input_info.shape [axis];
107+ transpose_output_shape[axis] = input_info.shape [input_rank - 1 ];
108+
109+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.AddTransposeNode (node_unit.Index (),
110+ input_names[0 ],
111+ transpose_output_name,
112+ input_info.shape ,
113+ transpose_perm,
114+ transpose_output_shape,
115+ input_info.qnn_data_type ,
116+ input_info.quant_param ,
117+ do_op_validation,
118+ false ,
119+ false ));
120+ input_names[0 ] = transpose_output_name;
121+
84122 return Status::OK ();
85123}
86124
@@ -108,9 +146,125 @@ Status TopKOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
108146 std::string k_param_name = k_param.GetParamTensorName ();
109147 qnn_model_wrapper.AddParamWrapper (std::move (k_param));
110148 std::vector<std::string> param_tensor_names{k_param_name};
111- ORT_RETURN_IF_ERROR (ProcessOutputs (qnn_model_wrapper, node_unit, std::move (input_names),
112- std::move (param_tensor_names), logger, do_op_validation,
113- GetQnnOpType (node_unit.OpType ())));
149+
150+ // HTP only supports TopK at the last axis, and thus check whether extra Transpose is required.
151+ TensorInfo input_info = {};
152+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (node_unit.Inputs ()[0 ], input_info));
153+
154+ size_t input_rank = input_info.shape .size ();
155+ int32_t axis = NodeAttrHelper (node_unit).Get (" axis" , -1 );
156+ if (axis == -1 || axis == static_cast <int32_t >(input_rank - 1 )) {
157+ ORT_RETURN_IF_ERROR (ProcessOutputs (qnn_model_wrapper,
158+ node_unit,
159+ std::move (input_names),
160+ std::move (param_tensor_names),
161+ logger,
162+ do_op_validation,
163+ GetQnnOpType (node_unit.OpType ())));
164+ return Status::OK ();
165+ }
166+
167+ const auto & outputs = node_unit.Outputs ();
168+ std::vector<std::string> transpose_input_names;
169+ std::vector<std::vector<std::uint32_t >> transpose_input_shapes;
170+
171+ // Add TopK outputs.
172+ for (size_t output_idx = 0 ; output_idx < 2 ; ++output_idx) {
173+ const auto & output = outputs[output_idx];
174+
175+ // Since user may not be aware of the additional Transpose, the original output name of TopK node must be used by
176+ // the additional Transpose node which has the same output as original TopK node.
177+ const std::string& output_name = output.node_arg .Name ();
178+ std::string transpose_input_name = output_name + " _ort_qnn_ep_transpose" ;
179+ transpose_input_names.push_back (std::move (transpose_input_name));
180+
181+ // Since the input of TopK node is permuted, its output shape must be manually calculated.
182+ TensorInfo output_info = {};
183+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (output, output_info));
184+ size_t output_rank = output_info.shape .size ();
185+
186+ std::vector<uint32_t > transpose_input_shape = output_info.shape ;
187+ transpose_input_shape[output_rank - 1 ] = output_info.shape [axis];
188+ transpose_input_shape[axis] = output_info.shape [output_rank - 1 ];
189+ transpose_input_shapes.push_back (std::move (transpose_input_shape));
190+
191+ QnnTensorWrapper output_tensorwrapper (transpose_input_names[output_idx],
192+ QNN_TENSOR_TYPE_NATIVE,
193+ output_info.qnn_data_type ,
194+ output_info.quant_param .Copy (),
195+ std::vector<uint32_t >(transpose_input_shapes[output_idx]));
196+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (output_tensorwrapper)), " Failed to add tensor." );
197+ }
198+
199+ // Add TopK node.
200+ ORT_RETURN_IF_NOT (qnn_model_wrapper.CreateQnnNode (utils::GetNodeName (node_unit),
201+ QNN_OP_PACKAGE_NAME_QTI_AISW,
202+ GetQnnOpType (node_unit.OpType ()),
203+ std::move (input_names),
204+ std::vector<std::string>(transpose_input_names),
205+ std::move (param_tensor_names)),
206+ " Failed to add node." );
207+
208+ // Add Transpose nodes for each output to permute back.
209+ for (size_t output_idx = 0 ; output_idx < 2 ; ++output_idx) {
210+ const auto & output = outputs[output_idx];
211+ const std::string& output_name = output.node_arg .Name ();
212+
213+ TensorInfo output_info = {};
214+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (output, output_info));
215+ size_t output_rank = output_info.shape .size ();
216+
217+ std::vector<uint32_t > transpose_perm;
218+ ORT_RETURN_IF_ERROR (utils::GetPermToLastAxis (static_cast <uint32_t >(axis),
219+ static_cast <uint32_t >(output_rank),
220+ transpose_perm));
221+
222+ std::string transpose_output_name = output_name;
223+ bool is_graph_output = qnn_model_wrapper.IsGraphOutput (output_name);
224+
225+ // TopK's second output is indices which could be INT64 dtype, and QnnTensorWrapper directly changes the dtype to
226+ // INT32 during the wrapper construction. Nevertheless, if this output happens to be graph output, an additional
227+ // Cast must be added to cast dtype from INT32 back to INT64.
228+ bool is_cast_required = output_idx == 1 && output_info.qnn_data_type == QNN_DATATYPE_INT_64 && is_graph_output;
229+ std::string cast_input_name = " " ;
230+ if (is_cast_required) {
231+ cast_input_name = transpose_output_name + " _ort_qnn_ep_cast" ;
232+ // For the same reason described above, the original output name is now used by this Cast.
233+ transpose_output_name = cast_input_name;
234+ // Since additional Cast is added, below Transpose is no longer graph output.
235+ is_graph_output = false ;
236+ }
237+
238+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.AddTransposeNode (node_unit.Index (),
239+ transpose_input_names[output_idx],
240+ transpose_output_name,
241+ transpose_input_shapes[output_idx],
242+ transpose_perm,
243+ output_info.shape ,
244+ output_info.qnn_data_type ,
245+ output_info.quant_param ,
246+ do_op_validation,
247+ false ,
248+ is_graph_output));
249+
250+ if (is_cast_required) {
251+ QnnTensorWrapper cast_output_tensorwrapper (output_name,
252+ QNN_TENSOR_TYPE_APP_READ,
253+ output_info.qnn_data_type ,
254+ output_info.quant_param .Copy (),
255+ std::vector<uint32_t >(output_info.shape ));
256+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (cast_output_tensorwrapper)),
257+ " Failed to add tensor." );
258+ ORT_RETURN_IF_NOT (qnn_model_wrapper.CreateQnnNode (cast_input_name,
259+ QNN_OP_PACKAGE_NAME_QTI_AISW,
260+ " Cast" ,
261+ {cast_input_name},
262+ {output_name},
263+ {}),
264+ " Failed to add node" );
265+ }
266+ }
267+
114268 return Status::OK ();
115269}
116270
0 commit comments