@@ -144,7 +144,7 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
144144 .transferToDevice (DataTransferMode .FIRST_EXECUTION ,
145145 context ,
146146 state .wrapLogits ,
147- weights .wclsByteArray ,
147+ weights .wclsHalfFloat ,
148148 weights .rms_final_weight_as_floatArray
149149 )
150150 .task ("reductionsOneBlockLogits" , TransformerComputeKernels ::reductionOneBlockWithLayer , context , state .tempLogits ,
@@ -184,14 +184,16 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
184184 private TaskGraph configureQuantizedMatrixVectorFinalWeight (TaskGraph logits ) {
185185 switch (weights .weightType ) {
186186 case Q8_0 :
187- logits .task ("projection" , TransformerComputeKernels ::matmulTornadoQ8Optimized , //
188- context , weights .wclsByteArray , state .wrapX , //
189- state .wrapLogits , config .dim ); //
187+ logits .task ("projection" , TransformerComputeKernelsLayered ::matrixVectorGeneric , //
188+ context ,
189+ state .wrapX , state .wrapLogits , weights .wclsHalfFloat , //
190+ config .dim , config .vocabularySize , LOCAL_WORK_GROUP_SIZE_ALLOC ); //
190191 break ;
191192 case Q4_0 :
192- logits .task ("projection" , TransformerComputeKernels ::matmulTornadoQ4Optimized , //
193- context , weights .wclsByteArray , state .wrapX , //
194- state .wrapLogits , config .dim ); //
193+ logits .task ("projection" , TransformerComputeKernelsLayered ::matrixVectorGeneric , //
194+ context ,
195+ state .wrapX , state .wrapLogits , weights .wclsHalfFloat , //
196+ config .dim , config .vocabularySize , LOCAL_WORK_GROUP_SIZE_ALLOC ); //
195197 break ;
196198 default :
197199 throw new UnsupportedOperationException ("Unsupported weight quantization type: " + weights .weightType + ". Only Q8_0 and Q4_0 are supported." );
@@ -342,9 +344,9 @@ private GridScheduler setupGridSchedulersLayered() {
342344 // Vocabulary worker configuration
343345 // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.vocabularySize,1,1], localWorkSize=[16,1,1])
344346 // CUDA equivalent: kernel<<<dim3((config.vocabularySize+15)/16,1,1), dim3(16,1,1)>>>
345- WorkerGrid vocabWorker = new WorkerGrid1D ( config .vocabularySize ) ;
346- vocabWorker . setGlobalWork ( config . vocabularySize , 1 , 1 );
347- vocabWorker .setLocalWork (16 , 1 , 1 );
347+ int vocabSizeRowMajor = config .vocabularySize * LOCAL_WORK_GROUP_SIZE_ALLOC ;
348+ WorkerGrid vocabWorker = new WorkerGrid1D ( vocabSizeRowMajor );
349+ vocabWorker .setLocalWork (LOCAL_WORK_GROUP_SIZE_ALLOC , 1 , 1 );
348350
349351 tornadoForwardScheduler .addWorkerGrid ("logits.projection" , vocabWorker );
350352 tornadoForwardScheduler .addWorkerGrid ("logits.reductionsOneBlockLogits" , rmsNormWorker );
0 commit comments