Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion external/tornadovm
Submodule tornadovm updated 110 files
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ private static ModelType detectModelType(Map<String, Object> metadata) {
return ModelType.QWEN_2;
} else if (lowerName.contains("qwen3")) {
return ModelType.QWEN_3;
} else if (lowerName.contains("deepseek r1 distill")) {
} else if (lowerName.contains("deepseek")) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we still need this

return ModelType.DEEPSEEK_R1_DISTILL_QWEN;
} else if (lowerName.contains("phi3") || lowerName.contains("phi-3")) {
return ModelType.PHI_3;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,52 @@ public static void ropeRotationPhi3(KernelContext context, IntArray positionHold
}
}

/**
* Orchestrates parallel multi-head attention computation across all heads. Each head processes attention independently in parallel.
*
* Attention computation: 1. Compute attention scores (Q·K) 2. Apply softmax for attention weights 3. Compute weighted sum of values (attention·V)
*
* @param q
* Query vectors for all heads
* @param key_cache
* Cached key vectors
* @param value_cache
* Cached value vectors
* @param xb
* Output buffer for attention results
* @param nHeads
* Number of attention heads
* @param headSize
* Dimension of each head
* @param kvDim
* Total key/value dimension
* @param kvMul
* Key/value head multiplier for grouped-query attention
* @param seqLen
* Current sequence length
* @param positionHolder
* Array containing position and layer info
* @param wrapAtt
* Buffer for attention weights
* @param layer
* Current transformer layer
* @param contextLength
* Maximum context length
*/
public static void processHeadsParallel(FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int seqLen,
IntArray positionHolder, FloatArray wrapAtt, int layer, int contextLength) {

int pos = positionHolder.get(0);
int loff = layer * contextLength * kvDim;

// Parallelize computation across attention heads
for (@Parallel int h = 0; h < nHeads; h++) {
// Process each head in parallel
processHeadTornado(q, key_cache, value_cache, xb, h, headSize, kvDim, kvMul, loff, pos, wrapAtt);
}
}


/**
* Computes attention for a single head. Implements scaled dot-product attention with softmax normalization.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tornadovm.GenericLayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
import uk.ac.manchester.tornado.api.KernelContext;

/**
Expand All @@ -22,15 +24,19 @@ public abstract class QuantizedLayerPlanner<S extends State, C extends Configura
protected final C config;
protected final W weights;
protected final KernelContext context;
protected final Model model;
protected final SchedulerType schedulerType;

/**
* Constructor: validate quantization type, extract model components
*/
protected QuantizedLayerPlanner(S state, Model model) {
this.state = state;
this.model = model;
this.config = (C) model.configuration();
this.weights = (W) model.weights();
this.context = new KernelContext();
this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);

validateQuantizationType();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ public LlamaFP16LayerPlanner(LlamaState state, Model model) {
@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new LlamaFP16FFNLayers("llamaFFN", this.state, this.weights, this.config);
this.logitsLayer = new LogitsFP16Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
this.ffnLayers = new LlamaFP16FFNLayers("llamaFFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsFP16Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ public Phi3FP16LayerPlanner(Phi3State state, Model model) {
@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new Phi3FP16FFNLayers("phi3FFN", this.state, this.weights, this.config);
this.logitsLayer = new LogitsFP16Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
this.ffnLayers = new Phi3FP16FFNLayers("phi3FFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsFP16Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public Qwen2FP16LayerPlanner(Qwen2State state, Model model) {
@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new Qwen2FP16FFNLayers("qwen2FFN", this.state, this.weights, this.config);
this.logitsLayer = new LogitsFP16Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
this.ffnLayers = new Qwen2FP16FFNLayers("qwen2FFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsFP16Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ public Qwen3FP16LayerPlanner(Qwen3State state, Model model) {
@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new Qwen3FP16FFNLayers("qwen3FFN", this.state, this.weights, this.config);
this.logitsLayer = new LogitsFP16Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
this.ffnLayers = new Qwen3FP16FFNLayers("qwen3FFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsFP16Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ public LlamaQ8_0LayerPlanner(LlamaState state, Model model) {
@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new LlamaQ8_0FFNLayers("llamaFFN", this.state, this.weights, this.config);
this.logitsLayer = new LogitsQ8_0Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
this.ffnLayers = new LlamaQ8_0FFNLayers("llamaFFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsQ8_0Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ public Phi3Q8_0LayerPlanner(Phi3State state, Model model) {
@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new Phi3Q8_0FFNLayers("phi3FFN", this.state, this.weights, this.config);
this.logitsLayer = new LogitsQ8_0Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
this.ffnLayers = new Phi3Q8_0FFNLayers("phi3FFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsQ8_0Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ public Qwen2Q8_0LayerPlanner(Qwen2State state, Model model) {
@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new Qwen2Q8_0FFNLayers("qwen2FFN", this.state, this.weights, this.config);
this.logitsLayer = new LogitsQ8_0Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
this.ffnLayers = new Qwen2Q8_0FFNLayers("qwen2FFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsQ8_0Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public Qwen3Q8_0LayerPlanner(Qwen3State state, Model model) {
@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new Qwen3Q8_0FFNLayers("qwen3FFN", this.state, this.weights, this.config);
this.logitsLayer = new LogitsQ8_0Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
this.ffnLayers = new Qwen3Q8_0FFNLayers("qwen3FFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsQ8_0Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,11 @@

public class SchedulerDetectionService {


public static SchedulerType determineSchedulerType(Model model) {
TornadoRuntime tornadoRuntime = TornadoRuntimeProvider.getTornadoRuntime();
String platformName = tornadoRuntime.getBackend(0)
.getDefaultDevice()
.getPlatformName()
.toLowerCase(Locale.ROOT);
String platformName = tornadoRuntime.getBackend(0).getDefaultDevice().getPlatformName().toLowerCase(Locale.ROOT);

boolean isNvidia = platformName.contains("nvidia") ||
platformName.contains("cuda") ||
platformName.contains("ptx");
boolean isNvidia = platformName.contains("nvidia") || platformName.contains("cuda") || platformName.contains("ptx");
boolean isNotMistral = model.getModelType() != ModelType.MISTRAL;

return (isNvidia && isNotMistral) ? SchedulerType.NVIDIA : SchedulerType.NON_NVIDIA;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;

import java.util.List;
Expand All @@ -20,6 +21,7 @@
public abstract class AbstractFFNLayers extends AbstractLayer {

protected String lastTaskGraphID;
protected final SchedulerType schedulerType;

/**
* Constructor for FFN layers.
Expand All @@ -32,9 +34,12 @@ public abstract class AbstractFFNLayers extends AbstractLayer {
* Model weights (FP16Weights, Q8_0Weights, etc.)
* @param config
* Model configuration
* @param schedulerType
* Scheduler type (NVIDIA or NON_NVIDIA) for hardware-specific optimizations
*/
protected AbstractFFNLayers(String taskGraphName, State state, Weights weights, Configuration config) {
protected AbstractFFNLayers(String taskGraphName, State state, Weights weights, Configuration config, SchedulerType schedulerType) {
super(taskGraphName, state, weights, config);
this.schedulerType = schedulerType;
}

/**
Expand Down Expand Up @@ -66,4 +71,18 @@ public String getLastTaskGraphID() {
public void clearLastTaskGraphID() {
lastTaskGraphID = null;
}

/**
* Configures the attention mechanism based on hardware scheduler type.
*
* - NVIDIA hardware: Uses Flash Attention for optimized performance
* - NON_NVIDIA hardware: Uses parallel head processing
*
* This method should be called during task graph setup in subclasses.
*
* @return true if final normalization step should be used (NON_NVIDIA), false otherwise
*/
protected boolean shouldUseFinalNormalization() {
return schedulerType == SchedulerType.NON_NVIDIA;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ public abstract class AbstractLayer {

/** Common constants used in tasks & worker-grid sizing. */
protected static final int LOCAL_WORK_GROUP_SIZE_ALLOC = 32;
protected static final int THREAD_SCALE_FOR_LOGITS = 1;
// TODO: 1 OR 8?
protected static final int THREAD_SCALE_FOR_LOGITS = 8;
protected static String lastTaskGraphID;
protected final Weights weights;
protected final Configuration config;
Expand Down
Loading