Skip to content

Commit 0105645

Browse files
committed
Refactor projection tasks and optimize vocabulary worker grid.
1 parent 6d18d95 commit 0105645

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
144144
.transferToDevice(DataTransferMode.FIRST_EXECUTION,
145145
context,
146146
state.wrapLogits,
147-
weights.wclsByteArray,
147+
weights.wclsHalfFloat,
148148
weights.rms_final_weight_as_floatArray
149149
)
150150
.task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits,
@@ -184,14 +184,16 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
184184
private TaskGraph configureQuantizedMatrixVectorFinalWeight(TaskGraph logits) {
185185
switch (weights.weightType) {
186186
case Q8_0:
187-
logits.task("projection", TransformerComputeKernels::matmulTornadoQ8Optimized, //
188-
context, weights.wclsByteArray, state.wrapX, //
189-
state.wrapLogits, config.dim); //
187+
logits.task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, //
188+
context,
189+
state.wrapX, state.wrapLogits, weights.wclsHalfFloat, //
190+
config.dim, config.vocabularySize, LOCAL_WORK_GROUP_SIZE_ALLOC); //
190191
break;
191192
case Q4_0:
192-
logits.task("projection", TransformerComputeKernels::matmulTornadoQ4Optimized, //
193-
context, weights.wclsByteArray, state.wrapX, //
194-
state.wrapLogits, config.dim); //
193+
logits.task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, //
194+
context,
195+
state.wrapX, state.wrapLogits, weights.wclsHalfFloat, //
196+
config.dim, config.vocabularySize, LOCAL_WORK_GROUP_SIZE_ALLOC); //
195197
break;
196198
default:
197199
throw new UnsupportedOperationException("Unsupported weight quantization type: " + weights.weightType + ". Only Q8_0 and Q4_0 are supported.");
@@ -342,9 +344,9 @@ private GridScheduler setupGridSchedulersLayered() {
342344
// Vocabulary worker configuration
343345
// OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.vocabularySize,1,1], localWorkSize=[16,1,1])
344346
// CUDA equivalent: kernel<<<dim3((config.vocabularySize+15)/16,1,1), dim3(16,1,1)>>>
345-
WorkerGrid vocabWorker = new WorkerGrid1D(config.vocabularySize);
346-
vocabWorker.setGlobalWork(config.vocabularySize, 1, 1);
347-
vocabWorker.setLocalWork(16, 1, 1);
347+
int vocabSizeRowMajor = config.vocabularySize * LOCAL_WORK_GROUP_SIZE_ALLOC;
348+
WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
349+
vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
348350

349351
tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker);
350352
tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker);

0 commit comments

Comments
 (0)