Skip to content

Commit c22f70d

Browse files
authored
[QNN-EP] Einsum equation ReduceSum Multiply on broadcast X (microsoft#25581)
[QNN-EP] Einsum equation ReduceSum Multiply on broadcast X
1 parent a89b038 commit c22f70d

File tree

2 files changed

+225
-5
lines changed

2 files changed

+225
-5
lines changed

onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc

Lines changed: 168 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,50 @@ bool IsEquationMatMulBroadcastTransposeY(const Equation& equation) {
192192
return true;
193193
}
194194

195+
bool IsEquationReduceSumMulBroadcastX(const Equation& equation) {
196+
// E.g., bhwc,wkc->bhwk
197+
const auto& [term_1, term_2, result] = equation;
198+
if (term_1.size() != 4) {
199+
return false;
200+
}
201+
if (term_2.size() != 3) {
202+
return false;
203+
}
204+
if (result.size() != 4) {
205+
return false;
206+
}
207+
208+
// Check contraction over last axis (c)
209+
char c1 = term_1[3];
210+
char c2 = term_2[2];
211+
if (c1 != c2) {
212+
return false;
213+
}
214+
215+
// Check w axis alignment
216+
if (term_1[2] != term_2[0]) {
217+
return false;
218+
}
219+
if (term_1[2] != result[2]) {
220+
return false;
221+
}
222+
223+
// Check k axis alignment
224+
if (term_2[1] != result[3]) {
225+
return false;
226+
}
227+
228+
// Check batch dimensions
229+
if (term_1[0] != result[0]) {
230+
return false;
231+
}
232+
if (term_1[1] != result[1]) {
233+
return false;
234+
}
235+
236+
return true;
237+
}
238+
195239
/**
196240
* @brief Sets the parameter tensor names for a MatMul op.
197241
*
@@ -305,6 +349,113 @@ Status CreateMatMulTransposeAll(
305349
return Status::OK();
306350
}
307351

352+
/**
353+
* @brief Creates a ReduceSum, Multiply on broadcasted input X and original input Y.
354+
*
355+
* @param qnn_model_wrapper Pointer to the QnnModelWrapper instance used to manage the QNN model.
356+
* @param node_unit The NodeUnit representing the ONNX node to be converted.
357+
* @param do_op_validation A boolean flag indicating whether to perform operation validation.
358+
* @return Status indicating success or failure of the operation.
359+
*/
360+
Status CreateReduceSumMulBroadcastX(
361+
onnxruntime::qnn::QnnModelWrapper* qnn_model_wrapper,
362+
const onnxruntime::NodeUnit& node_unit,
363+
std::vector<std::string>&& input_names,
364+
bool do_op_validation) {
365+
// Reshape in0 to shape (b, h, w, 1, c) to expand dimension before the contraction axis 'c'.
366+
// Allowing broadcast with in1 for multiplication, aligning the contraction axis for reduce.
367+
onnxruntime::qnn::TensorInfo tensor_info_in0{}, tensor_info_in1{}, tensor_info_out{};
368+
ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Inputs()[0], tensor_info_in0));
369+
ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Inputs()[1], tensor_info_in1));
370+
ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Outputs()[0], tensor_info_out));
371+
const std::vector<uint32_t>& shape_in0 = tensor_info_in0.shape;
372+
const std::vector<uint32_t>& shape_in1 = tensor_info_in1.shape;
373+
ORT_RETURN_IF_NOT(shape_in0.size() == 4, "CreateReduceSumMulBroadcastX expects input 0 to be rank 4");
374+
ORT_RETURN_IF_NOT(shape_in1.size() == 3, "CreateReduceSumMulBroadcastX expects input 1 to be rank 3");
375+
const std::vector<uint32_t> new_shape_in0{shape_in0[0], shape_in0[1], shape_in0[2], 1, shape_in0[3]};
376+
const std::string reshape_out_name = input_names[0] + "_reshaped";
377+
ORT_RETURN_IF_ERROR(qnn_model_wrapper->AddReshapeNode(
378+
/*input_name=*/input_names[0],
379+
/*output_name=*/reshape_out_name,
380+
/*input_shape=*/shape_in0,
381+
/*output_shape=*/new_shape_in0,
382+
/*tensor_data_type=*/tensor_info_in0.qnn_data_type,
383+
/*quantize_param=*/tensor_info_in0.quant_param.Copy(),
384+
/*do_op_validation=*/do_op_validation,
385+
/*is_for_input=*/qnn_model_wrapper->IsGraphInput(input_names[0])));
386+
387+
// Multiply: reshaped in0 * in1
388+
// The output shape of the multiplication is determined by broadcasting the reshaped in0 of
389+
// (b, h, w, 1, c) and in1 (w, k, c) along the matching axes, resulting in (b, h, w, k, c).
390+
const std::string mul_out_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_mul";
391+
std::vector<uint32_t> shape_out_mul{new_shape_in0[0], new_shape_in0[1], new_shape_in0[2], shape_in1[1], new_shape_in0[4]};
392+
onnxruntime::qnn::QnnTensorWrapper tensor_wrapper_mul(mul_out_name,
393+
QNN_TENSOR_TYPE_NATIVE,
394+
tensor_info_in0.qnn_data_type,
395+
tensor_info_in0.quant_param.Copy(),
396+
std::move(shape_out_mul));
397+
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(tensor_wrapper_mul)),
398+
"CreateReduceSumMulBroadcastX: failed to AddTensorWrapper");
399+
ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(
400+
/*qnn_node_name=*/mul_out_name,
401+
/*package_name=*/QNN_OP_PACKAGE_NAME_QTI_AISW,
402+
/*qnn_node_type=*/QNN_OP_ELEMENT_WISE_MULTIPLY,
403+
/*input_names=*/{reshape_out_name, input_names[1]},
404+
/*output_names=*/{mul_out_name},
405+
/*param_tensor_names=*/{},
406+
/*do_op_validation=*/do_op_validation),
407+
"CreateReduceSumMulBroadcastX: failed to create Mul node");
408+
409+
std::vector<std::string> param_tensor_names{};
410+
411+
// ReduceSum on last axes={4}, keep_dims=False
412+
// Axis '4' corresponds to the last dimension ('c') of the reshaped tensor (b, h, w, k, c),
413+
// which is the contraction axis for reduce sum op in the einsum equation (bhwc,wkc->bhwk).
414+
std::vector<uint32_t> axes_shape{SafeInt<uint32_t>(1)};
415+
std::vector<uint32_t> axes_value{SafeInt<uint32_t>(4)};
416+
onnxruntime::qnn::QnnParamWrapper param_axes(node_unit.Index(),
417+
node_unit.Name(),
418+
QNN_OP_REDUCE_SUM_PARAM_AXES,
419+
std::move(axes_shape),
420+
std::move(axes_value));
421+
param_tensor_names.push_back(param_axes.GetParamTensorName());
422+
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_axes)),
423+
"CreateReduceSumMulBroadcastX: failed to add param axes");
424+
425+
Qnn_Scalar_t keep_dims_scalar = QNN_SCALAR_INIT;
426+
keep_dims_scalar.dataType = QNN_DATATYPE_BOOL_8;
427+
keep_dims_scalar.bool8Value = SafeInt<uint8_t>(0);
428+
onnxruntime::qnn::QnnParamWrapper param_keep_dims(node_unit.Index(),
429+
node_unit.Name(),
430+
QNN_OP_REDUCE_SUM_PARAM_KEEP_DIMS,
431+
keep_dims_scalar);
432+
param_tensor_names.push_back(param_keep_dims.GetParamTensorName());
433+
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_keep_dims)),
434+
"CreateReduceSumMulBroadcastX: failed to add param keep_dims");
435+
436+
const std::string out_name = node_unit.Outputs()[0].node_arg.Name();
437+
Qnn_TensorType_t out_tensor_type = qnn_model_wrapper->IsGraphOutput(out_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
438+
onnxruntime::qnn::QnnTensorWrapper tensor_wrapper_out(out_name,
439+
out_tensor_type,
440+
tensor_info_out.qnn_data_type,
441+
tensor_info_out.quant_param.Copy(),
442+
std::move(tensor_info_out.shape));
443+
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(tensor_wrapper_out)),
444+
"CreateReduceSumMulBroadcastX: failed to AddTensorWrapper");
445+
446+
ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(
447+
/*qnn_node_name=*/out_name,
448+
/*package_name=*/QNN_OP_PACKAGE_NAME_QTI_AISW,
449+
/*qnn_node_type=*/QNN_OP_REDUCE_SUM,
450+
/*input_names=*/{mul_out_name},
451+
/*output_names=*/{out_name},
452+
/*param_tensor_names=*/std::move(param_tensor_names),
453+
/*do_op_validation=*/do_op_validation),
454+
"CreateReduceSumMulBroadcastX: failed to create ReduceSum node");
455+
456+
return Status::OK();
457+
}
458+
308459
} // namespace
309460

310461
namespace onnxruntime {
@@ -356,9 +507,20 @@ Status EinsumOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
356507
if (!IsEquationMatMul(parsed_equation.value()) &&
357508
!IsEquationMatMulTransposeY(parsed_equation.value()) &&
358509
!IsEquationMatMulBroadcastTransposeY(parsed_equation.value()) &&
359-
!IsEquationMatMulTransposeAll(parsed_equation.value())) {
510+
!IsEquationMatMulTransposeAll(parsed_equation.value()) &&
511+
!IsEquationReduceSumMulBroadcastX(parsed_equation.value())) {
360512
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation);
361513
}
514+
if (IsEquationReduceSumMulBroadcastX(parsed_equation.value())) {
515+
if (IsGpuBackend(qnn_model_wrapper.GetQnnBackendType())) {
516+
// QAIRT 3.36.1: Failed to validate on GPU.
517+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation + " on backend GPU");
518+
}
519+
if (node_unit.Inputs()[0].quant_param.has_value()) {
520+
// QAIRT 3.36.1: Failed to finalize QNN graph 1002.
521+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation + " for quantized inputs");
522+
}
523+
}
362524
return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true);
363525
}
364526

@@ -408,6 +570,11 @@ Status EinsumOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
408570
/*node_unit=*/node_unit,
409571
/*input_names=*/std::move(input_names),
410572
/*do_op_validation=*/do_op_validation));
573+
} else if (IsEquationReduceSumMulBroadcastX(parsed_equation.value())) {
574+
ORT_RETURN_IF_ERROR(CreateReduceSumMulBroadcastX(/*qnn_model_wrapper=*/&qnn_model_wrapper,
575+
/*node_unit=*/node_unit,
576+
/*input_names=*/std::move(input_names),
577+
/*do_op_validation=*/do_op_validation));
411578
} else {
412579
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation);
413580
}

onnxruntime/test/providers/qnn/einsum_op_test.cc

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,19 @@ TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll1) {
189189
/*tolerance=*/1e-4f);
190190
}
191191

192+
TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) {
193+
const std::vector<int64_t> shape0{1, 7, 1, 7};
194+
const std::vector<int64_t> shape1{1, 9, 1, 7};
195+
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
196+
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
197+
RunQnnEinsum<float>(
198+
/*backend=*/kQnnBackendTypeCpu,
199+
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
200+
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
201+
/*equation=*/"bkhq,bchk->bchq",
202+
/*tolerance=*/1e-4f);
203+
}
204+
192205
TEST_F(QnnCPUBackendTests, EinsumMatMulBroadcastTransposeY) {
193206
const std::vector<int64_t> shape0{2, 3, 3, 4};
194207
const std::vector<int64_t> shape1{3, 3, 4};
@@ -202,16 +215,16 @@ TEST_F(QnnCPUBackendTests, EinsumMatMulBroadcastTransposeY) {
202215
/*tolerance=*/1e-4f);
203216
}
204217

205-
TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) {
206-
const std::vector<int64_t> shape0{1, 7, 1, 7};
207-
const std::vector<int64_t> shape1{1, 9, 1, 7};
218+
TEST_F(QnnCPUBackendTests, EinsumReduceSumMulBroadcastX) {
219+
const std::vector<int64_t> shape0{2, 3, 4, 5};
220+
const std::vector<int64_t> shape1{4, 6, 5};
208221
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
209222
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
210223
RunQnnEinsum<float>(
211224
/*backend=*/kQnnBackendTypeCpu,
212225
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
213226
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
214-
/*equation=*/"bkhq,bchk->bchq",
227+
/*equation=*/"bhwc,wkc->bhwk",
215228
/*tolerance=*/1e-4f);
216229
}
217230

@@ -299,6 +312,19 @@ TEST_F(QnnHTPBackendTests, EinsumF16MatMulBroadcastTransposeY) {
299312
/*tolerance=*/1e-2f);
300313
}
301314

315+
TEST_F(QnnHTPBackendTests, EinsumF16ReduceSumMulBroadcastX) {
316+
const std::vector<int64_t> shape0{1, 3, 2, 4};
317+
const std::vector<int64_t> shape1{2, 3, 4};
318+
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
319+
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
320+
RunQnnEinsum<float>(
321+
/*backend=*/kQnnBackendTypeHtp,
322+
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
323+
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
324+
/*equation=*/"bhwc,wkc->bhwk",
325+
/*tolerance=*/1e-2f);
326+
}
327+
302328
//
303329
// QNN HTP QDQ
304330
//
@@ -375,6 +401,19 @@ TEST_F(QnnHTPBackendTests, EinsumQdqMatMulBroadcastTransposeY) {
375401
/*tolerance=*/QDQTolerance());
376402
}
377403

404+
// TODO: Re-enable. QAIRT 3.36.1: failed to finalize QNN graph 1002.
405+
TEST_F(QnnHTPBackendTests, DISABLED_EinsumQdqReduceSumMulBroadcastX) {
406+
const std::vector<int64_t> shape0{1, 3, 2, 4};
407+
const std::vector<int64_t> shape1{2, 3, 4};
408+
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
409+
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
410+
RunQnnHtpQdqEinsum<uint8_t, uint8_t>(
411+
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
412+
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
413+
/*equation=*/"bhwc,wkc->bhwk",
414+
/*tolerance=*/QDQTolerance());
415+
}
416+
378417
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
379418

380419
#if defined(_M_ARM64)
@@ -474,6 +513,20 @@ TEST_F(QnnGPUBackendTests, DISABLED_EinsumMatMulBroadcastTransposeY) {
474513
/*tolerance=*/1e-4f);
475514
}
476515

516+
// TODO: Re-enable. Failed on QAIRT 3.36.1.
517+
TEST_F(QnnGPUBackendTests, DISABLED_EinsumReduceSumMulBroadcastX) {
518+
const std::vector<int64_t> shape0{1, 3, 2, 4};
519+
const std::vector<int64_t> shape1{2, 3, 4};
520+
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
521+
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
522+
RunQnnEinsum<float>(
523+
/*backend=*/kQnnBackendTypeGpu,
524+
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
525+
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
526+
/*equation=*/"bhwc,wkc->bhwk",
527+
/*tolerance=*/1e-4f);
528+
}
529+
477530
#endif // defined(_M_ARM64) GPU tests
478531

479532
} // namespace test

0 commit comments

Comments
 (0)