Skip to content

Commit 7bb5d8e

Browse files
committed
Increase work group size allocation for the final projection to increase work
Doubled the LOCAL_WORK_GROUP_SIZE_ALLOC value in several key areas to enhance computational parallelism and resource utilization. Removed an unused variable `gid` in `TransformerComputeKernelsLayered` for cleaner code. These changes aim to optimize kernel execution and ensure better scalability.
1 parent 0105645 commit 7bb5d8e

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,13 @@ private TaskGraph configureQuantizedMatrixVectorFinalWeight(TaskGraph logits) {
187187
logits.task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, //
188188
context,
189189
state.wrapX, state.wrapLogits, weights.wclsHalfFloat, //
190-
config.dim, config.vocabularySize, LOCAL_WORK_GROUP_SIZE_ALLOC); //
190+
config.dim, config.vocabularySize, LOCAL_WORK_GROUP_SIZE_ALLOC * 2); //
191191
break;
192192
case Q4_0:
193193
logits.task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, //
194194
context,
195195
state.wrapX, state.wrapLogits, weights.wclsHalfFloat, //
196-
config.dim, config.vocabularySize, LOCAL_WORK_GROUP_SIZE_ALLOC); //
196+
config.dim, config.vocabularySize, LOCAL_WORK_GROUP_SIZE_ALLOC * 2); //
197197
break;
198198
default:
199199
throw new UnsupportedOperationException("Unsupported weight quantization type: " + weights.weightType + ". Only Q8_0 and Q4_0 are supported.");
@@ -344,9 +344,9 @@ private GridScheduler setupGridSchedulersLayered() {
344344
// Vocabulary worker configuration
345345
// OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.vocabularySize,1,1], localWorkSize=[16,1,1])
346346
// CUDA equivalent: kernel<<<dim3((config.vocabularySize+15)/16,1,1), dim3(16,1,1)>>>
347-
int vocabSizeRowMajor = config.vocabularySize * LOCAL_WORK_GROUP_SIZE_ALLOC;
347+
int vocabSizeRowMajor = config.vocabularySize * LOCAL_WORK_GROUP_SIZE_ALLOC * 2 ;
348348
WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
349-
vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
349+
vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * 2, 1, 1);
350350

351351
tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker);
352352
tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker);

src/main/java/com/example/tornadovm/TransformerComputeKernelsLayered.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,6 @@ public static void processHeadsFlashAttention(KernelContext context, FloatArray
291291

292292
// Thread and workgroup information
293293
int tid = context.localIdx;
294-
int gid = context.globalIdx; // gid is not actively used in the core logic here
295294
int h = context.groupIdx; // Each workgroup processes one head
296295
int localSize = context.localGroupSizeX;
297296

0 commit comments

Comments
 (0)