diff --git a/external/tornadovm b/external/tornadovm index e1d2d12e..da800567 160000 --- a/external/tornadovm +++ b/external/tornadovm @@ -1 +1 @@ -Subproject commit e1d2d12e19f50a8e1d42f15aa0ab3c718bbed2c8 +Subproject commit da80056791474135a310a2faac718fa6361ab754 diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index a59ba97e..48852579 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -247,6 +247,52 @@ public static void ropeRotationPhi3(KernelContext context, IntArray positionHold } } + /** + * Orchestrates parallel multi-head attention computation across all heads. Each head processes attention independently in parallel. + * + * Attention computation: 1. Compute attention scores (Q·K) 2. Apply softmax for attention weights 3. Compute weighted sum of values (attention·V) + * + * @param q + * Query vectors for all heads + * @param key_cache + * Cached key vectors + * @param value_cache + * Cached value vectors + * @param xb + * Output buffer for attention results + * @param nHeads + * Number of attention heads + * @param headSize + * Dimension of each head + * @param kvDim + * Total key/value dimension + * @param kvMul + * Key/value head multiplier for grouped-query attention + * @param seqLen + * Current sequence length + * @param positionHolder + * Array containing position and layer info + * @param wrapAtt + * Buffer for attention weights + * @param layer + * Current transformer layer + * @param contextLength + * Maximum context length + */ + public static void processHeadsParallel(FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int seqLen, + IntArray positionHolder, FloatArray wrapAtt, int layer, int contextLength) { + + int pos = positionHolder.get(0); + int loff = layer * contextLength * kvDim; + + // Parallelize computation across attention heads + for (@Parallel int h = 0; h < nHeads; h++) { + // Process each head in parallel + processHeadTornado(q, key_cache, value_cache, xb, h, headSize, kvDim, kvMul, loff, pos, wrapAtt); + } + } + + /** * Computes attention for a single head. Implements scaled dot-product attention with softmax normalization. * diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java index 53428a40..dde074cd 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java @@ -5,6 +5,8 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tornadovm.GenericLayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import uk.ac.manchester.tornado.api.KernelContext; /** @@ -22,15 +24,19 @@ public abstract class QuantizedLayerPlanner ffnLayerTaskGraphs; - public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config) { - super(taskGraph, state, weights, config); + List ffnLayerTaskGraphs; + + public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config, org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); this.ffnLayerTaskGraphs = setupFFNLayered(); } @@ -111,9 +114,12 @@ TaskGraph setupSingleFFNLayer(FP16Weights weights, Configuration config, int lay weights.w2Layered[layerIndex], weights.w3Layered[layerIndex]); unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - unifiedLayer - .task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex], state.temp) + unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, + config.dim(), config.rmsNormEps()); + } + unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex], state.temp) .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapK, weights.wkLayered[layerIndex], config.dim(), config.kvDim(), @@ -122,14 +128,17 @@ TaskGraph setupSingleFFNLayer(FP16Weights weights, Configuration config, int lay LOCAL_WORK_GROUP_SIZE_ALLOC) .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize()) .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), - layerIndex, config.contextLength()) - .task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, context, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, - config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), state.positionHolder, layerIndex, config.contextLength()) - .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex], config.dim(), config.dim(), + layerIndex, config.contextLength()); + configureAttention(unifiedLayer, layerIndex); + unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), - state.localSize) - .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN) + state.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, + config.dim(), config.rmsNormEps()); + } + unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN) .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex], weights.w3Layered[layerIndex], config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex], config.hiddenDim(), @@ -159,4 +168,18 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye return unifiedLayer; } + private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { + if (schedulerType == SchedulerType.NVIDIA) { + return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), + state.positionHolder, layerIndex, config.contextLength()); + } else { + return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsParallel, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), config.contextLength(), + state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()); + } + } + } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index d45d808a..cb543ab0 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -8,6 +8,7 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractLayer; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; @@ -22,10 +23,12 @@ public class LogitsFP16Layer extends AbstractLayer { private TaskGraph logitsTaskGraph; private ImmutableTaskGraph immutableLogitsGraph; private GridScheduler scheduler; + private final SchedulerType schedulerType; - public LogitsFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID) { + public LogitsFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { super(name, state, weights, config); this.lastTaskGraphID = lastTaskGraphID; + this.schedulerType = schedulerType; state.tempLogits.init(0.0f); var fp16Weights = requireWeightsType(weights, FP16Weights.class, "LogitsFP16Layer", "FP16"); this.logitsTaskGraph = setupLogitsTaskGraph(fp16Weights, config); @@ -38,8 +41,12 @@ private TaskGraph setupLogitsTaskGraph(FP16Weights weights, Configuration config TaskGraph logits = new TaskGraph("logits"); logits.consumeFromDevice(lastTaskGraphID, state.wrapX).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits) .transferToDevice(DataTransferMode.FIRST_EXECUTION, context, state.wrapLogits, weights.wclsHalfFloat, weights.rms_final_weight_as_floatArray) - .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray, state.tempLogits) + .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (schedulerType == SchedulerType.NON_NVIDIA) { + logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, + config.dim(), config.rmsNormEps()); + } + logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray, state.tempLogits) .task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapX, state.wrapLogits, weights.wclsHalfFloat, config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index 789ebc63..88aa8c6e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -8,6 +8,7 @@ import org.beehive.gpullama3.model.phi3.Phi3Configuration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import org.beehive.gpullama3.tornadovm.layers.AbstractLayer; import uk.ac.manchester.tornado.api.GridScheduler; @@ -46,8 +47,8 @@ public class Phi3FP16FFNLayers extends AbstractFFNLayers { // Phi3-specific dimension for combined QKV buffer private final int opSize; - public Phi3FP16FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeights weights, Phi3Configuration config) { - super(taskGraphName, state, weights, config); + public Phi3FP16FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeights weights, Phi3Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); this.phi3State = state; this.phi3Config = config; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java index 86eefecb..032db1bf 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -7,6 +7,7 @@ import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; @@ -37,8 +38,8 @@ public class Qwen2FP16FFNLayers extends AbstractFFNLayers { GridScheduler scheduler; List ffnLayerTaskGraphs; - public Qwen2FP16FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWeights weights, Qwen2Configuration config) { - super(taskGraphName, state, weights, config); + public Qwen2FP16FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWeights weights, Qwen2Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); this.qwen2State = state; this.qwen2Config = config; ffnLayerTaskGraphs = setupFFNLayered(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index b26b1c8c..d1777c55 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -6,6 +6,7 @@ import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; @@ -41,8 +42,8 @@ public class Qwen3FP16FFNLayers extends AbstractFFNLayers { GridScheduler scheduler; List ffnLayerTaskGraphs; - public Qwen3FP16FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWeights weights, Qwen3Configuration config) { - super(taskGraphName, state, weights, config); + public Qwen3FP16FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWeights weights, Qwen3Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); this.qwen3State = state; this.qwen3Config = config; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java index 5c649546..468fe717 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java @@ -5,6 +5,7 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; @@ -20,8 +21,8 @@ public class LlamaQ8_0FFNLayers extends AbstractFFNLayers { GridScheduler scheduler; List ffnLayerTaskGraphs; - public LlamaQ8_0FFNLayers(String taskGraphName, LlamaState state, LlamaTornadoWeightsQ8_0 weights, Configuration config) { - super(taskGraphName, state, weights, config); + public LlamaQ8_0FFNLayers(String taskGraphName, LlamaState state, LlamaTornadoWeightsQ8_0 weights, Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); ffnLayerTaskGraphs = setupFFNLayered(); } @@ -65,8 +66,13 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeightsQ8_0 weights, Configuration con weights.woLayered[layerIndex].getScales(), weights.rms_ffn_weightLayered[layerIndex], weights.w1Layered[layerIndex].getQuants(), weights.w1Layered[layerIndex].getScales(), weights.w2Layered[layerIndex].getQuants(), weights.w2Layered[layerIndex].getScales(), weights.w3Layered[layerIndex].getQuants(), weights.w3Layered[layerIndex].getScales()); unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex], state.temp) + + unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, + config.dim(), config.rmsNormEps()); + } + unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex], state.temp) .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].getQuants(), weights.wqLayered[layerIndex].getScales(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapK, weights.wkLayered[layerIndex].getQuants(), @@ -75,13 +81,16 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeightsQ8_0 weights, Configuration con weights.wvLayered[layerIndex].getScales(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize()) .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), - layerIndex, config.contextLength()) - .task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, context, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, - config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), state.positionHolder, layerIndex, config.contextLength()) - .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].getQuants(), + layerIndex, config.contextLength()); + configureAttention(unifiedLayer, layerIndex); + unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].getQuants(), weights.woLayered[layerIndex].getScales(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN) + .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, + config.dim(), config.rmsNormEps()); + } + unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN) .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].getQuants(), weights.w1Layered[layerIndex].getScales(), weights.w3Layered[layerIndex].getQuants(), weights.w3Layered[layerIndex].getScales(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) @@ -150,4 +159,18 @@ public List getFfnLayerTaskGraphs() { return ffnLayerTaskGraphs; } + private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { + if (schedulerType == SchedulerType.NVIDIA) { + return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), + state.positionHolder, layerIndex, config.contextLength()); + } else { + return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsParallel, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), config.contextLength(), + state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()); + } + } + } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java index e915aabf..b700bcf8 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java @@ -8,6 +8,7 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractLayer; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; @@ -22,10 +23,12 @@ public class LogitsQ8_0Layer extends AbstractLayer { private TaskGraph logitsTaskGraph; private ImmutableTaskGraph immutableLogitsGraph; private GridScheduler scheduler; + private final SchedulerType schedulerType; - public LogitsQ8_0Layer(String taskGraphName, State state, Weights weights, Configuration config, String lastTaskGraphID) { + public LogitsQ8_0Layer(String taskGraphName, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { super(taskGraphName, state, weights, config); this.lastTaskGraphID = lastTaskGraphID; + this.schedulerType = schedulerType; state.tempLogits.init(0.0f); var q8_0Weights = requireWeightsType(weights, LlamaTornadoWeightsQ8_0.class, "LogitsQ8_0Layer", "Q8_0"); this.logitsTaskGraph = setupLogitsTaskGraph(q8_0Weights, config); @@ -54,12 +57,16 @@ private TaskGraph setupLogitsTaskGraph(LlamaTornadoWeightsQ8_0 weights, Configur TaskGraph logits = new TaskGraph("logits"); logits.consumeFromDevice(lastTaskGraphID, state.wrapX).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits) .transferToDevice(DataTransferMode.FIRST_EXECUTION, context, state.wrapLogits, weights.wclsHalfFloat.getQuants(), weights.wclsHalfFloat.getScales(), - weights.rms_final_weight_as_floatArray) - .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray, state.tempLogits) - .task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, // - context, state.wrapX, state.wrapLogits, weights.wclsHalfFloat.getQuants(), weights.wclsHalfFloat.getScales(), // - config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS) // + weights.rms_final_weight_as_floatArray); + logits.task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (schedulerType == SchedulerType.NON_NVIDIA) { + logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, + config.dim(), config.rmsNormEps()); + } + logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray, state.tempLogits) + .task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, + context, state.wrapX, state.wrapLogits, weights.wclsHalfFloat.getQuants(), weights.wclsHalfFloat.getScales(), + config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS) .transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); return logits; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java index e8d36851..9f4d6067 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java @@ -8,6 +8,7 @@ import org.beehive.gpullama3.model.phi3.Phi3Configuration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import org.beehive.gpullama3.tornadovm.layers.AbstractLayer; import uk.ac.manchester.tornado.api.GridScheduler; @@ -46,8 +47,8 @@ public class Phi3Q8_0FFNLayers extends AbstractFFNLayers { // Phi3-specific dimension for combined QKV buffer private final int opSize; - public Phi3Q8_0FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeightsQ8_0 weights, Phi3Configuration config) { - super(taskGraphName, state, weights, config); + public Phi3Q8_0FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeightsQ8_0 weights, Phi3Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); this.phi3State = state; this.phi3Config = config; this.opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize()); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java index 9ba0f974..01540ef6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java @@ -7,6 +7,7 @@ import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; @@ -41,8 +42,8 @@ public class Qwen2Q8_0FFNLayers extends AbstractFFNLayers { private final Qwen2State qwen2State; private final Qwen2Configuration qwen2Config; - public Qwen2Q8_0FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWeightsQ8_0 weights, Qwen2Configuration config) { - super(taskGraphName, state, weights, config); + public Qwen2Q8_0FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWeightsQ8_0 weights, Qwen2Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); this.qwen2State = state; this.qwen2Config = config; ffnLayerTaskGraphs = setupFFNLayered(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java index fddabf69..89c34428 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java @@ -6,6 +6,7 @@ import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; @@ -48,8 +49,8 @@ public class Qwen3Q8_0FFNLayers extends AbstractFFNLayers { private final int nEmbdGqa; private final int gqa; - public Qwen3Q8_0FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWeightsQ8_0 weights, Qwen3Configuration config) { - super(taskGraphName, state, weights, config); + public Qwen3Q8_0FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWeightsQ8_0 weights, Qwen3Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); this.qwen3State = state; this.qwen3Config = config; this.nHeadKv = config.numberOfKeyValueHeads();