@@ -71,144 +71,27 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
7171 shader.AddInput (" zero_points" , ShaderUsage::UseUniform);
7272 }
7373 shader.AddOutput (" output" , ShaderUsage::UseElementTypeAlias);
74+
7475 const uint32_t components_a = a.NumComponents ();
7576 const uint32_t components_b = b.NumComponents () / 4 ; // b is stored as uint32 which includs 4 uint8.
7677 constexpr uint32_t tile_size_k_vec = 16 ;
77- uint32_t elements_in_value_b = components_b * (32 / nbits_);
78- uint32_t tile_k_size = tile_size_k_vec * elements_in_value_b;
79- const uint32_t a_length_per_tile = tile_k_size / components_a;
80-
81- shader.AdditionalImplementation () << " const a_length_per_tile = " << a_length_per_tile << " u;\n "
82- << " const tile_size_k_vec = " << tile_size_k_vec << " ;\n "
83- << " const tile_size_k = " << tile_k_size << " u;\n "
84- << " const tile_size = " << tile_size_ << " u;\n "
85- << " const elements_in_value_b = " << elements_in_value_b << " u;\n "
86- << " const sub_tile_count = " << WorkgroupSizeX () / tile_size_k_vec << " u;\n "
87- << " const component_a = " << components_a << " u;\n "
88- << " const component_b = " << components_b << " u;\n " ;
89- shader.AdditionalImplementation () << R"ADDNL_FN(
90- // Shared memory
91- var<workgroup> tile_A : array<input_a_value_t, a_length_per_tile>;
92- var<workgroup> inter_results: array<array<output_element_t, tile_size_k_vec>, tile_size>;
93- fn loadSHMA(batch: u32, a_global: u32, kidx: u32, col: u32)
94- {
95- let k_offset = kidx / component_a + col;
96- if (batch < uniforms.batch_count && k_offset < uniforms.K_of_a) {
97- tile_A[col] = input_a[batch * uniforms.M * uniforms.K_of_a + a_global * uniforms.K_of_a + k_offset];
98- } else {
99- tile_A[col] = input_a_value_t(0);
100- }
101- }
102- )ADDNL_FN"
103- << GenerateZeroPointReadingCode (nbits_, has_zero_points_);
104-
105- shader.MainFunctionBody () << R"MAIN_FN(
106- let batch = workgroup_idx / (uniforms.M * uniforms.num_N_tile);
107- let a_global = (workgroup_idx / uniforms.num_N_tile) % uniforms.M;
108- let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size;
109-
110- let idx = local_idx % tile_size_k_vec;
111- let idy = local_idx / tile_size_k_vec;
112-
113- for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k)
114- {
115- for (var id = local_idx; id < a_length_per_tile; id += workgroup_size_x)
116- {
117- loadSHMA(batch, a_global, kidx, id);
118- }
119- workgroupBarrier();
120-
121- for (var local_row_offset = 0u; local_row_offset < tile_size; local_row_offset += sub_tile_count)
122- {
123- var b_global = b_global_base + local_row_offset + idy;
124- var k_offset = kidx / elements_in_value_b + idx;
125- if (b_global < uniforms.N && k_offset < uniforms.K_of_b)
126- {
127- let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size;
128- let scale_b = scales_b[b_global * uniforms.blocks_per_col + block_idx];
129- let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col);
130- var b_value = input_b[b_global * uniforms.K_of_b + k_offset];
131- )MAIN_FN" ;
132-
133- if (nbits_ == 4 ) {
134- shader.MainFunctionBody () << R"MAIN_FN(
135- var sum = output_element_t(0);
136- var a_offset = idx * (8 / component_a) * component_b;
137- for (var i = 0u; i < component_b; i++) {
138- let b_value_lower = vec4<output_element_t>(unpack4xU8(b_value[i] & 0x0F0F0F0Fu)) - vec4<output_element_t>(zero);
139- let b_value_upper = vec4<output_element_t>(unpack4xU8((b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4<output_element_t>(zero);
140- let b0 = vec4<output_element_t>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]) * scale_b;
141- let b1 = vec4<output_element_t>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]) * scale_b;
142- )MAIN_FN" ;
143- switch (components_a) {
144- case 1 :
145- shader.MainFunctionBody () << " sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b0) +"
146- " dot(vec4<output_element_t>(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]), b1);\n "
147- " a_offset += 8;\n " ;
148- break ;
149- case 2 :
150- shader.MainFunctionBody () << " sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b0) +"
151- " dot(vec4<output_element_t>(tile_A[a_offset + 2], tile_A[a_offset + 3]), b1);\n "
152- " a_offset += 4;\n " ;
153- break ;
154- case 4 :
155- shader.MainFunctionBody () << " sum += dot(tile_A[a_offset], b0) + dot(tile_A[a_offset + 1], b1);\n "
156- " a_offset += 2;\n " ;
157- break ;
158- default :
159- break ;
160- }
161- shader.MainFunctionBody () << " }\n " ;
162- } else {
163- shader.MainFunctionBody () << R"MAIN_FN(
164- var sum = output_element_t(0);
165- var a_offset = idx * (4 / component_a) * component_b;
166- for (var i = 0u; i < component_b; i++) {
167- let b_value = (vec4<output_element_t>(unpack4xU8(b_value[i])) - vec4<output_element_t>(zero)) * scale_b;
168- )MAIN_FN" ;
169- switch (components_a) {
170- case 1 :
171- shader.MainFunctionBody () << " sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value);\n "
172- " a_offset += 4;\n " ;
173- break ;
174- case 2 :
175- shader.MainFunctionBody () << " sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b_value);\n "
176- " a_offset += 2;\n " ;
177- break ;
178- case 4 :
179- shader.MainFunctionBody () << " sum += dot(tile_A[a_offset], b_value);\n "
180- " a_offset += 1;\n " ;
181- break ;
182- default :
183- break ;
184- }
185- shader.MainFunctionBody () << " }\n " ;
186- }
187-
188- shader.MainFunctionBody () << R"MAIN_FN(
189- inter_results[local_row_offset + idy][idx] += sum;
190- }
191- }
192- workgroupBarrier();
193- }
194-
195- if (batch >= uniforms.batch_count) {
196- return;
197- }
198-
199- if (local_idx < tile_size) {
200- var output_value = output_element_t(0);
201- for (var b = 0u; b < tile_size_k_vec; b++) {
202- output_value += inter_results[local_idx][b];
203- }
204- let b_global = b_global_base + local_idx;
205- let output_idx = batch * uniforms.M * uniforms.N + a_global * uniforms.N + b_global;
206- if (b_global < uniforms.N) {
207- output[output_idx] = output_value;
208- }
209- }
210- )MAIN_FN" ;
211-
78+ const uint32_t elements_in_value_b = components_b * (32 / nbits_);
79+ const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b;
80+ const uint32_t a_length_per_tile = tile_size_k / components_a;
81+ uint32_t sub_tile_count = WorkgroupSizeX () / tile_size_k_vec;
82+
83+ return WGSL_TEMPLATE_APPLY (shader, " quantization/matmul_nbits.wgsl.template" ,
84+ WGSL_TEMPLATE_PARAMETER (a_length_per_tile, a_length_per_tile),
85+ WGSL_TEMPLATE_PARAMETER (component_a, components_a),
86+ WGSL_TEMPLATE_PARAMETER (component_b, components_b),
87+ WGSL_TEMPLATE_PARAMETER (elements_in_value_b, elements_in_value_b),
88+ WGSL_TEMPLATE_PARAMETER (has_zero_points, has_zero_points_),
89+ WGSL_TEMPLATE_PARAMETER (n_bits, nbits_),
90+ WGSL_TEMPLATE_PARAMETER (output_type_i32, false ),
91+ WGSL_TEMPLATE_PARAMETER (sub_tile_count, sub_tile_count),
92+ WGSL_TEMPLATE_PARAMETER (tile_size, tile_size_),
93+ WGSL_TEMPLATE_PARAMETER (tile_size_k, tile_size_k),
94+ WGSL_TEMPLATE_PARAMETER (tile_size_k_vec, tile_size_k_vec));
21295 return Status::OK ();
21396}
21497
0 commit comments