@@ -150,6 +150,32 @@ TEST_F(QnnCPUBackendTests, EinsumRank2) {
150150 /* tolerance=*/ 1e-4f );
151151}
152152
153+ TEST_F (QnnCPUBackendTests, EinsumRank3MatMul) {
154+ const std::vector<int64_t > shape0{4 , 5 , 6 };
155+ const std::vector<int64_t > shape1{4 , 6 , 5 };
156+ const std::vector<float > data0 = GetSequentialFloatData (shape0, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
157+ const std::vector<float > data1 = GetSequentialFloatData (shape1, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
158+ RunQnnEinsum<float >(
159+ /* backend=*/ kQnnBackendTypeCpu ,
160+ /* in0=*/ TestInputDef<float >(shape0, /* is_initializer=*/ false , std::move (data0)),
161+ /* in1=*/ TestInputDef<float >(shape1, /* is_initializer=*/ false , std::move (data1)),
162+ /* equation=*/ " hij,hjk->hik" ,
163+ /* tolerance=*/ 1e-4f );
164+ }
165+
166+ TEST_F (QnnCPUBackendTests, EinsumRank3MatMul_QK) {
167+ const std::vector<int64_t > shape0{4 , 5 , 6 };
168+ const std::vector<int64_t > shape1{4 , 6 , 5 };
169+ const std::vector<float > data0 = GetSequentialFloatData (shape0, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
170+ const std::vector<float > data1 = GetSequentialFloatData (shape1, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
171+ RunQnnEinsum<float >(
172+ /* backend=*/ kQnnBackendTypeCpu ,
173+ /* in0=*/ TestInputDef<float >(shape0, /* is_initializer=*/ false , std::move (data0)),
174+ /* in1=*/ TestInputDef<float >(shape1, /* is_initializer=*/ false , std::move (data1)),
175+ /* equation=*/ " hQK,hKd->hQd" ,
176+ /* tolerance=*/ 1e-4f );
177+ }
178+
153179TEST_F (QnnCPUBackendTests, EinsumRank4MatMul) {
154180 const std::vector<int64_t > shape0{3 , 4 , 5 , 6 };
155181 const std::vector<int64_t > shape1{3 , 4 , 6 , 5 };
@@ -189,6 +215,19 @@ TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll1) {
189215 /* tolerance=*/ 1e-4f );
190216}
191217
218+ TEST_F (QnnCPUBackendTests, EinsumRank4MatMulTransposeY_QK) {
219+ const std::vector<int64_t > shape0{2 , 3 , 4 , 6 };
220+ const std::vector<int64_t > shape1{2 , 3 , 5 , 6 };
221+ const std::vector<float > data0 = GetSequentialFloatData (shape0, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
222+ const std::vector<float > data1 = GetSequentialFloatData (shape1, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
223+ RunQnnEinsum<float >(
224+ /* backend=*/ kQnnBackendTypeCpu ,
225+ /* in0=*/ TestInputDef<float >(shape0, /* is_initializer=*/ false , std::move (data0)),
226+ /* in1=*/ TestInputDef<float >(shape1, /* is_initializer=*/ false , std::move (data1)),
227+ /* equation=*/ " bnQd,bnKd->bnQK" ,
228+ /* tolerance=*/ 1e-4f );
229+ }
230+
192231TEST_F (QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) {
193232 const std::vector<int64_t > shape0{1 , 7 , 1 , 7 };
194233 const std::vector<int64_t > shape1{1 , 9 , 1 , 7 };
@@ -273,6 +312,60 @@ TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeY) {
273312 /* tolerance=*/ 1e-2f );
274313}
275314
315+ TEST_F (QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeY_QK) {
316+ const std::vector<int64_t > shape0{2 , 3 , 4 , 2 };
317+ const std::vector<int64_t > shape1{2 , 3 , 5 , 2 };
318+ const std::vector<float > data0 = GetSequentialFloatData (shape0, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
319+ const std::vector<float > data1 = GetSequentialFloatData (shape1, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
320+ RunQnnEinsum<float >(
321+ /* backend=*/ kQnnBackendTypeHtp ,
322+ /* in0=*/ TestInputDef<float >(shape0, /* is_initializer=*/ false , std::move (data0)),
323+ /* in1=*/ TestInputDef<float >(shape1, /* is_initializer=*/ false , std::move (data1)),
324+ /* equation=*/ " bnQd,bnKd->bnQK" ,
325+ /* tolerance=*/ 1e-2f );
326+ }
327+
328+ TEST_F (QnnHTPBackendTests, EinsumRank3MatMulTransposeY) {
329+ const std::vector<int64_t > shape0{2 , 4 , 2 };
330+ const std::vector<int64_t > shape1{2 , 5 , 2 };
331+ const std::vector<float > data0 = GetSequentialFloatData (shape0, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
332+ const std::vector<float > data1 = GetSequentialFloatData (shape1, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
333+ RunQnnEinsum<float >(
334+ /* backend=*/ kQnnBackendTypeHtp ,
335+ /* in0=*/ TestInputDef<float >(shape0, /* is_initializer=*/ false , std::move (data0)),
336+ /* in1=*/ TestInputDef<float >(shape1, /* is_initializer=*/ false , std::move (data1)),
337+ /* equation=*/ " bid,bjd->bij" ,
338+ /* tolerance=*/ 1e-2f );
339+ }
340+
341+ TEST_F (QnnHTPBackendTests, EinsumRank3MatMulTransposeY_QK) {
342+ const std::vector<int64_t > shape0{2 , 4 , 2 };
343+ const std::vector<int64_t > shape1{2 , 5 , 2 };
344+ const std::vector<float > data0 = GetSequentialFloatData (shape0, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
345+ const std::vector<float > data1 = GetSequentialFloatData (shape1, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
346+ RunQnnEinsum<float >(
347+ /* backend=*/ kQnnBackendTypeHtp ,
348+ /* in0=*/ TestInputDef<float >(shape0, /* is_initializer=*/ false , std::move (data0)),
349+ /* in1=*/ TestInputDef<float >(shape1, /* is_initializer=*/ false , std::move (data1)),
350+ /* equation=*/ " bQd,bKd->bQK" ,
351+ /* tolerance=*/ 1e-2f );
352+ }
353+
354+ // The value pair (65.1049271, 65.0625076) at index #51 don't match, which is -0.0424194 from 65.1049
355+ // Disable this Rank3 test on HTP since it has accuracy issue.
356+ TEST_F (QnnHTPBackendTests, DISABLED_EinsumRank3MatMul_QK) {
357+ const std::vector<int64_t > shape0{4 , 5 , 6 };
358+ const std::vector<int64_t > shape1{4 , 6 , 5 };
359+ const std::vector<float > data0 = GetSequentialFloatData (shape0, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
360+ const std::vector<float > data1 = GetSequentialFloatData (shape1, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
361+ RunQnnEinsum<float >(
362+ /* backend=*/ kQnnBackendTypeHtp ,
363+ /* in0=*/ TestInputDef<float >(shape0, /* is_initializer=*/ false , std::move (data0)),
364+ /* in1=*/ TestInputDef<float >(shape1, /* is_initializer=*/ false , std::move (data1)),
365+ /* equation=*/ " hQK,hKd->hQd" ,
366+ /* tolerance=*/ 1e-2f );
367+ }
368+
276369TEST_F (QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeAll1) {
277370 const std::vector<int64_t > shape0{1 , 3 , 1 , 7 };
278371 const std::vector<int64_t > shape1{1 , 7 , 1 , 3 };
@@ -365,6 +458,66 @@ TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeY) {
365458 /* tolerance=*/ QDQTolerance ());
366459}
367460
461+ TEST_F (QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeY_QK) {
462+ const std::vector<int64_t > shape0{2 , 3 , 4 , 2 };
463+ const std::vector<int64_t > shape1{2 , 3 , 5 , 2 };
464+ const std::vector<float > data0 = GetSequentialFloatData (shape0, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
465+ const std::vector<float > data1 = GetSequentialFloatData (shape1, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
466+ RunQnnHtpQdqEinsum<uint8_t , uint8_t >(
467+ /* in0=*/ TestInputDef<float >(shape0, /* is_initializer=*/ false , std::move (data0)),
468+ /* in1=*/ TestInputDef<float >(shape1, /* is_initializer=*/ false , std::move (data1)),
469+ /* equation=*/ " bnQd,bnKd->bnQK" ,
470+ /* tolerance=*/ QDQTolerance ());
471+ }
472+
473+ TEST_F (QnnHTPBackendTests, EinsumQdqRank3MatMulTransposeY) {
474+ const std::vector<int64_t > shape0{2 , 4 , 2 };
475+ const std::vector<int64_t > shape1{2 , 5 , 2 };
476+ const std::vector<float > data0 = GetSequentialFloatData (shape0, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
477+ const std::vector<float > data1 = GetSequentialFloatData (shape1, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
478+ RunQnnHtpQdqEinsum<uint8_t , uint8_t >(
479+ /* in0=*/ TestInputDef<float >(shape0, /* is_initializer=*/ false , std::move (data0)),
480+ /* in1=*/ TestInputDef<float >(shape1, /* is_initializer=*/ false , std::move (data1)),
481+ /* equation=*/ " bid,bjd->bij" ,
482+ /* tolerance=*/ QDQTolerance ());
483+ }
484+
485+ TEST_F (QnnHTPBackendTests, EinsumQdqRank3MatMulTransposeY_QK) {
486+ const std::vector<int64_t > shape0{2 , 4 , 2 };
487+ const std::vector<int64_t > shape1{2 , 5 , 2 };
488+ const std::vector<float > data0 = GetSequentialFloatData (shape0, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
489+ const std::vector<float > data1 = GetSequentialFloatData (shape1, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
490+ RunQnnHtpQdqEinsum<uint8_t , uint8_t >(
491+ /* in0=*/ TestInputDef<float >(shape0, /* is_initializer=*/ false , std::move (data0)),
492+ /* in1=*/ TestInputDef<float >(shape1, /* is_initializer=*/ false , std::move (data1)),
493+ /* equation=*/ " bQd,bKd->bQK" ,
494+ /* tolerance=*/ QDQTolerance ());
495+ }
496+
497+ TEST_F (QnnHTPBackendTests, EinsumQdqRank3MatMul) {
498+ const std::vector<int64_t > shape0{4 , 5 , 6 };
499+ const std::vector<int64_t > shape1{4 , 6 , 5 };
500+ const std::vector<float > data0 = GetSequentialFloatData (shape0, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
501+ const std::vector<float > data1 = GetSequentialFloatData (shape1, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
502+ RunQnnHtpQdqEinsum<uint8_t , uint8_t >(
503+ /* in0=*/ TestInputDef<float >(shape0, /* is_initializer=*/ false , std::move (data0)),
504+ /* in1=*/ TestInputDef<float >(shape1, /* is_initializer=*/ false , std::move (data1)),
505+ /* equation=*/ " hij,hjk->hik" ,
506+ /* tolerance=*/ QDQTolerance ());
507+ }
508+
509+ TEST_F (QnnHTPBackendTests, EinsumQdqRank3MatMul_QK) {
510+ const std::vector<int64_t > shape0{4 , 5 , 6 };
511+ const std::vector<int64_t > shape1{4 , 6 , 5 };
512+ const std::vector<float > data0 = GetSequentialFloatData (shape0, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
513+ const std::vector<float > data1 = GetSequentialFloatData (shape1, /* start=*/ -0 .1f , /* step=*/ 0 .05f );
514+ RunQnnHtpQdqEinsum<uint8_t , uint8_t >(
515+ /* in0=*/ TestInputDef<float >(shape0, /* is_initializer=*/ false , std::move (data0)),
516+ /* in1=*/ TestInputDef<float >(shape1, /* is_initializer=*/ false , std::move (data1)),
517+ /* equation=*/ " hQK,hKd->hQd" ,
518+ /* tolerance=*/ QDQTolerance ());
519+ }
520+
368521TEST_F (QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeAll1) {
369522 const std::vector<int64_t > shape0{1 , 3 , 1 , 7 };
370523 const std::vector<int64_t > shape1{1 , 7 , 1 , 3 };
0 commit comments