Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion external/tornadovm
Submodule tornadovm updated 110 files
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,52 @@ public static void ropeRotationPhi3(KernelContext context, IntArray positionHold
}
}

/**
* Orchestrates parallel multi-head attention computation across all heads. Each head processes attention independently in parallel.
*
* Attention computation: 1. Compute attention scores (Q·K) 2. Apply softmax for attention weights 3. Compute weighted sum of values (attention·V)
*
* @param q
* Query vectors for all heads
* @param key_cache
* Cached key vectors
* @param value_cache
* Cached value vectors
* @param xb
* Output buffer for attention results
* @param nHeads
* Number of attention heads
* @param headSize
* Dimension of each head
* @param kvDim
* Total key/value dimension
* @param kvMul
* Key/value head multiplier for grouped-query attention
* @param seqLen
* Current sequence length
* @param positionHolder
* Array containing position and layer info
* @param wrapAtt
* Buffer for attention weights
* @param layer
* Current transformer layer
* @param contextLength
* Maximum context length
*/
public static void processHeadsParallel(FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int seqLen,
IntArray positionHolder, FloatArray wrapAtt, int layer, int contextLength) {

int pos = positionHolder.get(0);
int loff = layer * contextLength * kvDim;

// Parallelize computation across attention heads
for (@Parallel int h = 0; h < nHeads; h++) {
// Process each head in parallel
processHeadTornado(q, key_cache, value_cache, xb, h, headSize, kvDim, kvMul, loff, pos, wrapAtt);
}
}


/**
* Computes attention for a single head. Implements scaled dot-product attention with softmax normalization.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tornadovm.GenericLayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
import uk.ac.manchester.tornado.api.KernelContext;

/**
Expand All @@ -22,15 +24,19 @@ public abstract class QuantizedLayerPlanner<S extends State, C extends Configura
protected final C config;
protected final W weights;
protected final KernelContext context;
protected final Model model;
protected final SchedulerType schedulerType;

/**
* Constructor: validate quantization type, extract model components
*/
protected QuantizedLayerPlanner(S state, Model model) {
this.state = state;
this.model = model;
this.config = (C) model.configuration();
this.weights = (W) model.weights();
this.context = new KernelContext();
this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);

validateQuantizationType();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ public LlamaFP16LayerPlanner(LlamaState state, Model model) {
@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new LlamaFP16FFNLayers("llamaFFN", this.state, this.weights, this.config);
this.logitsLayer = new LogitsFP16Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
this.ffnLayers = new LlamaFP16FFNLayers("llamaFFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsFP16Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ public Phi3FP16LayerPlanner(Phi3State state, Model model) {
@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new Phi3FP16FFNLayers("phi3FFN", this.state, this.weights, this.config);
this.logitsLayer = new LogitsFP16Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
this.ffnLayers = new Phi3FP16FFNLayers("phi3FFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsFP16Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public Qwen2FP16LayerPlanner(Qwen2State state, Model model) {
@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new Qwen2FP16FFNLayers("qwen2FFN", this.state, this.weights, this.config);
this.logitsLayer = new LogitsFP16Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
this.ffnLayers = new Qwen2FP16FFNLayers("qwen2FFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsFP16Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ public Qwen3FP16LayerPlanner(Qwen3State state, Model model) {
@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new Qwen3FP16FFNLayers("qwen3FFN", this.state, this.weights, this.config);
this.logitsLayer = new LogitsFP16Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
this.ffnLayers = new Qwen3FP16FFNLayers("qwen3FFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsFP16Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ public LlamaQ8_0LayerPlanner(LlamaState state, Model model) {
@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new LlamaQ8_0FFNLayers("llamaFFN", this.state, this.weights, this.config);
this.logitsLayer = new LogitsQ8_0Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
this.ffnLayers = new LlamaQ8_0FFNLayers("llamaFFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsQ8_0Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ public Phi3Q8_0LayerPlanner(Phi3State state, Model model) {
@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new Phi3Q8_0FFNLayers("phi3FFN", this.state, this.weights, this.config);
this.logitsLayer = new LogitsQ8_0Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
this.ffnLayers = new Phi3Q8_0FFNLayers("phi3FFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsQ8_0Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ public Qwen2Q8_0LayerPlanner(Qwen2State state, Model model) {
@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new Qwen2Q8_0FFNLayers("qwen2FFN", this.state, this.weights, this.config);
this.logitsLayer = new LogitsQ8_0Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
this.ffnLayers = new Qwen2Q8_0FFNLayers("qwen2FFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsQ8_0Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public Qwen3Q8_0LayerPlanner(Qwen3State state, Model model) {
@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new Qwen3Q8_0FFNLayers("qwen3FFN", this.state, this.weights, this.config);
this.logitsLayer = new LogitsQ8_0Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
this.ffnLayers = new Qwen3Q8_0FFNLayers("qwen3FFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsQ8_0Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,11 @@

public class SchedulerDetectionService {


public static SchedulerType determineSchedulerType(Model model) {
TornadoRuntime tornadoRuntime = TornadoRuntimeProvider.getTornadoRuntime();
String platformName = tornadoRuntime.getBackend(0)
.getDefaultDevice()
.getPlatformName()
.toLowerCase(Locale.ROOT);
String platformName = tornadoRuntime.getBackend(0).getDefaultDevice().getPlatformName().toLowerCase(Locale.ROOT);

boolean isNvidia = platformName.contains("nvidia") ||
platformName.contains("cuda") ||
platformName.contains("ptx");
boolean isNvidia = platformName.contains("nvidia") || platformName.contains("cuda") || platformName.contains("ptx");
boolean isNotMistral = model.getModelType() != ModelType.MISTRAL;

return (isNvidia && isNotMistral) ? SchedulerType.NVIDIA : SchedulerType.NON_NVIDIA;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;

import java.util.List;
Expand All @@ -20,6 +21,7 @@
public abstract class AbstractFFNLayers extends AbstractLayer {

protected String lastTaskGraphID;
protected final SchedulerType schedulerType;

/**
* Constructor for FFN layers.
Expand All @@ -32,9 +34,12 @@ public abstract class AbstractFFNLayers extends AbstractLayer {
* Model weights (FP16Weights, Q8_0Weights, etc.)
* @param config
* Model configuration
* @param schedulerType
* Scheduler type (NVIDIA or NON_NVIDIA) for hardware-specific optimizations
*/
protected AbstractFFNLayers(String taskGraphName, State state, Weights weights, Configuration config) {
protected AbstractFFNLayers(String taskGraphName, State state, Weights weights, Configuration config, SchedulerType schedulerType) {
super(taskGraphName, state, weights, config);
this.schedulerType = schedulerType;
}

/**
Expand Down Expand Up @@ -66,4 +71,18 @@ public String getLastTaskGraphID() {
public void clearLastTaskGraphID() {
lastTaskGraphID = null;
}

/**
* Configures the attention mechanism based on hardware scheduler type.
*
* - NVIDIA hardware: Uses Flash Attention for optimized performance
* - NON_NVIDIA hardware: Uses parallel head processing
*
* This method should be called during task graph setup in subclasses.
*
* @return true if final normalization step should be used (NON_NVIDIA), false otherwise
*/
protected boolean shouldUseFinalNormalization() {
return schedulerType == SchedulerType.NON_NVIDIA;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public abstract class AbstractLayer {

/** Common constants used in tasks & worker-grid sizing. */
protected static final int LOCAL_WORK_GROUP_SIZE_ALLOC = 32;
protected static final int THREAD_SCALE_FOR_LOGITS = 1;
protected static final int THREAD_SCALE_FOR_LOGITS = 8;
protected static String lastTaskGraphID;
protected final Weights weights;
protected final Configuration config;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.common.TornadoFunctions;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;

import java.util.List;
Expand All @@ -21,9 +23,10 @@ public class LlamaFP16FFNLayers extends AbstractFFNLayers {

TaskGraph ffnTaskGraphs;
GridScheduler scheduler;
List<ImmutableTaskGraph> ffnLayerTaskGraphs;
public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config) {
super(taskGraph, state, weights, config);
List<ImmutableTaskGraph> ffnLayerTaskGraphs;

public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config, org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType schedulerType) {
super(taskGraph, state, weights, config, schedulerType);
this.ffnLayerTaskGraphs = setupFFNLayered();
}

Expand Down Expand Up @@ -111,9 +114,12 @@ TaskGraph setupSingleFFNLayer(FP16Weights weights, Configuration config, int lay
weights.w2Layered[layerIndex],
weights.w3Layered[layerIndex]);
unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
unifiedLayer
.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex], state.temp)
unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);
if (shouldUseFinalNormalization()) {
unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp,
config.dim(), config.rmsNormEps());
}
unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex], state.temp)
.task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex], config.dim(), config.dim(),
LOCAL_WORK_GROUP_SIZE_ALLOC)
.task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapK, weights.wkLayered[layerIndex], config.dim(), config.kvDim(),
Expand All @@ -122,14 +128,17 @@ TaskGraph setupSingleFFNLayer(FP16Weights weights, Configuration config, int lay
LOCAL_WORK_GROUP_SIZE_ALLOC)
.task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize())
.task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(),
layerIndex, config.contextLength())
.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, context, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb,
config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), state.positionHolder, layerIndex, config.contextLength())
.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex], config.dim(), config.dim(),
layerIndex, config.contextLength());
configureAttention(unifiedLayer, layerIndex);
unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex], config.dim(), config.dim(),
LOCAL_WORK_GROUP_SIZE_ALLOC)
.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(),
state.localSize)
.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN)
state.localSize);
if (shouldUseFinalNormalization()) {
unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN,
config.dim(), config.rmsNormEps());
}
unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN)
.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex],
weights.w3Layered[layerIndex], config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
.task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex], config.hiddenDim(),
Expand Down Expand Up @@ -159,4 +168,18 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye
return unifiedLayer;
}

private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) {
if (schedulerType == SchedulerType.NVIDIA) {
return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention,
context, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb,
config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(),
state.positionHolder, layerIndex, config.contextLength());
} else {
return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsParallel,
state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb,
config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), config.contextLength(),
state.positionHolder, state.wrapAtt, layerIndex, config.contextLength());
Copy link

Copilot AI Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter layerIndex is passed twice in the processHeadsParallel call - once as the second-to-last parameter and once as the last parameter. According to the method signature in TransformerComputeKernelsLayered.java (line 282-283), the parameters should be: q, key_cache, value_cache, xb, nHeads, headSize, kvDim, kvMul, seqLen, positionHolder, wrapAtt, layer, contextLength. Here layerIndex appears to be passed as both layer and contextLength, which is incorrect.

Copilot uses AI. Check for mistakes.
}
}

}
Loading