|
1 | 1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
2 | 2 | // Licensed under the MIT License. |
3 | 3 |
|
| 4 | +// Template for DP4A Matrix Multiply Quantization |
| 5 | +// Quantizes input matrix A for DP4A computation |
| 6 | +// This shader quantizes float values to 8-bit signed integers using pack4x8snorm |
| 7 | + |
| 8 | +var<workgroup> a_values : array<array<input_a_value_t, 32>, 2>; |
| 9 | +var<workgroup> max_values : array<input_a_value_t, 4>; |
| 10 | + |
| 11 | +fn readInput(offset: u32) -> input_a_value_t |
| 12 | +{ |
| 13 | + if (offset >= uniforms.output_size) { |
| 14 | + return input_a_value_t(0); |
| 15 | + } |
| 16 | + return input_a[offset]; |
| 17 | +} |
| 18 | + |
4 | 19 | $MAIN { |
5 | | - var local_a : array<vec4<input_a_element_t>, 32>; |
6 | | - var max_value:vec4<input_a_element_t> = vec4<input_a_element_t>(0); |
7 | | - for (var idx:u32=0;idx<32;idx+=1) |
| 20 | + if (sg_size == 32) { |
| 21 | + let local_a = readInput(global_idx); |
| 22 | + let max_val = subgroupMax(abs(local_a)); |
| 23 | + if (global_idx >= uniforms.output_size) { |
| 24 | + return; |
| 25 | + } |
| 26 | + let max_temp = max(max_val.xy, max_val.zw); |
| 27 | + let scale = max(max_temp[0], max_temp[1]); |
| 28 | + let norm_a = local_a/scale; |
| 29 | + output[global_idx] = pack4x8snorm(vec4<f32>(norm_a)); |
| 30 | + if (local_idx % 32 == 0) |
| 31 | + { |
| 32 | + // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. |
| 33 | + scales[workgroup_idx * 2 + local_idx / 32] = scale/127; |
| 34 | + } |
| 35 | + } else if (sg_size == 16) { |
| 36 | + let local_a = readInput(global_idx); |
| 37 | + let sub_max_value = subgroupMax(abs(local_a)); |
| 38 | + if (local_idx % 16 == 0) { |
| 39 | + max_values[local_idx / 16] = sub_max_value; |
| 40 | + } |
| 41 | + workgroupBarrier(); |
| 42 | + |
| 43 | + if (global_idx >= uniforms.output_size) { |
| 44 | + return; |
| 45 | + } |
| 46 | + |
| 47 | + var max_val = input_a_value_t(0); |
| 48 | + if (local_idx < 32) { |
| 49 | + max_val = max(max_values[0], max_values[1]); |
| 50 | + } else { |
| 51 | + max_val = max(max_values[2], max_values[3]); |
| 52 | + } |
| 53 | + let max_temp = max(max_val.xy, max_val.zw); |
| 54 | + let scale = max(max_temp[0], max_temp[1]); |
| 55 | + let norm_a = local_a/scale; |
| 56 | + output[global_idx] = pack4x8snorm(vec4<f32>(norm_a)); |
| 57 | + if (local_idx % 32 == 0) |
| 58 | + { |
| 59 | + // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. |
| 60 | + scales[workgroup_idx * 2 + local_idx / 32] = scale/127; |
| 61 | + } |
| 62 | + } else { |
| 63 | + let local_row = local_idx / 32u; |
| 64 | + let local_col = local_idx % 32u; |
| 65 | + a_values[local_row][local_col] = readInput(global_idx); |
| 66 | + workgroupBarrier(); |
| 67 | + |
| 68 | + if (global_idx >= uniforms.output_size) { |
| 69 | + return; |
| 70 | + } |
| 71 | + |
| 72 | + var max_val = input_a_value_t(0); |
| 73 | + // TODO: Optimize this part so that all the threads are not computing the same value. |
| 74 | + for (var i = 0u; i < 32u; i++) |
8 | 75 | { |
9 | | - local_a[idx] = input_a[workgroup_idx*32 + idx]; |
10 | | - max_value = max(max_value, abs(local_a[idx])); |
| 76 | + max_val = max(max_val, abs(a_values[local_row][i])); |
11 | 77 | } |
12 | | - var scale = max(max_value.x, max_value.y); |
13 | | - scale = max(scale, max_value.z); |
14 | | - scale = max(scale, max_value.w); |
15 | | - for (var idx:u32=0;idx<32;idx+=1) |
| 78 | + let max_temp = max(max_val.xy, max_val.zw); |
| 79 | + let scale = max(max_temp[0], max_temp[1]); |
| 80 | + let norm_a = a_values[local_row][local_col]/scale; |
| 81 | + output[global_idx] = pack4x8snorm(vec4<f32>(norm_a)); |
| 82 | + if (local_col == 0u) |
16 | 83 | { |
17 | | - output[workgroup_idx*32+idx] = pack4x8snorm(vec4<f32>(local_a[idx]/scale)); |
| 84 | + // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. |
| 85 | + scales[workgroup_idx * 2 + local_row] = scale/127; |
18 | 86 | } |
19 | | - // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. |
20 | | - scales[workgroup_idx] = scale/127; |
21 | | -} // MAIN |
| 87 | + } |
| 88 | +} |
0 commit comments