Skip to content

Commit 1ad9f12

Browse files
authored
[webgpu] Use 64 as the workgroup size of DP4AMatMulQuantize (microsoft#24129)
Usually, workgroup size 1 is not a good option for compute shader. It means that only one thread is active in one workgroup. This PR uses 64 as the workgroup size of DP4AMatMulQuantize. On Qualcomm Adreno x1-85 GPU: 721.13 ms -> 148.38 ms On NV RTX 2000 Ada: 87.66 ms -> 14.51 ms On Intel Xe GPU: 76.30 ms -> 42.96 ms
1 parent 850be8e commit 1ad9f12

File tree

3 files changed

+91
-21
lines changed

3 files changed

+91
-21
lines changed

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,28 +73,30 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
7373

7474
constexpr uint32_t kBlockSizeA = 128;
7575
DP4AMatMulQuantizeProgram quantize_program;
76-
quantize_program.SetWorkgroupSize(1);
77-
quantize_program.SetDispatchGroupSize(M * K / kBlockSizeA, 1, 1);
76+
quantize_program.SetWorkgroupSize(64);
77+
uint32_t tile_size = 64 * kVec4Components;
78+
quantize_program.SetDispatchGroupSize((M * K + tile_size - 1) / tile_size, 1, 1);
7879
TensorShape a_quant_shape{1, M, K / kU32Components};
7980
Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), a_quant_shape);
8081
TensorShapeVector a_scales_dims({1, 1, M, K / kBlockSizeA});
8182
Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims);
8283
quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)}})
8384
.AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), 1},
84-
{&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), 1}});
85+
{&a_scale, ProgramTensorMetadataDependency::Rank, 1}})
86+
.AddUniformVariable({M * K / kU32Components});
8587
ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program));
8688
const bool has_zero_points = zero_points != nullptr;
8789
if (M < min_M_for_tile_optimization) {
8890
uint32_t tile_size_k_vec = 16;
89-
uint32_t tile_size = 32;
91+
uint32_t tile_size_n = 32;
9092

9193
if (context.AdapterInfo().vendor == std::string_view{"intel"}) {
9294
tile_size_k_vec = 32;
93-
tile_size = 4;
95+
tile_size_n = 4;
9496
}
9597

96-
DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size, nbits, has_zero_points};
97-
uint32_t num_N_tile = (N + tile_size - 1) / tile_size;
98+
DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points};
99+
uint32_t num_N_tile = (N + tile_size_n - 1) / tile_size_n;
98100
mul_program.SetWorkgroupSize(128);
99101
mul_program.SetDispatchGroupSize(M * num_N_tile);
100102
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)},
@@ -103,7 +105,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
103105
{scales, ProgramTensorMetadataDependency::TypeAndRank, 1}})
104106
.AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile, zero_blocks_per_col})
105107
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1})
106-
.CacheHint(nbits, tile_size_k_vec, tile_size, has_zero_points);
108+
.CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points);
107109
if (has_zero_points) {
108110
mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
109111
}

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class DP4AMatMulQuantizeProgram final : public Program<DP4AMatMulQuantizeProgram
1616
public:
1717
DP4AMatMulQuantizeProgram() : Program{"DP4AMatMulQuantize"} {}
1818
Status GenerateShaderCode(ShaderHelper& sh) const override;
19+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32});
1920
};
2021

2122
class DP4AMatMulNBitsProgram final : public Program<DP4AMatMulNBitsProgram> {
Lines changed: 80 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,88 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4+
// Template for DP4A Matrix Multiply Quantization
5+
// Quantizes input matrix A for DP4A computation
6+
// This shader quantizes float values to 8-bit signed integers using pack4x8snorm
7+
8+
var<workgroup> a_values : array<array<input_a_value_t, 32>, 2>;
9+
var<workgroup> max_values : array<input_a_value_t, 4>;
10+
11+
fn readInput(offset: u32) -> input_a_value_t
12+
{
13+
if (offset >= uniforms.output_size) {
14+
return input_a_value_t(0);
15+
}
16+
return input_a[offset];
17+
}
18+
419
$MAIN {
5-
var local_a : array<vec4<input_a_element_t>, 32>;
6-
var max_value:vec4<input_a_element_t> = vec4<input_a_element_t>(0);
7-
for (var idx:u32=0;idx<32;idx+=1)
20+
if (sg_size == 32) {
21+
let local_a = readInput(global_idx);
22+
let max_val = subgroupMax(abs(local_a));
23+
if (global_idx >= uniforms.output_size) {
24+
return;
25+
}
26+
let max_temp = max(max_val.xy, max_val.zw);
27+
let scale = max(max_temp[0], max_temp[1]);
28+
let norm_a = local_a/scale;
29+
output[global_idx] = pack4x8snorm(vec4<f32>(norm_a));
30+
if (local_idx % 32 == 0)
31+
{
32+
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
33+
scales[workgroup_idx * 2 + local_idx / 32] = scale/127;
34+
}
35+
} else if (sg_size == 16) {
36+
let local_a = readInput(global_idx);
37+
let sub_max_value = subgroupMax(abs(local_a));
38+
if (local_idx % 16 == 0) {
39+
max_values[local_idx / 16] = sub_max_value;
40+
}
41+
workgroupBarrier();
42+
43+
if (global_idx >= uniforms.output_size) {
44+
return;
45+
}
46+
47+
var max_val = input_a_value_t(0);
48+
if (local_idx < 32) {
49+
max_val = max(max_values[0], max_values[1]);
50+
} else {
51+
max_val = max(max_values[2], max_values[3]);
52+
}
53+
let max_temp = max(max_val.xy, max_val.zw);
54+
let scale = max(max_temp[0], max_temp[1]);
55+
let norm_a = local_a/scale;
56+
output[global_idx] = pack4x8snorm(vec4<f32>(norm_a));
57+
if (local_idx % 32 == 0)
58+
{
59+
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
60+
scales[workgroup_idx * 2 + local_idx / 32] = scale/127;
61+
}
62+
} else {
63+
let local_row = local_idx / 32u;
64+
let local_col = local_idx % 32u;
65+
a_values[local_row][local_col] = readInput(global_idx);
66+
workgroupBarrier();
67+
68+
if (global_idx >= uniforms.output_size) {
69+
return;
70+
}
71+
72+
var max_val = input_a_value_t(0);
73+
// TODO: Optimize this part so that all the threads are not computing the same value.
74+
for (var i = 0u; i < 32u; i++)
875
{
9-
local_a[idx] = input_a[workgroup_idx*32 + idx];
10-
max_value = max(max_value, abs(local_a[idx]));
76+
max_val = max(max_val, abs(a_values[local_row][i]));
1177
}
12-
var scale = max(max_value.x, max_value.y);
13-
scale = max(scale, max_value.z);
14-
scale = max(scale, max_value.w);
15-
for (var idx:u32=0;idx<32;idx+=1)
78+
let max_temp = max(max_val.xy, max_val.zw);
79+
let scale = max(max_temp[0], max_temp[1]);
80+
let norm_a = a_values[local_row][local_col]/scale;
81+
output[global_idx] = pack4x8snorm(vec4<f32>(norm_a));
82+
if (local_col == 0u)
1683
{
17-
output[workgroup_idx*32+idx] = pack4x8snorm(vec4<f32>(local_a[idx]/scale));
84+
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
85+
scales[workgroup_idx * 2 + local_row] = scale/127;
1886
}
19-
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
20-
scales[workgroup_idx] = scale/127;
21-
} // MAIN
87+
}
88+
}

0 commit comments

Comments
 (0)