Skip to content

Commit bb1676b

Browse files
authored
Merge pull request #7 from mikepapadim/optimization/flash_attention
Replace parallel attention with flash parallel attention
2 parents 5d45815 + a6ac8e2 commit bb1676b

File tree

2 files changed

+141
-6
lines changed

2 files changed

+141
-6
lines changed

src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
113113
config.headSize)
114114
.task("copyToCaches", TransformerComputeKernelsLayered::copyToCache,
115115
state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim, layerIndex, config.contextLength)
116-
.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsParallel,
116+
.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, context,
117117
state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb,
118-
config.numberOfHeads, config.headSize, config.kvDim, config.kvMul, config.vocabularySize,
119-
state.positionHolder, state.wrapAtt, layerIndex, config.contextLength)
118+
config.numberOfHeads, config.headSize, config.kvDim, config.kvMul,
119+
state.positionHolder, layerIndex, config.contextLength)
120120
.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
121121
state.wrapXb, state.wrapX, weights.woLayered[layerIndex], config.dim, config.dim, LOCAL_WORK_GROUP_SIZE_ALLOC)
122122
.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN,
@@ -310,7 +310,8 @@ private GridScheduler setupGridSchedulersLayered() {
310310
// OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.numberOfHeads,1,1], localWorkSize=[4,1,1])
311311
// CUDA equivalent: kernel<<<dim3((config.numberOfHeads+3)/4,1,1), dim3(4,1,1)>>>
312312
WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads);
313-
parallelAttentionWorker.setGlobalWork(config.numberOfHeads, 1, 1);
313+
// the global group work size is numberOfHeads * localWorkGroupSize, where the localWorkGroupSize is currently 4
314+
parallelAttentionWorker.setGlobalWork(config.numberOfHeads * 4, 1, 1);
314315
parallelAttentionWorker.setLocalWork(4, 1, 1); // Set local work size to 4 (for parallel attention)
315316

316317
// Copy to caches worker configuration

src/main/java/com/example/tornadovm/TransformerComputeKernelsLayered.java

Lines changed: 136 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ public static void ropeRotation(KernelContext context, IntArray positionHolder,
194194
* @param contextLength Maximum context length
195195
*/
196196
public static void processHeadsParallel(FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int seqLen,
197-
IntArray positionHolder, FloatArray wrapAtt, int layer, int contextLength) {
197+
IntArray positionHolder, FloatArray wrapAtt, int layer, int contextLength) {
198198

199199
int pos = positionHolder.get(0);
200200
int loff = layer * contextLength * kvDim;
@@ -228,7 +228,7 @@ public static void processHeadsParallel(FloatArray q, FloatArray key_cache, Floa
228228
* @param wrapAtt Attention weights buffer
229229
*/
230230
private static void processHeadTornado(FloatArray allQ, FloatArray key_cache, FloatArray value_cache, FloatArray allXb, int h, int headSize, int kvDim, int kvMul, long loff, int pos,
231-
FloatArray wrapAtt) {
231+
FloatArray wrapAtt) {
232232

233233
// Base index for this head's attention weights
234234
int headOffset = h * (pos + 1);
@@ -285,6 +285,140 @@ private static void processHeadTornado(FloatArray allQ, FloatArray key_cache, Fl
285285
}
286286
}
287287

288+
public static void processHeadsFlashAttention(
289+
KernelContext context,
290+
FloatArray q,
291+
FloatArray key_cache,
292+
FloatArray value_cache,
293+
FloatArray xb,
294+
int nHeads,
295+
int headSize,
296+
int kvDim,
297+
int kvMul,
298+
IntArray positionHolder,
299+
int layer,
300+
int contextLength) {
301+
302+
// Thread and workgroup information
303+
int tid = context.localIdx;
304+
int gid = context.globalIdx; // gid is not actively used in the core logic here
305+
int h = context.groupIdx; // Each workgroup processes one head
306+
int localSize = context.localGroupSizeX;
307+
308+
// Early exit if this workgroup is beyond our head count
309+
// This relies on the kernel being launched with nHeads workgroups.
310+
if (h >= nHeads) return;
311+
312+
int pos = positionHolder.get(0);
313+
int loff = layer * contextLength * kvDim;
314+
int kvHeadIdx = h / kvMul;
315+
int BLOCK_SIZE_C = 4;
316+
317+
// Allocate shared memory for tiled computation
318+
float[] q_shared = context.allocateFloatLocalArray(headSize);
319+
float[] k_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize);
320+
float[] v_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize);
321+
float[] s_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C);
322+
float[] shared_tile_max_holder = context.allocateFloatLocalArray(1); // FIX: For broadcasting tile max
323+
324+
// Thread-local accumulators for online softmax
325+
float maxScore = Float.NEGATIVE_INFINITY;
326+
float sumExp = 0.0f;
327+
328+
// Thread-local output accumulation
329+
float[] output = new float[headSize];
330+
for (int i = 0; i < headSize; i++) {
331+
output[i] = 0.0f;
332+
}
333+
334+
// Load query vector into shared memory
335+
for (int i = tid; i < headSize; i += localSize) {
336+
q_shared[i] = q.get(h * headSize + i);
337+
}
338+
339+
context.localBarrier();
340+
341+
// Process sequence in tiles
342+
for (int tileC = 0; tileC <= pos; tileC += BLOCK_SIZE_C) {
343+
int tileEnd = Math.min(tileC + BLOCK_SIZE_C - 1, pos);
344+
345+
// Load key and value vectors for this tile
346+
// Each thread loads a portion of the K and V vectors for the tile
347+
for (int tIdxInSeq = tileC + tid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) {
348+
int k_v_idx_in_tile = tIdxInSeq - tileC; // 0, 1, 2, or 3 for this tile
349+
int tileMemOffset = k_v_idx_in_tile * headSize;
350+
for (int d = 0; d < headSize; d++) {
351+
int kvCacheAbsolutePos = tIdxInSeq;
352+
int kvOffset = loff + kvCacheAbsolutePos * kvDim + kvHeadIdx * headSize + d;
353+
k_tile[tileMemOffset + d] = key_cache.get(kvOffset);
354+
v_tile[tileMemOffset + d] = value_cache.get(kvOffset);
355+
}
356+
}
357+
358+
context.localBarrier();
359+
360+
// Compute attention scores for this tile
361+
// Each thread computes one score for the tile
362+
for (int tIdxInSeq = tileC + tid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) {
363+
int score_idx_in_tile = tIdxInSeq - tileC; // 0, 1, 2, or 3 for this tile
364+
365+
float score = 0.0f;
366+
for (int d = 0; d < headSize; d++) {
367+
score += q_shared[d] * k_tile[score_idx_in_tile * headSize + d];
368+
}
369+
score /= TornadoMath.sqrt(headSize);
370+
s_tile[score_idx_in_tile] = score;
371+
}
372+
373+
context.localBarrier();
374+
375+
// Find max score in this tile (all threads compute it redundantly over the small s_tile)
376+
float tileLocalMax = Float.NEGATIVE_INFINITY;
377+
for (int i = 0; i <= tileEnd - tileC; i++) { // Iterate over valid scores in s_tile
378+
if (s_tile[i] > tileLocalMax) {
379+
tileLocalMax = s_tile[i];
380+
}
381+
}
382+
383+
// Broadcast max to all threads via shared memory
384+
if (tid == 0) {
385+
shared_tile_max_holder[0] = tileLocalMax; // FIX: Use dedicated holder
386+
}
387+
context.localBarrier();
388+
float currentTileMax = shared_tile_max_holder[0]; // FIX: Read from dedicated holder
389+
390+
// Determine if we need to rescale previous results
391+
float newMax = Math.max(maxScore, currentTileMax);
392+
if (newMax != maxScore && maxScore != Float.NEGATIVE_INFINITY) {
393+
float scale = TornadoMath.exp(maxScore - newMax);
394+
sumExp *= scale;
395+
for (int d = 0; d < headSize; d++) {
396+
output[d] *= scale;
397+
}
398+
}
399+
maxScore = newMax;
400+
401+
// Process each key-value pair using original scores from s_tile
402+
// All threads iterate over all scores in the current tile
403+
for (int t_idx_in_s_tile = 0; t_idx_in_s_tile <= tileEnd - tileC; t_idx_in_s_tile++) {
404+
// s_tile[t_idx_in_s_tile] now correctly refers to the original score
405+
float expScore = TornadoMath.exp(s_tile[t_idx_in_s_tile] - maxScore);
406+
sumExp += expScore;
407+
408+
for (int d = 0; d < headSize; d++) {
409+
output[d] += expScore * v_tile[t_idx_in_s_tile * headSize + d];
410+
}
411+
}
412+
context.localBarrier(); // Ensure all threads finish with s_tile, k_tile, v_tile before next tile load
413+
}
414+
415+
// Normalize and write final results
416+
float normFactor = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f; // Avoid division by zero, return 0 if sumExp is 0
417+
for (int d = tid; d < headSize; d += localSize) {
418+
xb.set(h * headSize + d, output[d] * normFactor);
419+
}
420+
}
421+
288422
/**
289423
* Performs optimized matrix-vector multiplication where each work group
290424
* processes one row of the matrix.

0 commit comments

Comments
 (0)