Skip to content

Commit 850be8e

Browse files
[webgpu] Move MatMulNBits to templates (microsoft#25783)
### Description This change moves the regular matmulnbits shader to templates.
1 parent 0d04ad3 commit 850be8e

File tree

2 files changed

+135
-135
lines changed

2 files changed

+135
-135
lines changed

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

Lines changed: 18 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -71,144 +71,27 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
7171
shader.AddInput("zero_points", ShaderUsage::UseUniform);
7272
}
7373
shader.AddOutput("output", ShaderUsage::UseElementTypeAlias);
74+
7475
const uint32_t components_a = a.NumComponents();
7576
const uint32_t components_b = b.NumComponents() / 4; // b is stored as uint32 which includs 4 uint8.
7677
constexpr uint32_t tile_size_k_vec = 16;
77-
uint32_t elements_in_value_b = components_b * (32 / nbits_);
78-
uint32_t tile_k_size = tile_size_k_vec * elements_in_value_b;
79-
const uint32_t a_length_per_tile = tile_k_size / components_a;
80-
81-
shader.AdditionalImplementation() << "const a_length_per_tile = " << a_length_per_tile << "u;\n"
82-
<< "const tile_size_k_vec = " << tile_size_k_vec << ";\n"
83-
<< "const tile_size_k = " << tile_k_size << "u;\n"
84-
<< "const tile_size = " << tile_size_ << "u;\n"
85-
<< "const elements_in_value_b = " << elements_in_value_b << "u;\n"
86-
<< "const sub_tile_count = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n"
87-
<< "const component_a = " << components_a << "u;\n"
88-
<< "const component_b = " << components_b << "u;\n";
89-
shader.AdditionalImplementation() << R"ADDNL_FN(
90-
// Shared memory
91-
var<workgroup> tile_A : array<input_a_value_t, a_length_per_tile>;
92-
var<workgroup> inter_results: array<array<output_element_t, tile_size_k_vec>, tile_size>;
93-
fn loadSHMA(batch: u32, a_global: u32, kidx: u32, col: u32)
94-
{
95-
let k_offset = kidx / component_a + col;
96-
if (batch < uniforms.batch_count && k_offset < uniforms.K_of_a) {
97-
tile_A[col] = input_a[batch * uniforms.M * uniforms.K_of_a + a_global * uniforms.K_of_a + k_offset];
98-
} else {
99-
tile_A[col] = input_a_value_t(0);
100-
}
101-
}
102-
)ADDNL_FN"
103-
<< GenerateZeroPointReadingCode(nbits_, has_zero_points_);
104-
105-
shader.MainFunctionBody() << R"MAIN_FN(
106-
let batch = workgroup_idx / (uniforms.M * uniforms.num_N_tile);
107-
let a_global = (workgroup_idx / uniforms.num_N_tile) % uniforms.M;
108-
let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size;
109-
110-
let idx = local_idx % tile_size_k_vec;
111-
let idy = local_idx / tile_size_k_vec;
112-
113-
for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k)
114-
{
115-
for (var id = local_idx; id < a_length_per_tile; id += workgroup_size_x)
116-
{
117-
loadSHMA(batch, a_global, kidx, id);
118-
}
119-
workgroupBarrier();
120-
121-
for (var local_row_offset = 0u; local_row_offset < tile_size; local_row_offset += sub_tile_count)
122-
{
123-
var b_global = b_global_base + local_row_offset + idy;
124-
var k_offset = kidx / elements_in_value_b + idx;
125-
if (b_global < uniforms.N && k_offset < uniforms.K_of_b)
126-
{
127-
let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size;
128-
let scale_b = scales_b[b_global * uniforms.blocks_per_col + block_idx];
129-
let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col);
130-
var b_value = input_b[b_global * uniforms.K_of_b + k_offset];
131-
)MAIN_FN";
132-
133-
if (nbits_ == 4) {
134-
shader.MainFunctionBody() << R"MAIN_FN(
135-
var sum = output_element_t(0);
136-
var a_offset = idx * (8 / component_a) * component_b;
137-
for (var i = 0u; i < component_b; i++) {
138-
let b_value_lower = vec4<output_element_t>(unpack4xU8(b_value[i] & 0x0F0F0F0Fu)) - vec4<output_element_t>(zero);
139-
let b_value_upper = vec4<output_element_t>(unpack4xU8((b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4<output_element_t>(zero);
140-
let b0 = vec4<output_element_t>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]) * scale_b;
141-
let b1 = vec4<output_element_t>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]) * scale_b;
142-
)MAIN_FN";
143-
switch (components_a) {
144-
case 1:
145-
shader.MainFunctionBody() << " sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b0) +"
146-
" dot(vec4<output_element_t>(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]), b1);\n"
147-
" a_offset += 8;\n";
148-
break;
149-
case 2:
150-
shader.MainFunctionBody() << " sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b0) +"
151-
"dot(vec4<output_element_t>(tile_A[a_offset + 2], tile_A[a_offset + 3]), b1);\n"
152-
" a_offset += 4;\n";
153-
break;
154-
case 4:
155-
shader.MainFunctionBody() << " sum += dot(tile_A[a_offset], b0) + dot(tile_A[a_offset + 1], b1);\n"
156-
" a_offset += 2;\n";
157-
break;
158-
default:
159-
break;
160-
}
161-
shader.MainFunctionBody() << " }\n";
162-
} else {
163-
shader.MainFunctionBody() << R"MAIN_FN(
164-
var sum = output_element_t(0);
165-
var a_offset = idx * (4 / component_a) * component_b;
166-
for (var i = 0u; i < component_b; i++) {
167-
let b_value = (vec4<output_element_t>(unpack4xU8(b_value[i])) - vec4<output_element_t>(zero)) * scale_b;
168-
)MAIN_FN";
169-
switch (components_a) {
170-
case 1:
171-
shader.MainFunctionBody() << " sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value);\n"
172-
" a_offset += 4;\n";
173-
break;
174-
case 2:
175-
shader.MainFunctionBody() << " sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b_value);\n"
176-
" a_offset += 2;\n";
177-
break;
178-
case 4:
179-
shader.MainFunctionBody() << " sum += dot(tile_A[a_offset], b_value);\n"
180-
" a_offset += 1;\n";
181-
break;
182-
default:
183-
break;
184-
}
185-
shader.MainFunctionBody() << " }\n";
186-
}
187-
188-
shader.MainFunctionBody() << R"MAIN_FN(
189-
inter_results[local_row_offset + idy][idx] += sum;
190-
}
191-
}
192-
workgroupBarrier();
193-
}
194-
195-
if (batch >= uniforms.batch_count) {
196-
return;
197-
}
198-
199-
if (local_idx < tile_size) {
200-
var output_value = output_element_t(0);
201-
for (var b = 0u; b < tile_size_k_vec; b++) {
202-
output_value += inter_results[local_idx][b];
203-
}
204-
let b_global = b_global_base + local_idx;
205-
let output_idx = batch * uniforms.M * uniforms.N + a_global * uniforms.N + b_global;
206-
if (b_global < uniforms.N) {
207-
output[output_idx] = output_value;
208-
}
209-
}
210-
)MAIN_FN";
211-
78+
const uint32_t elements_in_value_b = components_b * (32 / nbits_);
79+
const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b;
80+
const uint32_t a_length_per_tile = tile_size_k / components_a;
81+
uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec;
82+
83+
return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits.wgsl.template",
84+
WGSL_TEMPLATE_PARAMETER(a_length_per_tile, a_length_per_tile),
85+
WGSL_TEMPLATE_PARAMETER(component_a, components_a),
86+
WGSL_TEMPLATE_PARAMETER(component_b, components_b),
87+
WGSL_TEMPLATE_PARAMETER(elements_in_value_b, elements_in_value_b),
88+
WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_),
89+
WGSL_TEMPLATE_PARAMETER(n_bits, nbits_),
90+
WGSL_TEMPLATE_PARAMETER(output_type_i32, false),
91+
WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count),
92+
WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
93+
WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k),
94+
WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec));
21295
return Status::OK();
21396
}
21497

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#param a_length_per_tile
5+
#param component_a
6+
#param component_b
7+
#param elements_in_value_b
8+
#param n_bits
9+
#param sub_tile_count
10+
#param tile_size_k_vec
11+
#param tile_size_k
12+
#param tile_size
13+
14+
#include "quantization/matmul_nbits_zero_pt.wgsl.template"
15+
16+
// Shared memory
17+
var<workgroup> tile_A : array<input_a_value_t, a_length_per_tile>;
18+
var<workgroup> inter_results: array<array<output_element_t, tile_size_k_vec>, tile_size>;
19+
20+
fn loadSHMA(batch: u32, a_global: u32, kidx: u32, col: u32)
21+
{
22+
let k_offset = kidx / component_a + col;
23+
if (batch < uniforms.batch_count && k_offset < uniforms.K_of_a) {
24+
tile_A[col] = input_a[batch * uniforms.M * uniforms.K_of_a + a_global * uniforms.K_of_a + k_offset];
25+
} else {
26+
tile_A[col] = input_a_value_t(0);
27+
}
28+
}
29+
30+
$MAIN {
31+
let batch = workgroup_idx / (uniforms.M * uniforms.num_N_tile);
32+
let a_global = (workgroup_idx / uniforms.num_N_tile) % uniforms.M;
33+
let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size;
34+
35+
let idx = local_idx % tile_size_k_vec;
36+
let idy = local_idx / tile_size_k_vec;
37+
38+
for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k)
39+
{
40+
for (var id = local_idx; id < a_length_per_tile; id += workgroup_size_x)
41+
{
42+
loadSHMA(batch, a_global, kidx, id);
43+
}
44+
workgroupBarrier();
45+
46+
for (var local_row_offset = 0u; local_row_offset < tile_size; local_row_offset += sub_tile_count)
47+
{
48+
var b_global = b_global_base + local_row_offset + idy;
49+
var k_offset = kidx / elements_in_value_b + idx;
50+
if (b_global < uniforms.N && k_offset < uniforms.K_of_b)
51+
{
52+
let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size;
53+
let scale_b = scales_b[b_global * uniforms.blocks_per_col + block_idx];
54+
let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col);
55+
var b_value = input_b[b_global * uniforms.K_of_b + k_offset];
56+
57+
#if n_bits == 4
58+
var sum = output_element_t(0);
59+
var a_offset = idx * (8 / component_a) * component_b;
60+
for (var i = 0u; i < component_b; i++) {
61+
let b_value_lower = vec4<output_element_t>(unpack4xU8(b_value[i] & 0x0F0F0F0Fu)) - vec4<output_element_t>(zero);
62+
let b_value_upper = vec4<output_element_t>(unpack4xU8((b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4<output_element_t>(zero);
63+
let b0 = vec4<output_element_t>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]) * scale_b;
64+
let b1 = vec4<output_element_t>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]) * scale_b;
65+
#if component_a == 1
66+
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b0) +
67+
dot(vec4<output_element_t>(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]), b1);
68+
a_offset += 8;
69+
#elif component_a == 2
70+
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b0) +
71+
dot(vec4<output_element_t>(tile_A[a_offset + 2], tile_A[a_offset + 3]), b1);
72+
a_offset += 4;
73+
#elif component_a == 4
74+
sum += dot(tile_A[a_offset], b0) + dot(tile_A[a_offset + 1], b1);
75+
a_offset += 2;
76+
#endif
77+
}
78+
#elif n_bits == 8
79+
var sum = output_element_t(0);
80+
var a_offset = idx * (4 / component_a) * component_b;
81+
for (var i = 0u; i < component_b; i++) {
82+
let b_value = (vec4<output_element_t>(unpack4xU8(b_value[i])) - vec4<output_element_t>(zero)) * scale_b;
83+
#if component_a == 1
84+
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value);
85+
a_offset += 4;
86+
#elif component_a == 2
87+
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b_value);
88+
a_offset += 2;
89+
#elif component_a == 4
90+
sum += dot(tile_A[a_offset], b_value);
91+
a_offset += 1;
92+
#endif
93+
}
94+
#endif
95+
96+
inter_results[local_row_offset + idy][idx] += sum;
97+
}
98+
}
99+
workgroupBarrier();
100+
}
101+
102+
if (batch >= uniforms.batch_count) {
103+
return;
104+
}
105+
106+
if (local_idx < tile_size) {
107+
var output_value = output_element_t(0);
108+
for (var b = 0u; b < tile_size_k_vec; b++) {
109+
output_value += inter_results[local_idx][b];
110+
}
111+
let b_global = b_global_base + local_idx;
112+
let output_idx = batch * uniforms.M * uniforms.N + a_global * uniforms.N + b_global;
113+
if (b_global < uniforms.N) {
114+
output[output_idx] = output_value;
115+
}
116+
}
117+
} // MAIN

0 commit comments

Comments
 (0)