|
2 | 2 |
|
3 | 3 | import uk.ac.manchester.tornado.api.KernelContext; |
4 | 4 | import uk.ac.manchester.tornado.api.math.TornadoMath; |
5 | | -import uk.ac.manchester.tornado.api.types.arrays.ByteArray; |
6 | 5 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray; |
7 | 6 |
|
8 | 7 | public class TransformerComputeKernels { |
@@ -95,310 +94,4 @@ public static void reductionOneBlock2WithLogits(KernelContext context, FloatArra |
95 | 94 | output.set(gid, weights.get(gid) * (ss * output.get(gid))); |
96 | 95 | } |
97 | 96 |
|
98 | | - /** |
99 | | - * Optimized matrix-vector multiplication for Q8_0 quantized weights. |
100 | | - * Q8_0 format: 8-bit quantization with FP16 scale per 32-element block. |
101 | | - * |
102 | | - * Block format: |
103 | | - * - Bytes 0-1: Float16 scale value (big-endian) |
104 | | - * - Bytes 2-33: 32 quantized values (int8) |
105 | | - * |
106 | | - * Optimization features: |
107 | | - * - Loop unrolling (16x factor) |
108 | | - * - Vectorization (4 elements per iteration) |
109 | | - * - Scale caching to avoid redundant decompression |
110 | | - * - Fused multiply-add for accumulation |
111 | | - * |
112 | | - * @param context Kernel execution context |
113 | | - * @param thisx Quantized matrix in row-major order |
114 | | - * @param that Vector to multiply |
115 | | - * @param out Output vector (result) |
116 | | - * @param dim1 Vector dimension |
117 | | - */ |
118 | | - public static void matmulTornadoQ8Optimized(KernelContext context, ByteArray thisx, FloatArray that, FloatArray out, int dim1) { |
119 | | - final int BLOCK_SIZE = 32; // Block size used in quantization |
120 | | - final int BYTES_PER_BLOCK = 2 + BLOCK_SIZE; // 2 bytes for scale + block_size bytes for values |
121 | | - final int UNROLL_FACTOR = 16; // Increased unroll factor for better performance |
122 | | - final int VECTOR_SIZE = 4; // Process 4 elements at once with vectorization |
123 | | - |
124 | | - int idx = context.globalIdx; |
125 | | - float result = 0f; |
126 | | - int thisOffset = idx * dim1; |
127 | | - |
128 | | - // Cache last block index and scale to avoid redundant decoding |
129 | | - int lastBlockIndex = -1; |
130 | | - float cachedScale = 0f; |
131 | | - |
132 | | - // Early calculation of block boundaries to reduce in-loop calculations |
133 | | - int numFullUnrolls = dim1 / UNROLL_FACTOR; |
134 | | - int remainingStart = numFullUnrolls * UNROLL_FACTOR; |
135 | | - |
136 | | - // Pre-calculate first block index to potentially save work in the loop |
137 | | - int firstIndex = thisOffset; |
138 | | - int firstBlockIndex = firstIndex / BLOCK_SIZE; |
139 | | - int firstBlockOffset = firstBlockIndex * BYTES_PER_BLOCK; |
140 | | - |
141 | | - // Initial scale calculation outside the loop |
142 | | - int scaleByte1 = thisx.get(firstBlockOffset) & 0xFF; |
143 | | - int scaleByte2 = thisx.get(firstBlockOffset + 1) & 0xFF; |
144 | | - short scaleFloat16 = (short) ((scaleByte2 << 8) | scaleByte1); |
145 | | - cachedScale = decodeFloat16Fast(scaleFloat16); |
146 | | - lastBlockIndex = firstBlockIndex; |
147 | | - |
148 | | - // Main loop with increased unrolling |
149 | | - for (int j = 0; j < numFullUnrolls; j++) { |
150 | | - int baseIdx = j * UNROLL_FACTOR; |
151 | | - |
152 | | - // Process elements in groups of UNROLL_FACTOR |
153 | | - for (int k = 0; k < UNROLL_FACTOR; k += VECTOR_SIZE) { |
154 | | - // Process VECTOR_SIZE elements in each iteration |
155 | | - for (int v = 0; v < VECTOR_SIZE; v++) { |
156 | | - int index = thisOffset + baseIdx + k + v; |
157 | | - int blockIndex = index / BLOCK_SIZE; |
158 | | - |
159 | | - // Only decode scale if we're in a new block |
160 | | - if (blockIndex != lastBlockIndex) { |
161 | | - int blockOffset = blockIndex * BYTES_PER_BLOCK; |
162 | | - int newScaleByte1 = thisx.get(blockOffset) & 0xFF; |
163 | | - int newScaleByte2 = thisx.get(blockOffset + 1) & 0xFF; |
164 | | - short newScaleFloat16 = (short) ((newScaleByte2 << 8) | newScaleByte1); |
165 | | - cachedScale = decodeFloat16Fast(newScaleFloat16); |
166 | | - lastBlockIndex = blockIndex; |
167 | | - } |
168 | | - |
169 | | - int withinBlockIndex = index % BLOCK_SIZE; |
170 | | - int blockOffset = blockIndex * BYTES_PER_BLOCK; |
171 | | - |
172 | | - // Read quantized value |
173 | | - byte quantized = thisx.get(blockOffset + 2 + withinBlockIndex); |
174 | | - |
175 | | - // Dequantize and accumulate |
176 | | - result = fma(quantized * cachedScale, that.get(baseIdx + k + v), result); |
177 | | - } |
178 | | - } |
179 | | - } |
180 | | - |
181 | | - // Handle remaining elements |
182 | | - for (int j = remainingStart; j < dim1; j++) { |
183 | | - int index = thisOffset + j; |
184 | | - int blockIndex = index / BLOCK_SIZE; |
185 | | - |
186 | | - // Only decode scale if we're in a new block |
187 | | - if (blockIndex != lastBlockIndex) { |
188 | | - int blockOffset = blockIndex * BYTES_PER_BLOCK; |
189 | | - int scaleByte11 = thisx.get(blockOffset) & 0xFF; |
190 | | - int scaleByte22 = thisx.get(blockOffset + 1) & 0xFF; |
191 | | - short scaleFloat166 = (short) ((scaleByte22 << 8) | scaleByte11); |
192 | | - cachedScale = decodeFloat16Fast(scaleFloat166); |
193 | | - lastBlockIndex = blockIndex; |
194 | | - } |
195 | | - |
196 | | - int withinBlockIndex = index % BLOCK_SIZE; |
197 | | - int blockOffset = blockIndex * BYTES_PER_BLOCK; |
198 | | - |
199 | | - // Read quantized value |
200 | | - byte quantized = thisx.get(blockOffset + 2 + withinBlockIndex); |
201 | | - |
202 | | - // Dequantize and accumulate |
203 | | - result = fma(quantized * cachedScale, that.get(j), result); |
204 | | - } |
205 | | - |
206 | | - out.set(idx, result); |
207 | | - } |
208 | | - |
209 | | - /** |
210 | | - * Optimized matrix-vector multiplication for Q4_0 quantized weights. |
211 | | - * Q4_0 format: 4-bit quantization with FP16 scale per 32-element block. |
212 | | - * |
213 | | - * Block format: |
214 | | - * - Bytes 0-1: Float16 scale value (big-endian) |
215 | | - * - Bytes 2-17: 16 bytes storing 32 packed 4-bit values |
216 | | - * |
217 | | - * Each byte stores two 4-bit values: |
218 | | - * - Lower nibble: elements 0, 2, 4, ... |
219 | | - * - Upper nibble: elements 1, 3, 5, ... |
220 | | - * |
221 | | - * Q4_0 uses signed 4-bit values with -8 offset (range: -8 to 7) |
222 | | - * |
223 | | - * @param context Kernel execution context |
224 | | - * @param thisx Quantized matrix in row-major order |
225 | | - * @param that Vector to multiply |
226 | | - * @param out Output vector (result) |
227 | | - * @param dim1 Vector dimension |
228 | | - */ |
229 | | - public static void matmulTornadoQ4Optimized(KernelContext context, ByteArray thisx, FloatArray that, FloatArray out, int dim1) { |
230 | | - final int BLOCK_SIZE = 32; // Block size for Q4_0 |
231 | | - final int BYTES_PER_BLOCK = 2 + BLOCK_SIZE / 2; // 2 bytes for scale + 16 bytes for packed values |
232 | | - final int UNROLL_FACTOR = 16; // Unroll factor for better performance |
233 | | - final int VECTOR_SIZE = 4; // Process 4 elements at once with vectorization |
234 | | - |
235 | | - int idx = context.globalIdx; |
236 | | - float result = 0f; |
237 | | - int thisOffset = idx * dim1; |
238 | | - |
239 | | - // Cache last block index and scale to avoid redundant decoding |
240 | | - int lastBlockIndex = -1; |
241 | | - float cachedScale = 0f; |
242 | | - |
243 | | - // Early calculation of block boundaries to reduce in-loop calculations |
244 | | - int numFullUnrolls = dim1 / UNROLL_FACTOR; |
245 | | - int remainingStart = numFullUnrolls * UNROLL_FACTOR; |
246 | | - |
247 | | - // Pre-calculate first block index to potentially save work in the loop |
248 | | - int firstIndex = thisOffset; |
249 | | - int firstBlockIndex = firstIndex / BLOCK_SIZE; |
250 | | - int firstBlockOffset = firstBlockIndex * BYTES_PER_BLOCK; |
251 | | - |
252 | | - // Initial scale calculation outside the loop |
253 | | - int scaleByte1 = thisx.get(firstBlockOffset) & 0xFF; |
254 | | - int scaleByte2 = thisx.get(firstBlockOffset + 1) & 0xFF; |
255 | | - short scaleFloat16 = (short) ((scaleByte2 << 8) | scaleByte1); |
256 | | - cachedScale = decodeFloat16Fast(scaleFloat16); |
257 | | - lastBlockIndex = firstBlockIndex; |
258 | | - |
259 | | - // Main loop with increased unrolling |
260 | | - for (int j = 0; j < numFullUnrolls; j++) { |
261 | | - int baseIdx = j * UNROLL_FACTOR; |
262 | | - |
263 | | - // Process elements in groups of UNROLL_FACTOR |
264 | | - for (int k = 0; k < UNROLL_FACTOR; k += VECTOR_SIZE) { |
265 | | - // Process VECTOR_SIZE elements in each iteration |
266 | | - for (int v = 0; v < VECTOR_SIZE; v++) { |
267 | | - int index = thisOffset + baseIdx + k + v; |
268 | | - int blockIndex = index / BLOCK_SIZE; |
269 | | - |
270 | | - // Only decode scale if we're in a new block |
271 | | - if (blockIndex != lastBlockIndex) { |
272 | | - int blockOffset = blockIndex * BYTES_PER_BLOCK; |
273 | | - int newScaleByte1 = thisx.get(blockOffset) & 0xFF; |
274 | | - int newScaleByte2 = thisx.get(blockOffset + 1) & 0xFF; |
275 | | - short newScaleFloat16 = (short) ((newScaleByte2 << 8) | newScaleByte1); |
276 | | - cachedScale = decodeFloat16Fast(newScaleFloat16); |
277 | | - lastBlockIndex = blockIndex; |
278 | | - } |
279 | | - |
280 | | - int withinBlockIndex = index % BLOCK_SIZE; |
281 | | - int blockOffset = blockIndex * BYTES_PER_BLOCK; |
282 | | - |
283 | | - // Extract Q4 value from packed byte |
284 | | - byte quant; |
285 | | - if (withinBlockIndex < BLOCK_SIZE / 2) { |
286 | | - // Lower nibble |
287 | | - quant = (byte) (thisx.get(blockOffset + 2 + withinBlockIndex) & 0x0F); |
288 | | - } else { |
289 | | - // Upper nibble |
290 | | - quant = (byte) ((thisx.get(blockOffset + 2 + withinBlockIndex - BLOCK_SIZE / 2) >>> 4) & 0x0F); |
291 | | - } |
292 | | - |
293 | | - // Apply Q4 offset and scale |
294 | | - quant -= 8; // Q4 uses -8 offset |
295 | | - |
296 | | - // Dequantize and accumulate |
297 | | - result = fma(quant * cachedScale, that.get(baseIdx + k + v), result); |
298 | | - } |
299 | | - } |
300 | | - } |
301 | | - |
302 | | - // Handle remaining elements |
303 | | - for (int j = remainingStart; j < dim1; j++) { |
304 | | - int index = thisOffset + j; |
305 | | - int blockIndex = index / BLOCK_SIZE; |
306 | | - |
307 | | - // Only decode scale if we're in a new block |
308 | | - if (blockIndex != lastBlockIndex) { |
309 | | - int blockOffset = blockIndex * BYTES_PER_BLOCK; |
310 | | - int scaleByte11 = thisx.get(blockOffset) & 0xFF; |
311 | | - int scaleByte22 = thisx.get(blockOffset + 1) & 0xFF; |
312 | | - short scaleFloat166 = (short) ((scaleByte22 << 8) | scaleByte11); |
313 | | - cachedScale = decodeFloat16Fast(scaleFloat166); |
314 | | - lastBlockIndex = blockIndex; |
315 | | - } |
316 | | - |
317 | | - int withinBlockIndex = index % BLOCK_SIZE; |
318 | | - int blockOffset = blockIndex * BYTES_PER_BLOCK; |
319 | | - |
320 | | - // Extract Q4 value from packed byte |
321 | | - byte quant; |
322 | | - if (withinBlockIndex < BLOCK_SIZE / 2) { |
323 | | - // Lower nibble |
324 | | - quant = (byte) (thisx.get(blockOffset + 2 + withinBlockIndex) & 0x0F); |
325 | | - } else { |
326 | | - // Upper nibble |
327 | | - quant = (byte) ((thisx.get(blockOffset + 2 + withinBlockIndex - BLOCK_SIZE / 2) >>> 4) & 0x0F); |
328 | | - } |
329 | | - |
330 | | - // Apply Q4 offset and scale |
331 | | - quant -= 8; // Q4 uses -8 offset |
332 | | - |
333 | | - // Dequantize and accumulate |
334 | | - result = fma(quant * cachedScale, that.get(j), result); |
335 | | - } |
336 | | - |
337 | | - out.set(idx, result); |
338 | | - } |
339 | | - |
340 | | - /** |
341 | | - * Fast decoder for IEEE 754 half-precision (Float16) floating point format. |
342 | | - * Converts 16-bit encoded values to 32-bit float. |
343 | | - * Float16 format: |
344 | | - * - Bit 15: Sign bit |
345 | | - * - Bits 14-10: Exponent (5 bits) |
346 | | - * - Bits 9-0: Mantissa/Fraction (10 bits) |
347 | | - * Special cases: |
348 | | - * - Exponent all 1s: Infinity (mantissa 0) or NaN (mantissa non-zero) |
349 | | - * - Exponent all 0s: Zero (mantissa 0) or denormalized (mantissa non-zero) |
350 | | - * |
351 | | - * @param value 16-bit encoded Float16 value |
352 | | - * @return Decoded 32-bit float value |
353 | | - */ |
354 | | - private static float decodeFloat16Fast(short value) { |
355 | | - // Split the components |
356 | | - int sign = (value & 0x8000) >>> 15; |
357 | | - int exp = (value & 0x7C00) >>> 10; |
358 | | - int frac = value & 0x03FF; |
359 | | - |
360 | | - // Handle special cases with direct returns for common values |
361 | | - if (exp == 0x1F) { |
362 | | - return sign == 0 ? Float.POSITIVE_INFINITY : Float.NEGATIVE_INFINITY; |
363 | | - } |
364 | | - |
365 | | - if (exp == 0) { |
366 | | - if (frac == 0) { |
367 | | - return sign == 0 ? 0.0f : -0.0f; |
368 | | - } |
369 | | - // Optimize denormalized numbers with precomputed constant |
370 | | - float result = frac * 5.9604645E-8f; // Precomputed 2^-24 |
371 | | - return sign == 0 ? result : -result; |
372 | | - } |
373 | | - |
374 | | - // Normal case - optimize with fewer operations |
375 | | - float result = 1.0f + (frac / 1024.0f); |
376 | | - |
377 | | - // Use bitshift instead of pow for integer powers of 2 |
378 | | - if (exp < 15) { |
379 | | - int shift = 15 - exp; |
380 | | - result /= (1 << shift); |
381 | | - } else { |
382 | | - int shift = exp - 15; |
383 | | - result *= (1 << shift); |
384 | | - } |
385 | | - |
386 | | - return sign == 0 ? result : -result; |
387 | | - } |
388 | | - |
389 | | - |
390 | | - /** |
391 | | - * Fused multiply-add operation: a * b + c |
392 | | - * This is typically implemented as a single instruction on modern hardware, |
393 | | - * providing better performance and numeric precision than separate operations. |
394 | | - * |
395 | | - * @param a First factor |
396 | | - * @param b Second factor |
397 | | - * @param c Value to add |
398 | | - * @return Result of a * b + c |
399 | | - */ |
400 | | - private static float fma(float a, float b, float c) { |
401 | | - return a * b + c; |
402 | | - } |
403 | | - |
404 | 97 | } |
0 commit comments