44#param block_size
55#param n_bits
66#param has_zero_points
7+ #param is_qualcomm
78
89#include "quantization/dp4a_matmul_common.wgsl.template"
910
@@ -138,18 +139,35 @@ $MAIN {
138139
139140 // During the compute phase, we have the 64x64 tile split into
140141 // subtiles of 16x16. We have a grid of 4x4 subtiles.
141- let subtile_id = u32(local_idx / subtile_size);
142- let subtile_idx = u32(subtile_id / 4);
143- let subtile_idy = u32(subtile_id % 4);
144- let base_A = subtile_idx * 16;
145- let base_B = subtile_idy * 16;
142+ var subtile_id = u32(local_idx / subtile_size);
143+ var subtile_idx = u32(subtile_id / 4);
144+ var subtile_idy = u32(subtile_id % 4);
145+ var base_A = subtile_idx * 16;
146+ var base_B = subtile_idy * 16;
146147 // For each subtile we have 16 threads assigned.
147- let a_idx = u32(local_idx % subtile_size);
148+ var a_idx = u32(local_idx % subtile_size);
148149
150+ #if is_qualcomm
151+ // subtile_idx is always 0
152+ // subtile_idy is one of {0,1,2,3}
153+ // The subtile is now rectangular 64x16 for qualcomm case and we have 4 subtiles, this way we don't need to
154+ // increase the number of lane_output each thread needs to track. That is if we want to use a subtile that is 64x64
155+ // we would need var lane_outputs: array<output_element_t, 64>;
156+ if (sg_size == 64) {
157+ subtile_id = u32(local_idx / sg_size);
158+ subtile_idx = u32(subtile_id / 4);
159+ subtile_idy = u32(subtile_id % 4);
160+ base_A = subtile_idx * sg_size;
161+ base_B = subtile_idy * 16;
162+ a_idx = sg_id;
163+ }
164+ var lane_outputs: array<output_element_t, 16>;
165+ #else
149166 var lane_output1: vec4<output_element_t>;
150167 var lane_output2: vec4<output_element_t>;
151168 var lane_output3: vec4<output_element_t>;
152169 var lane_output4: vec4<output_element_t>;
170+ #endif
153171 // K's vectorization is 16 items per index. See input_a/input_b.
154172 // tile_size_k_vec - is the k tile size in vectorized space (1/16). That is
155173 // k tile size is 32. In vectorized space that is 32/16 = 2.
@@ -173,6 +191,34 @@ $MAIN {
173191 var own_scale_a: output_element_t = scale_A[base_A + a_idx];
174192
175193#if has_zero_points && n_bits == 8
194+ #if is_qualcomm
195+ if (sg_size == 64)
196+ {
197+ var own_b0: vec4<u32>;
198+ var own_b1: vec4<u32>;
199+ var own_scale_b: output_element_t;
200+ var zero: i32;
201+ if (sg_id < 16)
202+ {
203+ own_b0 = tile_B[0][base_B + sg_id];
204+ own_b1 = tile_B[1][base_B + sg_id];
205+ own_scale_b = scale_B[base_B + sg_id];
206+ zero = zeroes[base_B + sg_id];
207+ }
208+ // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul.
209+ for (var i = 0u; i < 16u; i++)
210+ {
211+ lane_outputs[i] += SDP8AI(own_a0, subgroupShuffle(own_b0, i), own_a1, subgroupShuffle(own_b1, i), subgroupShuffle(own_scale_b, i) * own_scale_a, subgroupShuffle(zero, i));
212+ }
213+ }
214+ else
215+ {
216+ for (var i = 0u; i < 16u; i++)
217+ {
218+ lane_outputs[i] += SDP8AI(own_a0, tile_B[0][base_B + i], own_a1, tile_B[1][base_B + i], own_scale_a * scale_B[base_B + i], zeroes[base_B + i]);
219+ }
220+ }
221+ #else
176222 if (sg_size == 16)
177223 {
178224 var own_b0: vec4<u32> = tile_B[0][base_B + sg_id];
@@ -225,7 +271,34 @@ $MAIN {
225271 lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14], zeroes[base_B + 14]);
226272 lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15], zeroes[base_B + 15]);
227273 }
274+ #endif
228275#else
276+ #if is_qualcomm
277+ if (sg_size == 64)
278+ {
279+ var own_b0: vec4<u32>;
280+ var own_b1: vec4<u32>;
281+ var own_scale_b: output_element_t;
282+ if (sg_id < 16)
283+ {
284+ own_b0 = tile_B[0][base_B + sg_id];
285+ own_b1 = tile_B[1][base_B + sg_id];
286+ own_scale_b = scale_B[base_B + sg_id];
287+ }
288+ // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul.
289+ for (var i = 0u; i < 16u; i++)
290+ {
291+ lane_outputs[i] += SDP8AI(own_a0, subgroupShuffle(own_b0, i), own_a1, subgroupShuffle(own_b1, i), subgroupShuffle(own_scale_b, i) * own_scale_a);
292+ }
293+ }
294+ else
295+ {
296+ for (var i = 0u; i < 16u; i++)
297+ {
298+ lane_outputs[i] += SDP8AI(own_a0, tile_B[0][base_B + i], own_a1, tile_B[1][base_B + i], own_scale_a * scale_B[base_B + i]);
299+ }
300+ }
301+ #else
229302 if (sg_size == 16)
230303 {
231304 var own_b0: vec4<u32> = tile_B[0][base_B + sg_id];
@@ -277,6 +350,7 @@ $MAIN {
277350 lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]);
278351 lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]);
279352 }
353+ #endif
280354#endif
281355 workgroupBarrier();
282356 }
@@ -287,9 +361,16 @@ $MAIN {
287361 // This creates a shader requirement that uniforms.N % 16 == 0
288362 if (a_global < uniforms.M && b_global < uniforms.N)
289363 {
364+ #if is_qualcomm
365+ output[output_idx] = vec4<output_element_t>(lane_outputs[0], lane_outputs[1], lane_outputs[2], lane_outputs[3]);
366+ output[output_idx+1] = vec4<output_element_t>(lane_outputs[4], lane_outputs[5], lane_outputs[6], lane_outputs[7]);
367+ output[output_idx+2] = vec4<output_element_t>(lane_outputs[8], lane_outputs[9], lane_outputs[10], lane_outputs[11]);
368+ output[output_idx+3] = vec4<output_element_t>(lane_outputs[12], lane_outputs[13], lane_outputs[14], lane_outputs[15]);
369+ #else
290370 output[output_idx] = lane_output1;
291371 output[output_idx+1] = lane_output2;
292372 output[output_idx+2] = lane_output3;
293373 output[output_idx+3] = lane_output4;
374+ #endif
294375 }
295376} // MAIN
0 commit comments