@@ -194,7 +194,7 @@ public static void ropeRotation(KernelContext context, IntArray positionHolder,
194194 * @param contextLength Maximum context length
195195 */
196196 public static void processHeadsParallel (FloatArray q , FloatArray key_cache , FloatArray value_cache , FloatArray xb , int nHeads , int headSize , int kvDim , int kvMul , int seqLen ,
197- IntArray positionHolder , FloatArray wrapAtt , int layer , int contextLength ) {
197+ IntArray positionHolder , FloatArray wrapAtt , int layer , int contextLength ) {
198198
199199 int pos = positionHolder .get (0 );
200200 int loff = layer * contextLength * kvDim ;
@@ -228,7 +228,7 @@ public static void processHeadsParallel(FloatArray q, FloatArray key_cache, Floa
228228 * @param wrapAtt Attention weights buffer
229229 */
230230 private static void processHeadTornado (FloatArray allQ , FloatArray key_cache , FloatArray value_cache , FloatArray allXb , int h , int headSize , int kvDim , int kvMul , long loff , int pos ,
231- FloatArray wrapAtt ) {
231+ FloatArray wrapAtt ) {
232232
233233 // Base index for this head's attention weights
234234 int headOffset = h * (pos + 1 );
@@ -285,6 +285,140 @@ private static void processHeadTornado(FloatArray allQ, FloatArray key_cache, Fl
285285 }
286286 }
287287
288+ public static void processHeadsFlashAttention (
289+ KernelContext context ,
290+ FloatArray q ,
291+ FloatArray key_cache ,
292+ FloatArray value_cache ,
293+ FloatArray xb ,
294+ int nHeads ,
295+ int headSize ,
296+ int kvDim ,
297+ int kvMul ,
298+ IntArray positionHolder ,
299+ int layer ,
300+ int contextLength ) {
301+
302+ // Thread and workgroup information
303+ int tid = context .localIdx ;
304+ int gid = context .globalIdx ; // gid is not actively used in the core logic here
305+ int h = context .groupIdx ; // Each workgroup processes one head
306+ int localSize = context .localGroupSizeX ;
307+
308+ // Early exit if this workgroup is beyond our head count
309+ // This relies on the kernel being launched with nHeads workgroups.
310+ if (h >= nHeads ) return ;
311+
312+ int pos = positionHolder .get (0 );
313+ int loff = layer * contextLength * kvDim ;
314+ int kvHeadIdx = h / kvMul ;
315+ int BLOCK_SIZE_C = 4 ;
316+
317+ // Allocate shared memory for tiled computation
318+ float [] q_shared = context .allocateFloatLocalArray (headSize );
319+ float [] k_tile = context .allocateFloatLocalArray (BLOCK_SIZE_C * headSize );
320+ float [] v_tile = context .allocateFloatLocalArray (BLOCK_SIZE_C * headSize );
321+ float [] s_tile = context .allocateFloatLocalArray (BLOCK_SIZE_C );
322+ float [] shared_tile_max_holder = context .allocateFloatLocalArray (1 ); // FIX: For broadcasting tile max
323+
324+ // Thread-local accumulators for online softmax
325+ float maxScore = Float .NEGATIVE_INFINITY ;
326+ float sumExp = 0.0f ;
327+
328+ // Thread-local output accumulation
329+ float [] output = new float [headSize ];
330+ for (int i = 0 ; i < headSize ; i ++) {
331+ output [i ] = 0.0f ;
332+ }
333+
334+ // Load query vector into shared memory
335+ for (int i = tid ; i < headSize ; i += localSize ) {
336+ q_shared [i ] = q .get (h * headSize + i );
337+ }
338+
339+ context .localBarrier ();
340+
341+ // Process sequence in tiles
342+ for (int tileC = 0 ; tileC <= pos ; tileC += BLOCK_SIZE_C ) {
343+ int tileEnd = Math .min (tileC + BLOCK_SIZE_C - 1 , pos );
344+
345+ // Load key and value vectors for this tile
346+ // Each thread loads a portion of the K and V vectors for the tile
347+ for (int tIdxInSeq = tileC + tid ; tIdxInSeq <= tileEnd ; tIdxInSeq += localSize ) {
348+ int k_v_idx_in_tile = tIdxInSeq - tileC ; // 0, 1, 2, or 3 for this tile
349+ int tileMemOffset = k_v_idx_in_tile * headSize ;
350+ for (int d = 0 ; d < headSize ; d ++) {
351+ int kvCacheAbsolutePos = tIdxInSeq ;
352+ int kvOffset = loff + kvCacheAbsolutePos * kvDim + kvHeadIdx * headSize + d ;
353+ k_tile [tileMemOffset + d ] = key_cache .get (kvOffset );
354+ v_tile [tileMemOffset + d ] = value_cache .get (kvOffset );
355+ }
356+ }
357+
358+ context .localBarrier ();
359+
360+ // Compute attention scores for this tile
361+ // Each thread computes one score for the tile
362+ for (int tIdxInSeq = tileC + tid ; tIdxInSeq <= tileEnd ; tIdxInSeq += localSize ) {
363+ int score_idx_in_tile = tIdxInSeq - tileC ; // 0, 1, 2, or 3 for this tile
364+
365+ float score = 0.0f ;
366+ for (int d = 0 ; d < headSize ; d ++) {
367+ score += q_shared [d ] * k_tile [score_idx_in_tile * headSize + d ];
368+ }
369+ score /= TornadoMath .sqrt (headSize );
370+ s_tile [score_idx_in_tile ] = score ;
371+ }
372+
373+ context .localBarrier ();
374+
375+ // Find max score in this tile (all threads compute it redundantly over the small s_tile)
376+ float tileLocalMax = Float .NEGATIVE_INFINITY ;
377+ for (int i = 0 ; i <= tileEnd - tileC ; i ++) { // Iterate over valid scores in s_tile
378+ if (s_tile [i ] > tileLocalMax ) {
379+ tileLocalMax = s_tile [i ];
380+ }
381+ }
382+
383+ // Broadcast max to all threads via shared memory
384+ if (tid == 0 ) {
385+ shared_tile_max_holder [0 ] = tileLocalMax ; // FIX: Use dedicated holder
386+ }
387+ context .localBarrier ();
388+ float currentTileMax = shared_tile_max_holder [0 ]; // FIX: Read from dedicated holder
389+
390+ // Determine if we need to rescale previous results
391+ float newMax = Math .max (maxScore , currentTileMax );
392+ if (newMax != maxScore && maxScore != Float .NEGATIVE_INFINITY ) {
393+ float scale = TornadoMath .exp (maxScore - newMax );
394+ sumExp *= scale ;
395+ for (int d = 0 ; d < headSize ; d ++) {
396+ output [d ] *= scale ;
397+ }
398+ }
399+ maxScore = newMax ;
400+
401+ // Process each key-value pair using original scores from s_tile
402+ // All threads iterate over all scores in the current tile
403+ for (int t_idx_in_s_tile = 0 ; t_idx_in_s_tile <= tileEnd - tileC ; t_idx_in_s_tile ++) {
404+ // s_tile[t_idx_in_s_tile] now correctly refers to the original score
405+ float expScore = TornadoMath .exp (s_tile [t_idx_in_s_tile ] - maxScore );
406+ sumExp += expScore ;
407+
408+ for (int d = 0 ; d < headSize ; d ++) {
409+ output [d ] += expScore * v_tile [t_idx_in_s_tile * headSize + d ];
410+ }
411+ }
412+ context .localBarrier (); // Ensure all threads finish with s_tile, k_tile, v_tile before next tile load
413+ }
414+
415+ // Normalize and write final results
416+ float normFactor = (sumExp > 0.0f ) ? (1.0f / sumExp ) : 0.0f ; // Avoid division by zero, return 0 if sumExp is 0
417+ for (int d = tid ; d < headSize ; d += localSize ) {
418+ xb .set (h * headSize + d , output [d ] * normFactor );
419+ }
420+ }
421+
288422 /**
289423 * Performs optimized matrix-vector multiplication where each work group
290424 * processes one row of the matrix.
0 commit comments