Skip to content

Commit edf6ff0

Browse files
[webgpu] Move WebGPU DP4A shaders into templates (microsoft#25724)
### Description Moves DP4A shaders into templates ### Motivation and Context Preparation for upcoming changes to add 2 bit quantization and MOE. Moving to templates will improve code readability. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent cb0c5e9 commit edf6ff0

File tree

7 files changed

+555
-521
lines changed

7 files changed

+555
-521
lines changed

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template

Lines changed: 276 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#param n_bits
5+
#param has_zero_points
6+
7+
#include "quantization/matmul_nbits_zero_pt.wgsl.template"
8+
9+
#if n_bits == 4
10+
fn DequantizedFrom4BitsTo8Bits(in: vec2<u32>, zero: i32) -> vec4<u32>
11+
{
12+
var out = vec4<u32>(0);
13+
var value_lower = vec4<i32>(unpack4xU8(in[0] & 0x0F0F0F0Fu)) - vec4<i32>(zero);
14+
var value_upper = vec4<i32>(unpack4xU8((in[0] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(zero);
15+
out[0] = pack4xI8(vec4<i32>(value_lower[0], value_upper[0], value_lower[1], value_upper[1]));
16+
out[1] = pack4xI8(vec4<i32>(value_lower[2], value_upper[2], value_lower[3], value_upper[3]));
17+
value_lower = vec4<i32>(unpack4xU8(in[1] & 0x0F0F0F0Fu)) - vec4<i32>(zero);
18+
value_upper = vec4<i32>(unpack4xU8((in[1] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(zero);
19+
out[2] = pack4xI8(vec4<i32>(value_lower[0], value_upper[0], value_lower[1], value_upper[1]));
20+
out[3] = pack4xI8(vec4<i32>(value_lower[2], value_upper[2], value_lower[3], value_upper[3]));
21+
return out;
22+
}
23+
#endif
24+
25+
#if n_bits == 8
26+
fn AlignWithZeroPoint(in: vec4<u32>) -> vec4<u32>
27+
{
28+
var out = vec4<u32>(0);
29+
out[0] = pack4xI8(vec4<i32>(unpack4xU8(in[0])) - vec4<i32>(128));
30+
out[1] = pack4xI8(vec4<i32>(unpack4xU8(in[1])) - vec4<i32>(128));
31+
out[2] = pack4xI8(vec4<i32>(unpack4xU8(in[2])) - vec4<i32>(128));
32+
out[3] = pack4xI8(vec4<i32>(unpack4xU8(in[3])) - vec4<i32>(128));
33+
return out;
34+
}
35+
#endif
36+
37+
// For 8bits, in case data overflow when converting from int32 (output of dot4I8Packed) to f16, we force it convert to f32.
38+
// Then do the scale. Finally, convert to output element type.
39+
#if has_zero_points && n_bits == 8
40+
// If has_zero_points is true, vec4<i32>(unpack4xU8(b_data)) - vec4<i32>(zero) may be out of the range [-128, 127] since zero can be any value between [0, 255].
41+
// To avoid the data overflow when use pack4xI8, we still use |pack4xI8(vec4<i32>(unpack4xU8(xxx)) - vec4<i32>(128))| to process the b data. In SDP8AI, we use the
42+
// dp4a's result of a and b to subtract dot(vec4<i32>(unpack4xI8(a)), vec4<i32>(zero - 128)) to get the correct result.
43+
// Scaled dot product of 8 packed unsigned integers.
44+
fn SDP8AI(a1:vec4<u32>, b1:vec4<u32>, a2:vec4<u32>, b2:vec4<u32>, scale:output_element_t, zero: i32) -> output_element_t
45+
{
46+
let bias_zero = zero - 128;
47+
var local_sum = dot4I8Packed(a1[0], b1[0]);
48+
var dequantized_a_sum = vec4<i32>(unpack4xI8(a1[0]));
49+
local_sum += dot4I8Packed(a1[1], b1[1]);
50+
dequantized_a_sum += vec4<i32>(unpack4xI8(a1[1]));
51+
local_sum += dot4I8Packed(a1[2], b1[2]);
52+
dequantized_a_sum += vec4<i32>(unpack4xI8(a1[2]));
53+
local_sum += dot4I8Packed(a1[3], b1[3]);
54+
dequantized_a_sum += vec4<i32>(unpack4xI8(a1[3]));
55+
local_sum += dot4I8Packed(a2[0], b2[0]);
56+
dequantized_a_sum += vec4<i32>(unpack4xI8(a2[0]));
57+
local_sum += dot4I8Packed(a2[1], b2[1]);
58+
dequantized_a_sum += vec4<i32>(unpack4xI8(a2[1]));
59+
local_sum += dot4I8Packed(a2[2], b2[2]);
60+
dequantized_a_sum += vec4<i32>(unpack4xI8(a2[2]));
61+
local_sum += dot4I8Packed(a2[3], b2[3]);
62+
dequantized_a_sum += vec4<i32>(unpack4xI8(a2[3]));
63+
local_sum -= dot(dequantized_a_sum, vec4<i32>(bias_zero));
64+
return output_element_t(f32(local_sum) * f32(scale));
65+
}
66+
#else
67+
// Scaled dot product of 8 packed unsigned integers.
68+
fn SDP8AI(a1:vec4<u32>, b1:vec4<u32>, a2:vec4<u32>, b2:vec4<u32>, scale:output_element_t) -> output_element_t
69+
{
70+
var local_sum = dot4I8Packed(a1[0], b1[0]);
71+
local_sum += dot4I8Packed(a1[1], b1[1]);
72+
local_sum += dot4I8Packed(a1[2], b1[2]);
73+
local_sum += dot4I8Packed(a1[3], b1[3]);
74+
local_sum += dot4I8Packed(a2[0], b2[0]);
75+
local_sum += dot4I8Packed(a2[1], b2[1]);
76+
local_sum += dot4I8Packed(a2[2], b2[2]);
77+
local_sum += dot4I8Packed(a2[3], b2[3]);
78+
return output_element_t(f32(local_sum) * f32(scale));
79+
}
80+
#endif

0 commit comments

Comments
 (0)