Skip to content

Commit a61fb39

Browse files
authored
[webgpu] Fix poor performance in flash attention for Qualcomm devices (microsoft#25730)
It seems that when multiple threads in one subgroup access the same shared memory location, the performance is poor on Qualcomm devices (bank conflicts?). If we limit the number of threads accessing the same memory location, the performance is greatly improved on Qualcomm devices. Phi4 becomes ~10s from ~13s on QC Adreno X1-85 (31.0.112.0).
1 parent 8e871c8 commit a61fb39

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,14 @@ $MAIN {
149149
var qk_4 : vec4<q_element_t>;
150150
if (sg_size > 8) {
151151
for (var i : u32 = 0u; i < head_size_vec; i++) {
152+
#if is_qualcomm
153+
var k_local = q_value_t(0);
154+
if (sg_id < max_k_step) {
155+
k_local = k_tile[sg_id][i];
156+
}
157+
#else
152158
var k_local = k_tile[capped_sg_id][i];
159+
#endif
153160
var q_own = q_tile[i];
154161
qk_1[0] += dot(q_own, subgroupShuffle(k_local, 0));
155162
qk_1[1] += dot(q_own, subgroupShuffle(k_local, 1));
@@ -243,7 +250,10 @@ $MAIN {
243250
#if is_qualcomm
244251
if (sg_size > 8) {
245252
for (var i : u32 = 0; i < half_head_size_vec; i++) {
246-
var val = v_tile[capped_sg_id][i];
253+
var val = q_value_t(0);
254+
if (sg_id < max_k_step) {
255+
val = v_tile[sg_id][i];
256+
}
247257
var sum = subgroupShuffle(val, 0) * qk_1[0];
248258
sum += subgroupShuffle(val, 1) * qk_1[1];
249259
sum += subgroupShuffle(val, 2) * qk_1[2];
@@ -262,7 +272,9 @@ $MAIN {
262272
sum += subgroupShuffle(val, 15) * qk_4[3];
263273
o_tile[i] = o_tile[i] * o_ratio + sum;
264274

265-
val = v_tile[capped_sg_id][half_head_size_vec + i];
275+
if (sg_id < max_k_step) {
276+
val = v_tile[sg_id][half_head_size_vec + i];
277+
}
266278
sum = subgroupShuffle(val, 0) * qk_1[0];
267279
sum += subgroupShuffle(val, 1) * qk_1[1];
268280
sum += subgroupShuffle(val, 2) * qk_1[2];
@@ -353,4 +365,4 @@ $MAIN {
353365
writeo(q_idx_global, head_idx);
354366
}
355367
#endif
356-
} // MAIN
368+
} // MAIN

0 commit comments

Comments
 (0)