44import com .example .loader .weights .State ;
55import com .example .model .Configuration ;
66import com .example .model .Model ;
7+ import com .example .model .ModelType ;
78import uk .ac .manchester .tornado .api .GridScheduler ;
89import uk .ac .manchester .tornado .api .ImmutableTaskGraph ;
910import uk .ac .manchester .tornado .api .TornadoExecutionPlan ;
1213import uk .ac .manchester .tornado .api .types .arrays .FloatArray ;
1314
1415import java .util .List ;
16+ import java .util .Locale ;
1517
1618public class TornadoVMMasterPlan {
1719 private static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean .parseBoolean (System .getProperty ("llama.EnableTimingForTornadoVMInit" , "False" ));
@@ -22,9 +24,9 @@ public class TornadoVMMasterPlan {
2224 public TornadoExecutionPlan executionPlan ;
2325 List <ImmutableTaskGraph > taskGraphs ;
2426
25- public TornadoVMMasterPlan (State state , Model model , boolean isNvidia ) {
27+ public TornadoVMMasterPlan (State state , Model model ) {
2628 TornadoVMLayerPlanner tornadoVMLayerPlanner = new TornadoVMLayerPlanner (state , model );
27- Tuple2 <List <ImmutableTaskGraph >, GridScheduler > tornadoVMPlan = isNvidia
29+ Tuple2 <List <ImmutableTaskGraph >, GridScheduler > tornadoVMPlan = shouldUseNvidiaScheduler ( model )
2830 ? tornadoVMLayerPlanner .setupTornadoForwardPlanLayered ()
2931 : tornadoVMLayerPlanner .setupTornadoForwardPlanLayeredNonNvidia ();
3032 this .taskGraphs = tornadoVMPlan .getFirst ();
@@ -57,9 +59,7 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod
5759 }
5860
5961 // 1. Pre-allocate the TornadoVM plan
60- TornadoRuntime coreRuntime = TornadoRuntimeProvider .getTornadoRuntime ();
61- boolean isNvidia = coreRuntime .getBackend (0 ).getDefaultDevice ().getPlatformName ().toLowerCase ().contains ("nvidia" );
62- TornadoVMMasterPlan tornadoVMPlan = new TornadoVMMasterPlan (state , model , isNvidia );
62+ TornadoVMMasterPlan tornadoVMPlan = new TornadoVMMasterPlan (state , model );
6363
6464 // Record time after plan creation
6565 if (ENABLE_TORNADOVM_INIT_TIME ) {
@@ -89,6 +89,29 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod
8989 return tornadoVMPlan ;
9090 }
9191
92+ /**
93+ * Determines whether the NVIDIA-specific scheduler should be used based on the current
94+ * hardware backend and the model type.
95+ * <p>
96+ * The scheduler is used only if the runtime is targeting an NVIDIA backend and the model
97+ * is not of type {@code MISTRAL}. If either the hardware is not NVIDIA or the model is
98+ * {@code MISTRAL}, the NVIDIA-specific scheduler should not be used.
99+ *
100+ * @param model the model whose type may affect the scheduler decision
101+ * @return {@code true} if the NVIDIA-specific scheduler should be used; {@code false} otherwise
102+ */
103+ public static boolean shouldUseNvidiaScheduler (Model model ) {
104+ TornadoRuntime runtime = TornadoRuntimeProvider .getTornadoRuntime ();
105+ String platformName = runtime .getBackend (0 ).getDefaultDevice ().getPlatformName ().toLowerCase (Locale .ROOT );
106+
107+ boolean isNvidia = platformName .contains ("nvidia" );
108+ boolean isNotMistral = model .getModelType () != ModelType .MISTRAL ;
109+
110+ boolean result = isNvidia && isNotMistral ;
111+
112+ return result ;
113+ }
114+
92115 /**
93116 * Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration.
94117 *This method processes the transformer layers in sequence for a particular token position in the context
0 commit comments