Skip to content

Commit 7e3174b

Browse files
authored
[webgpu] Optimize dp4 prefill shader for Qualcomm (microsoft#25578)
This change uses subgroupShuffle for sg_size=64 to perform the matmul. It also uses a loop instead of loop unrolling to reduce the register pressure. Phi4 prefill for 1K tokens becomes 8.8s from 11.32s on Qualcomm Adreno X1-85 GPU.
1 parent f58f7eb commit 7e3174b

File tree

3 files changed

+97
-12
lines changed

3 files changed

+97
-12
lines changed

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#param block_size
55
#param n_bits
66
#param has_zero_points
7+
#param is_qualcomm
78

89
#include "quantization/dp4a_matmul_common.wgsl.template"
910

@@ -138,18 +139,35 @@ $MAIN {
138139

139140
// During the compute phase, we have the 64x64 tile split into
140141
// subtiles of 16x16. We have a grid of 4x4 subtiles.
141-
let subtile_id = u32(local_idx / subtile_size);
142-
let subtile_idx = u32(subtile_id / 4);
143-
let subtile_idy = u32(subtile_id % 4);
144-
let base_A = subtile_idx * 16;
145-
let base_B = subtile_idy * 16;
142+
var subtile_id = u32(local_idx / subtile_size);
143+
var subtile_idx = u32(subtile_id / 4);
144+
var subtile_idy = u32(subtile_id % 4);
145+
var base_A = subtile_idx * 16;
146+
var base_B = subtile_idy * 16;
146147
// For each subtile we have 16 threads assigned.
147-
let a_idx = u32(local_idx % subtile_size);
148+
var a_idx = u32(local_idx % subtile_size);
148149

150+
#if is_qualcomm
151+
// subtile_idx is always 0
152+
// subtile_idy is one of {0,1,2,3}
153+
// The subtile is now rectangular 64x16 for qualcomm case and we have 4 subtiles, this way we don't need to
154+
// increase the number of lane_output each thread needs to track. That is if we want to use a subtile that is 64x64
155+
// we would need var lane_outputs: array<output_element_t, 64>;
156+
if (sg_size == 64) {
157+
subtile_id = u32(local_idx / sg_size);
158+
subtile_idx = u32(subtile_id / 4);
159+
subtile_idy = u32(subtile_id % 4);
160+
base_A = subtile_idx * sg_size;
161+
base_B = subtile_idy * 16;
162+
a_idx = sg_id;
163+
}
164+
var lane_outputs: array<output_element_t, 16>;
165+
#else
149166
var lane_output1: vec4<output_element_t>;
150167
var lane_output2: vec4<output_element_t>;
151168
var lane_output3: vec4<output_element_t>;
152169
var lane_output4: vec4<output_element_t>;
170+
#endif
153171
// K's vectorization is 16 items per index. See input_a/input_b.
154172
// tile_size_k_vec - is the k tile size in vectorized space (1/16). That is
155173
// k tile size is 32. In vectorized space that is 32/16 = 2.
@@ -173,6 +191,34 @@ $MAIN {
173191
var own_scale_a: output_element_t = scale_A[base_A + a_idx];
174192

175193
#if has_zero_points && n_bits == 8
194+
#if is_qualcomm
195+
if (sg_size == 64)
196+
{
197+
var own_b0: vec4<u32>;
198+
var own_b1: vec4<u32>;
199+
var own_scale_b: output_element_t;
200+
var zero: i32;
201+
if (sg_id < 16)
202+
{
203+
own_b0 = tile_B[0][base_B + sg_id];
204+
own_b1 = tile_B[1][base_B + sg_id];
205+
own_scale_b = scale_B[base_B + sg_id];
206+
zero = zeroes[base_B + sg_id];
207+
}
208+
// Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul.
209+
for (var i = 0u; i < 16u; i++)
210+
{
211+
lane_outputs[i] += SDP8AI(own_a0, subgroupShuffle(own_b0, i), own_a1, subgroupShuffle(own_b1, i), subgroupShuffle(own_scale_b, i) * own_scale_a, subgroupShuffle(zero, i));
212+
}
213+
}
214+
else
215+
{
216+
for (var i = 0u; i < 16u; i++)
217+
{
218+
lane_outputs[i] += SDP8AI(own_a0, tile_B[0][base_B + i], own_a1, tile_B[1][base_B + i], own_scale_a * scale_B[base_B + i], zeroes[base_B + i]);
219+
}
220+
}
221+
#else
176222
if (sg_size == 16)
177223
{
178224
var own_b0: vec4<u32> = tile_B[0][base_B + sg_id];
@@ -225,7 +271,34 @@ $MAIN {
225271
lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14], zeroes[base_B + 14]);
226272
lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15], zeroes[base_B + 15]);
227273
}
274+
#endif
228275
#else
276+
#if is_qualcomm
277+
if (sg_size == 64)
278+
{
279+
var own_b0: vec4<u32>;
280+
var own_b1: vec4<u32>;
281+
var own_scale_b: output_element_t;
282+
if (sg_id < 16)
283+
{
284+
own_b0 = tile_B[0][base_B + sg_id];
285+
own_b1 = tile_B[1][base_B + sg_id];
286+
own_scale_b = scale_B[base_B + sg_id];
287+
}
288+
// Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul.
289+
for (var i = 0u; i < 16u; i++)
290+
{
291+
lane_outputs[i] += SDP8AI(own_a0, subgroupShuffle(own_b0, i), own_a1, subgroupShuffle(own_b1, i), subgroupShuffle(own_scale_b, i) * own_scale_a);
292+
}
293+
}
294+
else
295+
{
296+
for (var i = 0u; i < 16u; i++)
297+
{
298+
lane_outputs[i] += SDP8AI(own_a0, tile_B[0][base_B + i], own_a1, tile_B[1][base_B + i], own_scale_a * scale_B[base_B + i]);
299+
}
300+
}
301+
#else
229302
if (sg_size == 16)
230303
{
231304
var own_b0: vec4<u32> = tile_B[0][base_B + sg_id];
@@ -277,6 +350,7 @@ $MAIN {
277350
lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]);
278351
lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]);
279352
}
353+
#endif
280354
#endif
281355
workgroupBarrier();
282356
}
@@ -287,9 +361,16 @@ $MAIN {
287361
// This creates a shader requirement that uniforms.N % 16 == 0
288362
if (a_global < uniforms.M && b_global < uniforms.N)
289363
{
364+
#if is_qualcomm
365+
output[output_idx] = vec4<output_element_t>(lane_outputs[0], lane_outputs[1], lane_outputs[2], lane_outputs[3]);
366+
output[output_idx+1] = vec4<output_element_t>(lane_outputs[4], lane_outputs[5], lane_outputs[6], lane_outputs[7]);
367+
output[output_idx+2] = vec4<output_element_t>(lane_outputs[8], lane_outputs[9], lane_outputs[10], lane_outputs[11]);
368+
output[output_idx+3] = vec4<output_element_t>(lane_outputs[12], lane_outputs[13], lane_outputs[14], lane_outputs[15]);
369+
#else
290370
output[output_idx] = lane_output1;
291371
output[output_idx+1] = lane_output2;
292372
output[output_idx+2] = lane_output3;
293373
output[output_idx+3] = lane_output4;
374+
#endif
294375
}
295376
} // MAIN

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
2828
return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul.wgsl.template",
2929
WGSL_TEMPLATE_PARAMETER(block_size, block_size_),
3030
WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_),
31+
WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_),
3132
WGSL_TEMPLATE_PARAMETER(n_bits, nbits_),
3233
WGSL_TEMPLATE_PARAMETER(output_type_i32, true));
3334
}
@@ -118,7 +119,8 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
118119
TensorShape reshaped_y_shape{1, M, N / kVec4Components};
119120
uint32_t num_M_tile = (M + kTileSize - 1) / kTileSize;
120121
uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize;
121-
DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points};
122+
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
123+
DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points, is_qualcomm};
122124
mul_program.SetWorkgroupSize(256);
123125
mul_program.SetDispatchGroupSize(num_M_tile * num_N_tile);
124126
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)},
@@ -133,7 +135,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
133135
{num_N_tile},
134136
{zero_blocks_per_col}})
135137
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast<int>(kVec4Components)})
136-
.CacheHint("Block" + std::to_string(block_size), nbits, has_zero_points);
138+
.CacheHint("Block" + std::to_string(block_size), nbits, has_zero_points, is_qualcomm);
137139
if (has_zero_points) {
138140
mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
139141
}

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ class DP4AMatMulQuantizeProgram final : public Program<DP4AMatMulQuantizeProgram
2121

2222
class DP4AMatMulNBitsProgram final : public Program<DP4AMatMulNBitsProgram> {
2323
public:
24-
DP4AMatMulNBitsProgram(uint32_t block_size, uint32_t nbits, bool has_zero_points) : Program{"DP4AMatMulNBits"},
25-
block_size_(block_size),
26-
nbits_(nbits),
27-
has_zero_points_(has_zero_points) {}
24+
DP4AMatMulNBitsProgram(uint32_t block_size, uint32_t nbits, bool has_zero_points, bool is_qualcomm) : Program{"DP4AMatMulNBits"},
25+
block_size_(block_size),
26+
nbits_(nbits),
27+
has_zero_points_(has_zero_points),
28+
is_qualcomm_(is_qualcomm) {}
2829
Status GenerateShaderCode(ShaderHelper& sh) const override;
2930
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
3031
{"M", ProgramUniformVariableDataType::Uint32},
@@ -39,6 +40,7 @@ class DP4AMatMulNBitsProgram final : public Program<DP4AMatMulNBitsProgram> {
3940
uint32_t block_size_;
4041
uint32_t nbits_;
4142
bool has_zero_points_;
43+
bool is_qualcomm_;
4244
};
4345

4446
class DP4AMatMulNBitsSmallMProgram final : public Program<DP4AMatMulNBitsSmallMProgram> {

0 commit comments

Comments
 (0)