@@ -192,6 +192,50 @@ bool IsEquationMatMulBroadcastTransposeY(const Equation& equation) {
192192 return true ;
193193}
194194
195+ bool IsEquationReduceSumMulBroadcastX (const Equation& equation) {
196+ // E.g., bhwc,wkc->bhwk
197+ const auto & [term_1, term_2, result] = equation;
198+ if (term_1.size () != 4 ) {
199+ return false ;
200+ }
201+ if (term_2.size () != 3 ) {
202+ return false ;
203+ }
204+ if (result.size () != 4 ) {
205+ return false ;
206+ }
207+
208+ // Check contraction over last axis (c)
209+ char c1 = term_1[3 ];
210+ char c2 = term_2[2 ];
211+ if (c1 != c2) {
212+ return false ;
213+ }
214+
215+ // Check w axis alignment
216+ if (term_1[2 ] != term_2[0 ]) {
217+ return false ;
218+ }
219+ if (term_1[2 ] != result[2 ]) {
220+ return false ;
221+ }
222+
223+ // Check k axis alignment
224+ if (term_2[1 ] != result[3 ]) {
225+ return false ;
226+ }
227+
228+ // Check batch dimensions
229+ if (term_1[0 ] != result[0 ]) {
230+ return false ;
231+ }
232+ if (term_1[1 ] != result[1 ]) {
233+ return false ;
234+ }
235+
236+ return true ;
237+ }
238+
195239/* *
196240 * @brief Sets the parameter tensor names for a MatMul op.
197241 *
@@ -305,6 +349,113 @@ Status CreateMatMulTransposeAll(
305349 return Status::OK ();
306350}
307351
352+ /* *
353+ * @brief Creates a ReduceSum, Multiply on broadcasted input X and original input Y.
354+ *
355+ * @param qnn_model_wrapper Pointer to the QnnModelWrapper instance used to manage the QNN model.
356+ * @param node_unit The NodeUnit representing the ONNX node to be converted.
357+ * @param do_op_validation A boolean flag indicating whether to perform operation validation.
358+ * @return Status indicating success or failure of the operation.
359+ */
360+ Status CreateReduceSumMulBroadcastX (
361+ onnxruntime::qnn::QnnModelWrapper* qnn_model_wrapper,
362+ const onnxruntime::NodeUnit& node_unit,
363+ std::vector<std::string>&& input_names,
364+ bool do_op_validation) {
365+ // Reshape in0 to shape (b, h, w, 1, c) to expand dimension before the contraction axis 'c'.
366+ // Allowing broadcast with in1 for multiplication, aligning the contraction axis for reduce.
367+ onnxruntime::qnn::TensorInfo tensor_info_in0{}, tensor_info_in1{}, tensor_info_out{};
368+ ORT_RETURN_IF_ERROR (qnn_model_wrapper->GetTensorInfo (node_unit.Inputs ()[0 ], tensor_info_in0));
369+ ORT_RETURN_IF_ERROR (qnn_model_wrapper->GetTensorInfo (node_unit.Inputs ()[1 ], tensor_info_in1));
370+ ORT_RETURN_IF_ERROR (qnn_model_wrapper->GetTensorInfo (node_unit.Outputs ()[0 ], tensor_info_out));
371+ const std::vector<uint32_t >& shape_in0 = tensor_info_in0.shape ;
372+ const std::vector<uint32_t >& shape_in1 = tensor_info_in1.shape ;
373+ ORT_RETURN_IF_NOT (shape_in0.size () == 4 , " CreateReduceSumMulBroadcastX expects input 0 to be rank 4" );
374+ ORT_RETURN_IF_NOT (shape_in1.size () == 3 , " CreateReduceSumMulBroadcastX expects input 1 to be rank 3" );
375+ const std::vector<uint32_t > new_shape_in0{shape_in0[0 ], shape_in0[1 ], shape_in0[2 ], 1 , shape_in0[3 ]};
376+ const std::string reshape_out_name = input_names[0 ] + " _reshaped" ;
377+ ORT_RETURN_IF_ERROR (qnn_model_wrapper->AddReshapeNode (
378+ /* input_name=*/ input_names[0 ],
379+ /* output_name=*/ reshape_out_name,
380+ /* input_shape=*/ shape_in0,
381+ /* output_shape=*/ new_shape_in0,
382+ /* tensor_data_type=*/ tensor_info_in0.qnn_data_type ,
383+ /* quantize_param=*/ tensor_info_in0.quant_param .Copy (),
384+ /* do_op_validation=*/ do_op_validation,
385+ /* is_for_input=*/ qnn_model_wrapper->IsGraphInput (input_names[0 ])));
386+
387+ // Multiply: reshaped in0 * in1
388+ // The output shape of the multiplication is determined by broadcasting the reshaped in0 of
389+ // (b, h, w, 1, c) and in1 (w, k, c) along the matching axes, resulting in (b, h, w, k, c).
390+ const std::string mul_out_name = onnxruntime::qnn::utils::GetNodeName (node_unit) + " _mul" ;
391+ std::vector<uint32_t > shape_out_mul{new_shape_in0[0 ], new_shape_in0[1 ], new_shape_in0[2 ], shape_in1[1 ], new_shape_in0[4 ]};
392+ onnxruntime::qnn::QnnTensorWrapper tensor_wrapper_mul (mul_out_name,
393+ QNN_TENSOR_TYPE_NATIVE,
394+ tensor_info_in0.qnn_data_type ,
395+ tensor_info_in0.quant_param .Copy (),
396+ std::move (shape_out_mul));
397+ ORT_RETURN_IF_NOT (qnn_model_wrapper->AddTensorWrapper (std::move (tensor_wrapper_mul)),
398+ " CreateReduceSumMulBroadcastX: failed to AddTensorWrapper" );
399+ ORT_RETURN_IF_NOT (qnn_model_wrapper->CreateQnnNode (
400+ /* qnn_node_name=*/ mul_out_name,
401+ /* package_name=*/ QNN_OP_PACKAGE_NAME_QTI_AISW,
402+ /* qnn_node_type=*/ QNN_OP_ELEMENT_WISE_MULTIPLY,
403+ /* input_names=*/ {reshape_out_name, input_names[1 ]},
404+ /* output_names=*/ {mul_out_name},
405+ /* param_tensor_names=*/ {},
406+ /* do_op_validation=*/ do_op_validation),
407+ " CreateReduceSumMulBroadcastX: failed to create Mul node" );
408+
409+ std::vector<std::string> param_tensor_names{};
410+
411+ // ReduceSum on last axes={4}, keep_dims=False
412+ // Axis '4' corresponds to the last dimension ('c') of the reshaped tensor (b, h, w, k, c),
413+ // which is the contraction axis for reduce sum op in the einsum equation (bhwc,wkc->bhwk).
414+ std::vector<uint32_t > axes_shape{SafeInt<uint32_t >(1 )};
415+ std::vector<uint32_t > axes_value{SafeInt<uint32_t >(4 )};
416+ onnxruntime::qnn::QnnParamWrapper param_axes (node_unit.Index (),
417+ node_unit.Name (),
418+ QNN_OP_REDUCE_SUM_PARAM_AXES,
419+ std::move (axes_shape),
420+ std::move (axes_value));
421+ param_tensor_names.push_back (param_axes.GetParamTensorName ());
422+ ORT_RETURN_IF_NOT (qnn_model_wrapper->AddParamWrapper (std::move (param_axes)),
423+ " CreateReduceSumMulBroadcastX: failed to add param axes" );
424+
425+ Qnn_Scalar_t keep_dims_scalar = QNN_SCALAR_INIT;
426+ keep_dims_scalar.dataType = QNN_DATATYPE_BOOL_8;
427+ keep_dims_scalar.bool8Value = SafeInt<uint8_t >(0 );
428+ onnxruntime::qnn::QnnParamWrapper param_keep_dims (node_unit.Index (),
429+ node_unit.Name (),
430+ QNN_OP_REDUCE_SUM_PARAM_KEEP_DIMS,
431+ keep_dims_scalar);
432+ param_tensor_names.push_back (param_keep_dims.GetParamTensorName ());
433+ ORT_RETURN_IF_NOT (qnn_model_wrapper->AddParamWrapper (std::move (param_keep_dims)),
434+ " CreateReduceSumMulBroadcastX: failed to add param keep_dims" );
435+
436+ const std::string out_name = node_unit.Outputs ()[0 ].node_arg .Name ();
437+ Qnn_TensorType_t out_tensor_type = qnn_model_wrapper->IsGraphOutput (out_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
438+ onnxruntime::qnn::QnnTensorWrapper tensor_wrapper_out (out_name,
439+ out_tensor_type,
440+ tensor_info_out.qnn_data_type ,
441+ tensor_info_out.quant_param .Copy (),
442+ std::move (tensor_info_out.shape ));
443+ ORT_RETURN_IF_NOT (qnn_model_wrapper->AddTensorWrapper (std::move (tensor_wrapper_out)),
444+ " CreateReduceSumMulBroadcastX: failed to AddTensorWrapper" );
445+
446+ ORT_RETURN_IF_NOT (qnn_model_wrapper->CreateQnnNode (
447+ /* qnn_node_name=*/ out_name,
448+ /* package_name=*/ QNN_OP_PACKAGE_NAME_QTI_AISW,
449+ /* qnn_node_type=*/ QNN_OP_REDUCE_SUM,
450+ /* input_names=*/ {mul_out_name},
451+ /* output_names=*/ {out_name},
452+ /* param_tensor_names=*/ std::move (param_tensor_names),
453+ /* do_op_validation=*/ do_op_validation),
454+ " CreateReduceSumMulBroadcastX: failed to create ReduceSum node" );
455+
456+ return Status::OK ();
457+ }
458+
308459} // namespace
309460
310461namespace onnxruntime {
@@ -356,9 +507,20 @@ Status EinsumOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
356507 if (!IsEquationMatMul (parsed_equation.value ()) &&
357508 !IsEquationMatMulTransposeY (parsed_equation.value ()) &&
358509 !IsEquationMatMulBroadcastTransposeY (parsed_equation.value ()) &&
359- !IsEquationMatMulTransposeAll (parsed_equation.value ())) {
510+ !IsEquationMatMulTransposeAll (parsed_equation.value ()) &&
511+ !IsEquationReduceSumMulBroadcastX (parsed_equation.value ())) {
360512 return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, node_unit.OpType () + " unsupported equation: " + equation);
361513 }
514+ if (IsEquationReduceSumMulBroadcastX (parsed_equation.value ())) {
515+ if (IsGpuBackend (qnn_model_wrapper.GetQnnBackendType ())) {
516+ // QAIRT 3.36.1: Failed to validate on GPU.
517+ return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, node_unit.OpType () + " unsupported equation: " + equation + " on backend GPU" );
518+ }
519+ if (node_unit.Inputs ()[0 ].quant_param .has_value ()) {
520+ // QAIRT 3.36.1: Failed to finalize QNN graph 1002.
521+ return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, node_unit.OpType () + " unsupported equation: " + equation + " for quantized inputs" );
522+ }
523+ }
362524 return AddToModelBuilder (qnn_model_wrapper, node_unit, logger, true );
363525}
364526
@@ -408,6 +570,11 @@ Status EinsumOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
408570 /* node_unit=*/ node_unit,
409571 /* input_names=*/ std::move (input_names),
410572 /* do_op_validation=*/ do_op_validation));
573+ } else if (IsEquationReduceSumMulBroadcastX (parsed_equation.value ())) {
574+ ORT_RETURN_IF_ERROR (CreateReduceSumMulBroadcastX (/* qnn_model_wrapper=*/ &qnn_model_wrapper,
575+ /* node_unit=*/ node_unit,
576+ /* input_names=*/ std::move (input_names),
577+ /* do_op_validation=*/ do_op_validation));
411578 } else {
412579 return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, node_unit.OpType () + " unsupported equation: " + equation);
413580 }
0 commit comments