Skip to content

Commit 6d18d95

Browse files
committed
Remove redundant methods and cleanup unused imports
1 parent 45bfbbe commit 6d18d95

File tree

2 files changed

+14
-332
lines changed

2 files changed

+14
-332
lines changed

src/main/java/com/example/tornadovm/TransformerComputeKernels.java

Lines changed: 0 additions & 307 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import uk.ac.manchester.tornado.api.KernelContext;
44
import uk.ac.manchester.tornado.api.math.TornadoMath;
5-
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
65
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
76

87
public class TransformerComputeKernels {
@@ -95,310 +94,4 @@ public static void reductionOneBlock2WithLogits(KernelContext context, FloatArra
9594
output.set(gid, weights.get(gid) * (ss * output.get(gid)));
9695
}
9796

98-
/**
99-
* Optimized matrix-vector multiplication for Q8_0 quantized weights.
100-
* Q8_0 format: 8-bit quantization with FP16 scale per 32-element block.
101-
*
102-
* Block format:
103-
* - Bytes 0-1: Float16 scale value (big-endian)
104-
* - Bytes 2-33: 32 quantized values (int8)
105-
*
106-
* Optimization features:
107-
* - Loop unrolling (16x factor)
108-
* - Vectorization (4 elements per iteration)
109-
* - Scale caching to avoid redundant decompression
110-
* - Fused multiply-add for accumulation
111-
*
112-
* @param context Kernel execution context
113-
* @param thisx Quantized matrix in row-major order
114-
* @param that Vector to multiply
115-
* @param out Output vector (result)
116-
* @param dim1 Vector dimension
117-
*/
118-
public static void matmulTornadoQ8Optimized(KernelContext context, ByteArray thisx, FloatArray that, FloatArray out, int dim1) {
119-
final int BLOCK_SIZE = 32; // Block size used in quantization
120-
final int BYTES_PER_BLOCK = 2 + BLOCK_SIZE; // 2 bytes for scale + block_size bytes for values
121-
final int UNROLL_FACTOR = 16; // Increased unroll factor for better performance
122-
final int VECTOR_SIZE = 4; // Process 4 elements at once with vectorization
123-
124-
int idx = context.globalIdx;
125-
float result = 0f;
126-
int thisOffset = idx * dim1;
127-
128-
// Cache last block index and scale to avoid redundant decoding
129-
int lastBlockIndex = -1;
130-
float cachedScale = 0f;
131-
132-
// Early calculation of block boundaries to reduce in-loop calculations
133-
int numFullUnrolls = dim1 / UNROLL_FACTOR;
134-
int remainingStart = numFullUnrolls * UNROLL_FACTOR;
135-
136-
// Pre-calculate first block index to potentially save work in the loop
137-
int firstIndex = thisOffset;
138-
int firstBlockIndex = firstIndex / BLOCK_SIZE;
139-
int firstBlockOffset = firstBlockIndex * BYTES_PER_BLOCK;
140-
141-
// Initial scale calculation outside the loop
142-
int scaleByte1 = thisx.get(firstBlockOffset) & 0xFF;
143-
int scaleByte2 = thisx.get(firstBlockOffset + 1) & 0xFF;
144-
short scaleFloat16 = (short) ((scaleByte2 << 8) | scaleByte1);
145-
cachedScale = decodeFloat16Fast(scaleFloat16);
146-
lastBlockIndex = firstBlockIndex;
147-
148-
// Main loop with increased unrolling
149-
for (int j = 0; j < numFullUnrolls; j++) {
150-
int baseIdx = j * UNROLL_FACTOR;
151-
152-
// Process elements in groups of UNROLL_FACTOR
153-
for (int k = 0; k < UNROLL_FACTOR; k += VECTOR_SIZE) {
154-
// Process VECTOR_SIZE elements in each iteration
155-
for (int v = 0; v < VECTOR_SIZE; v++) {
156-
int index = thisOffset + baseIdx + k + v;
157-
int blockIndex = index / BLOCK_SIZE;
158-
159-
// Only decode scale if we're in a new block
160-
if (blockIndex != lastBlockIndex) {
161-
int blockOffset = blockIndex * BYTES_PER_BLOCK;
162-
int newScaleByte1 = thisx.get(blockOffset) & 0xFF;
163-
int newScaleByte2 = thisx.get(blockOffset + 1) & 0xFF;
164-
short newScaleFloat16 = (short) ((newScaleByte2 << 8) | newScaleByte1);
165-
cachedScale = decodeFloat16Fast(newScaleFloat16);
166-
lastBlockIndex = blockIndex;
167-
}
168-
169-
int withinBlockIndex = index % BLOCK_SIZE;
170-
int blockOffset = blockIndex * BYTES_PER_BLOCK;
171-
172-
// Read quantized value
173-
byte quantized = thisx.get(blockOffset + 2 + withinBlockIndex);
174-
175-
// Dequantize and accumulate
176-
result = fma(quantized * cachedScale, that.get(baseIdx + k + v), result);
177-
}
178-
}
179-
}
180-
181-
// Handle remaining elements
182-
for (int j = remainingStart; j < dim1; j++) {
183-
int index = thisOffset + j;
184-
int blockIndex = index / BLOCK_SIZE;
185-
186-
// Only decode scale if we're in a new block
187-
if (blockIndex != lastBlockIndex) {
188-
int blockOffset = blockIndex * BYTES_PER_BLOCK;
189-
int scaleByte11 = thisx.get(blockOffset) & 0xFF;
190-
int scaleByte22 = thisx.get(blockOffset + 1) & 0xFF;
191-
short scaleFloat166 = (short) ((scaleByte22 << 8) | scaleByte11);
192-
cachedScale = decodeFloat16Fast(scaleFloat166);
193-
lastBlockIndex = blockIndex;
194-
}
195-
196-
int withinBlockIndex = index % BLOCK_SIZE;
197-
int blockOffset = blockIndex * BYTES_PER_BLOCK;
198-
199-
// Read quantized value
200-
byte quantized = thisx.get(blockOffset + 2 + withinBlockIndex);
201-
202-
// Dequantize and accumulate
203-
result = fma(quantized * cachedScale, that.get(j), result);
204-
}
205-
206-
out.set(idx, result);
207-
}
208-
209-
/**
210-
* Optimized matrix-vector multiplication for Q4_0 quantized weights.
211-
* Q4_0 format: 4-bit quantization with FP16 scale per 32-element block.
212-
*
213-
* Block format:
214-
* - Bytes 0-1: Float16 scale value (big-endian)
215-
* - Bytes 2-17: 16 bytes storing 32 packed 4-bit values
216-
*
217-
* Each byte stores two 4-bit values:
218-
* - Lower nibble: elements 0, 2, 4, ...
219-
* - Upper nibble: elements 1, 3, 5, ...
220-
*
221-
* Q4_0 uses signed 4-bit values with -8 offset (range: -8 to 7)
222-
*
223-
* @param context Kernel execution context
224-
* @param thisx Quantized matrix in row-major order
225-
* @param that Vector to multiply
226-
* @param out Output vector (result)
227-
* @param dim1 Vector dimension
228-
*/
229-
public static void matmulTornadoQ4Optimized(KernelContext context, ByteArray thisx, FloatArray that, FloatArray out, int dim1) {
230-
final int BLOCK_SIZE = 32; // Block size for Q4_0
231-
final int BYTES_PER_BLOCK = 2 + BLOCK_SIZE / 2; // 2 bytes for scale + 16 bytes for packed values
232-
final int UNROLL_FACTOR = 16; // Unroll factor for better performance
233-
final int VECTOR_SIZE = 4; // Process 4 elements at once with vectorization
234-
235-
int idx = context.globalIdx;
236-
float result = 0f;
237-
int thisOffset = idx * dim1;
238-
239-
// Cache last block index and scale to avoid redundant decoding
240-
int lastBlockIndex = -1;
241-
float cachedScale = 0f;
242-
243-
// Early calculation of block boundaries to reduce in-loop calculations
244-
int numFullUnrolls = dim1 / UNROLL_FACTOR;
245-
int remainingStart = numFullUnrolls * UNROLL_FACTOR;
246-
247-
// Pre-calculate first block index to potentially save work in the loop
248-
int firstIndex = thisOffset;
249-
int firstBlockIndex = firstIndex / BLOCK_SIZE;
250-
int firstBlockOffset = firstBlockIndex * BYTES_PER_BLOCK;
251-
252-
// Initial scale calculation outside the loop
253-
int scaleByte1 = thisx.get(firstBlockOffset) & 0xFF;
254-
int scaleByte2 = thisx.get(firstBlockOffset + 1) & 0xFF;
255-
short scaleFloat16 = (short) ((scaleByte2 << 8) | scaleByte1);
256-
cachedScale = decodeFloat16Fast(scaleFloat16);
257-
lastBlockIndex = firstBlockIndex;
258-
259-
// Main loop with increased unrolling
260-
for (int j = 0; j < numFullUnrolls; j++) {
261-
int baseIdx = j * UNROLL_FACTOR;
262-
263-
// Process elements in groups of UNROLL_FACTOR
264-
for (int k = 0; k < UNROLL_FACTOR; k += VECTOR_SIZE) {
265-
// Process VECTOR_SIZE elements in each iteration
266-
for (int v = 0; v < VECTOR_SIZE; v++) {
267-
int index = thisOffset + baseIdx + k + v;
268-
int blockIndex = index / BLOCK_SIZE;
269-
270-
// Only decode scale if we're in a new block
271-
if (blockIndex != lastBlockIndex) {
272-
int blockOffset = blockIndex * BYTES_PER_BLOCK;
273-
int newScaleByte1 = thisx.get(blockOffset) & 0xFF;
274-
int newScaleByte2 = thisx.get(blockOffset + 1) & 0xFF;
275-
short newScaleFloat16 = (short) ((newScaleByte2 << 8) | newScaleByte1);
276-
cachedScale = decodeFloat16Fast(newScaleFloat16);
277-
lastBlockIndex = blockIndex;
278-
}
279-
280-
int withinBlockIndex = index % BLOCK_SIZE;
281-
int blockOffset = blockIndex * BYTES_PER_BLOCK;
282-
283-
// Extract Q4 value from packed byte
284-
byte quant;
285-
if (withinBlockIndex < BLOCK_SIZE / 2) {
286-
// Lower nibble
287-
quant = (byte) (thisx.get(blockOffset + 2 + withinBlockIndex) & 0x0F);
288-
} else {
289-
// Upper nibble
290-
quant = (byte) ((thisx.get(blockOffset + 2 + withinBlockIndex - BLOCK_SIZE / 2) >>> 4) & 0x0F);
291-
}
292-
293-
// Apply Q4 offset and scale
294-
quant -= 8; // Q4 uses -8 offset
295-
296-
// Dequantize and accumulate
297-
result = fma(quant * cachedScale, that.get(baseIdx + k + v), result);
298-
}
299-
}
300-
}
301-
302-
// Handle remaining elements
303-
for (int j = remainingStart; j < dim1; j++) {
304-
int index = thisOffset + j;
305-
int blockIndex = index / BLOCK_SIZE;
306-
307-
// Only decode scale if we're in a new block
308-
if (blockIndex != lastBlockIndex) {
309-
int blockOffset = blockIndex * BYTES_PER_BLOCK;
310-
int scaleByte11 = thisx.get(blockOffset) & 0xFF;
311-
int scaleByte22 = thisx.get(blockOffset + 1) & 0xFF;
312-
short scaleFloat166 = (short) ((scaleByte22 << 8) | scaleByte11);
313-
cachedScale = decodeFloat16Fast(scaleFloat166);
314-
lastBlockIndex = blockIndex;
315-
}
316-
317-
int withinBlockIndex = index % BLOCK_SIZE;
318-
int blockOffset = blockIndex * BYTES_PER_BLOCK;
319-
320-
// Extract Q4 value from packed byte
321-
byte quant;
322-
if (withinBlockIndex < BLOCK_SIZE / 2) {
323-
// Lower nibble
324-
quant = (byte) (thisx.get(blockOffset + 2 + withinBlockIndex) & 0x0F);
325-
} else {
326-
// Upper nibble
327-
quant = (byte) ((thisx.get(blockOffset + 2 + withinBlockIndex - BLOCK_SIZE / 2) >>> 4) & 0x0F);
328-
}
329-
330-
// Apply Q4 offset and scale
331-
quant -= 8; // Q4 uses -8 offset
332-
333-
// Dequantize and accumulate
334-
result = fma(quant * cachedScale, that.get(j), result);
335-
}
336-
337-
out.set(idx, result);
338-
}
339-
340-
/**
341-
* Fast decoder for IEEE 754 half-precision (Float16) floating point format.
342-
* Converts 16-bit encoded values to 32-bit float.
343-
* Float16 format:
344-
* - Bit 15: Sign bit
345-
* - Bits 14-10: Exponent (5 bits)
346-
* - Bits 9-0: Mantissa/Fraction (10 bits)
347-
* Special cases:
348-
* - Exponent all 1s: Infinity (mantissa 0) or NaN (mantissa non-zero)
349-
* - Exponent all 0s: Zero (mantissa 0) or denormalized (mantissa non-zero)
350-
*
351-
* @param value 16-bit encoded Float16 value
352-
* @return Decoded 32-bit float value
353-
*/
354-
private static float decodeFloat16Fast(short value) {
355-
// Split the components
356-
int sign = (value & 0x8000) >>> 15;
357-
int exp = (value & 0x7C00) >>> 10;
358-
int frac = value & 0x03FF;
359-
360-
// Handle special cases with direct returns for common values
361-
if (exp == 0x1F) {
362-
return sign == 0 ? Float.POSITIVE_INFINITY : Float.NEGATIVE_INFINITY;
363-
}
364-
365-
if (exp == 0) {
366-
if (frac == 0) {
367-
return sign == 0 ? 0.0f : -0.0f;
368-
}
369-
// Optimize denormalized numbers with precomputed constant
370-
float result = frac * 5.9604645E-8f; // Precomputed 2^-24
371-
return sign == 0 ? result : -result;
372-
}
373-
374-
// Normal case - optimize with fewer operations
375-
float result = 1.0f + (frac / 1024.0f);
376-
377-
// Use bitshift instead of pow for integer powers of 2
378-
if (exp < 15) {
379-
int shift = 15 - exp;
380-
result /= (1 << shift);
381-
} else {
382-
int shift = exp - 15;
383-
result *= (1 << shift);
384-
}
385-
386-
return sign == 0 ? result : -result;
387-
}
388-
389-
390-
/**
391-
* Fused multiply-add operation: a * b + c
392-
* This is typically implemented as a single instruction on modern hardware,
393-
* providing better performance and numeric precision than separate operations.
394-
*
395-
* @param a First factor
396-
* @param b Second factor
397-
* @param c Value to add
398-
* @return Result of a * b + c
399-
*/
400-
private static float fma(float a, float b, float c) {
401-
return a * b + c;
402-
}
403-
40497
}

0 commit comments

Comments
 (0)