Skip to content

Commit c9c23b0

Browse files
authored
[QNN-EP] Enable einsum with QK equations for QNN. (microsoft#25861)
### Description Enable einsum op with QK equations for attention in QNN EP. ### Motivation and Context Current einsum op in QNN doesn't support equations with capital alphabets. Loose this constraint to allow more usecases. Signed-off-by: Mu-Chein Hsu <quic_muchhsu@quicinc.com>
1 parent 16ae99e commit c9c23b0

File tree

2 files changed

+156
-3
lines changed

2 files changed

+156
-3
lines changed

onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ std::optional<Equation> ParseEquation(std::string_view equation_string) {
4848
if (term_1.size() < 2 || term_2.size() < 2 || result.size() < 2) {
4949
return std::nullopt;
5050
}
51-
if (!std::all_of(term_1.begin(), term_1.end(), [](unsigned char c) { return std::islower(c); })) {
51+
if (!std::all_of(term_1.begin(), term_1.end(), [](unsigned char c) { return std::isalpha(c); })) {
5252
return std::nullopt;
5353
}
54-
if (!std::all_of(term_2.begin(), term_2.end(), [](unsigned char c) { return std::islower(c); })) {
54+
if (!std::all_of(term_2.begin(), term_2.end(), [](unsigned char c) { return std::isalpha(c); })) {
5555
return std::nullopt;
5656
}
57-
if (!std::all_of(result.begin(), result.end(), [](unsigned char c) { return std::islower(c); })) {
57+
if (!std::all_of(result.begin(), result.end(), [](unsigned char c) { return std::isalpha(c); })) {
5858
return std::nullopt;
5959
}
6060
return std::make_tuple(term_1, term_2, result);

onnxruntime/test/providers/qnn/einsum_op_test.cc

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,32 @@ TEST_F(QnnCPUBackendTests, EinsumRank2) {
150150
/*tolerance=*/1e-4f);
151151
}
152152

153+
TEST_F(QnnCPUBackendTests, EinsumRank3MatMul) {
154+
const std::vector<int64_t> shape0{4, 5, 6};
155+
const std::vector<int64_t> shape1{4, 6, 5};
156+
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
157+
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
158+
RunQnnEinsum<float>(
159+
/*backend=*/kQnnBackendTypeCpu,
160+
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
161+
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
162+
/*equation=*/"hij,hjk->hik",
163+
/*tolerance=*/1e-4f);
164+
}
165+
166+
TEST_F(QnnCPUBackendTests, EinsumRank3MatMul_QK) {
167+
const std::vector<int64_t> shape0{4, 5, 6};
168+
const std::vector<int64_t> shape1{4, 6, 5};
169+
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
170+
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
171+
RunQnnEinsum<float>(
172+
/*backend=*/kQnnBackendTypeCpu,
173+
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
174+
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
175+
/*equation=*/"hQK,hKd->hQd",
176+
/*tolerance=*/1e-4f);
177+
}
178+
153179
TEST_F(QnnCPUBackendTests, EinsumRank4MatMul) {
154180
const std::vector<int64_t> shape0{3, 4, 5, 6};
155181
const std::vector<int64_t> shape1{3, 4, 6, 5};
@@ -189,6 +215,19 @@ TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll1) {
189215
/*tolerance=*/1e-4f);
190216
}
191217

218+
TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeY_QK) {
219+
const std::vector<int64_t> shape0{2, 3, 4, 6};
220+
const std::vector<int64_t> shape1{2, 3, 5, 6};
221+
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
222+
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
223+
RunQnnEinsum<float>(
224+
/*backend=*/kQnnBackendTypeCpu,
225+
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
226+
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
227+
/*equation=*/"bnQd,bnKd->bnQK",
228+
/*tolerance=*/1e-4f);
229+
}
230+
192231
TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) {
193232
const std::vector<int64_t> shape0{1, 7, 1, 7};
194233
const std::vector<int64_t> shape1{1, 9, 1, 7};
@@ -273,6 +312,60 @@ TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeY) {
273312
/*tolerance=*/1e-2f);
274313
}
275314

315+
TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeY_QK) {
316+
const std::vector<int64_t> shape0{2, 3, 4, 2};
317+
const std::vector<int64_t> shape1{2, 3, 5, 2};
318+
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
319+
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
320+
RunQnnEinsum<float>(
321+
/*backend=*/kQnnBackendTypeHtp,
322+
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
323+
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
324+
/*equation=*/"bnQd,bnKd->bnQK",
325+
/*tolerance=*/1e-2f);
326+
}
327+
328+
TEST_F(QnnHTPBackendTests, EinsumRank3MatMulTransposeY) {
329+
const std::vector<int64_t> shape0{2, 4, 2};
330+
const std::vector<int64_t> shape1{2, 5, 2};
331+
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
332+
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
333+
RunQnnEinsum<float>(
334+
/*backend=*/kQnnBackendTypeHtp,
335+
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
336+
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
337+
/*equation=*/"bid,bjd->bij",
338+
/*tolerance=*/1e-2f);
339+
}
340+
341+
TEST_F(QnnHTPBackendTests, EinsumRank3MatMulTransposeY_QK) {
342+
const std::vector<int64_t> shape0{2, 4, 2};
343+
const std::vector<int64_t> shape1{2, 5, 2};
344+
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
345+
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
346+
RunQnnEinsum<float>(
347+
/*backend=*/kQnnBackendTypeHtp,
348+
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
349+
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
350+
/*equation=*/"bQd,bKd->bQK",
351+
/*tolerance=*/1e-2f);
352+
}
353+
354+
// The value pair (65.1049271, 65.0625076) at index #51 don't match, which is -0.0424194 from 65.1049
355+
// Disable this Rank3 test on HTP since it has accuracy issue.
356+
TEST_F(QnnHTPBackendTests, DISABLED_EinsumRank3MatMul_QK) {
357+
const std::vector<int64_t> shape0{4, 5, 6};
358+
const std::vector<int64_t> shape1{4, 6, 5};
359+
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
360+
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
361+
RunQnnEinsum<float>(
362+
/*backend=*/kQnnBackendTypeHtp,
363+
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
364+
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
365+
/*equation=*/"hQK,hKd->hQd",
366+
/*tolerance=*/1e-2f);
367+
}
368+
276369
TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeAll1) {
277370
const std::vector<int64_t> shape0{1, 3, 1, 7};
278371
const std::vector<int64_t> shape1{1, 7, 1, 3};
@@ -365,6 +458,66 @@ TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeY) {
365458
/*tolerance=*/QDQTolerance());
366459
}
367460

461+
TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeY_QK) {
462+
const std::vector<int64_t> shape0{2, 3, 4, 2};
463+
const std::vector<int64_t> shape1{2, 3, 5, 2};
464+
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
465+
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
466+
RunQnnHtpQdqEinsum<uint8_t, uint8_t>(
467+
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
468+
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
469+
/*equation=*/"bnQd,bnKd->bnQK",
470+
/*tolerance=*/QDQTolerance());
471+
}
472+
473+
TEST_F(QnnHTPBackendTests, EinsumQdqRank3MatMulTransposeY) {
474+
const std::vector<int64_t> shape0{2, 4, 2};
475+
const std::vector<int64_t> shape1{2, 5, 2};
476+
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
477+
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
478+
RunQnnHtpQdqEinsum<uint8_t, uint8_t>(
479+
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
480+
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
481+
/*equation=*/"bid,bjd->bij",
482+
/*tolerance=*/QDQTolerance());
483+
}
484+
485+
TEST_F(QnnHTPBackendTests, EinsumQdqRank3MatMulTransposeY_QK) {
486+
const std::vector<int64_t> shape0{2, 4, 2};
487+
const std::vector<int64_t> shape1{2, 5, 2};
488+
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
489+
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
490+
RunQnnHtpQdqEinsum<uint8_t, uint8_t>(
491+
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
492+
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
493+
/*equation=*/"bQd,bKd->bQK",
494+
/*tolerance=*/QDQTolerance());
495+
}
496+
497+
TEST_F(QnnHTPBackendTests, EinsumQdqRank3MatMul) {
498+
const std::vector<int64_t> shape0{4, 5, 6};
499+
const std::vector<int64_t> shape1{4, 6, 5};
500+
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
501+
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
502+
RunQnnHtpQdqEinsum<uint8_t, uint8_t>(
503+
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
504+
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
505+
/*equation=*/"hij,hjk->hik",
506+
/*tolerance=*/QDQTolerance());
507+
}
508+
509+
TEST_F(QnnHTPBackendTests, EinsumQdqRank3MatMul_QK) {
510+
const std::vector<int64_t> shape0{4, 5, 6};
511+
const std::vector<int64_t> shape1{4, 6, 5};
512+
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
513+
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
514+
RunQnnHtpQdqEinsum<uint8_t, uint8_t>(
515+
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
516+
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
517+
/*equation=*/"hQK,hKd->hQd",
518+
/*tolerance=*/QDQTolerance());
519+
}
520+
368521
TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeAll1) {
369522
const std::vector<int64_t> shape0{1, 3, 1, 7};
370523
const std::vector<int64_t> shape1{1, 7, 1, 3};

0 commit comments

Comments
 (0)