Skip to content

Commit 12aa536

Browse files
committed
Refactor task graph configuration for clarity and maintainability
1 parent 9178f51 commit 12aa536

File tree

1 file changed

+17
-24
lines changed

1 file changed

+17
-24
lines changed

src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
8181
.persistOnDevice(state.wrapX);
8282
taskGraphs.add(activationUpdate.snapshot());
8383

84-
8584
TaskGraph unifiedLayer = null;
8685
for (int layerIndex =0; layerIndex < config.numberOfLayers; layerIndex++) {
8786
unifiedLayer = new TaskGraph("layer_" + layerIndex);
@@ -135,7 +134,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
135134
}
136135

137136
TaskGraph lastUnifiedLayer = unifiedLayer;
138-
139137
TaskGraph logits = new TaskGraph("logits")
140138
.consumeFromDevice(lastUnifiedLayer.getTaskGraphName(),
141139
state.wrapX
@@ -186,18 +184,21 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
186184
private TaskGraph configureQuantizedMatrixVectorFinalWeight(TaskGraph logits) {
187185
switch (weights.weightType) {
188186
case Q8_0:
189-
logits.task("projection", TransformerComputeKernels::matmulTornadoQ8Optimized, context, weights.wclsByteArray, state.wrapX, state.wrapLogits, config.dim);
187+
logits.task("projection", TransformerComputeKernels::matmulTornadoQ8Optimized, //
188+
context, weights.wclsByteArray, state.wrapX, //
189+
state.wrapLogits, config.dim); //
190190
break;
191191
case Q4_0:
192-
logits.task("projection", TransformerComputeKernels::matmulTornadoQ4Optimized, context, weights.wclsByteArray, state.wrapX, state.wrapLogits, config.dim);
192+
logits.task("projection", TransformerComputeKernels::matmulTornadoQ4Optimized, //
193+
context, weights.wclsByteArray, state.wrapX, //
194+
state.wrapLogits, config.dim); //
193195
break;
194196
default:
195197
throw new UnsupportedOperationException("Unsupported weight quantization type: " + weights.weightType + ". Only Q8_0 and Q4_0 are supported.");
196198
}
197199
return logits;
198200
}
199201

200-
// @formatter:off
201202
/**
202203
* Configures data transfer operations for a specific layer in the neural network task graph.
203204
*
@@ -218,29 +219,21 @@ private TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerI
218219
// First layer: Transfer initial data to device (one-time transfer)
219220
if (layerIndex == 0) {
220221
// Transfer all attention-related data: query, key, value matrices and their caches
221-
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
222-
context, state.wrapXb, state.wrapXb2,
223-
state.wrapQ, state.wrapK, state.wrapV,
224-
state.wrapKeyCache, state.wrapValueCache,
225-
state.wrapAtt, state.wrapHb);
222+
unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); //
223+
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, //
224+
context, state.wrapXb, state.wrapXb2, //
225+
state.wrapQ, state.wrapK, state.wrapV, //
226+
state.wrapKeyCache, state.wrapValueCache, //
227+
state.wrapAtt, state.wrapHb); //
226228
} else {
227229
// Subsequent layers: Consume data already on device from previous layer
228-
unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2,
229-
state.wrapQ, state.wrapK, state.wrapV,
230-
state.wrapKeyCache, state.wrapValueCache,
231-
state.wrapAtt, state.wrapHb
230+
unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, //
231+
state.wrapQ, state.wrapK, state.wrapV, //
232+
state.wrapKeyCache, state.wrapValueCache, //
233+
state.wrapAtt, state.wrapHb, //
234+
state.positionHolder //
232235
);
233236
}
234-
235-
// First layer: Transfer position and temp data (transferred every execution)
236-
if ((layerIndex) == 0) {
237-
// Transfer data that changes with each execution (position, temp buffers)
238-
unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN);
239-
} else {
240-
// Subsequent layers: Only consume position data from device
241-
unifiedLayer.consumeFromDevice(state.positionHolder);
242-
}
243-
// @formatter:on
244237
return unifiedLayer;
245238
}
246239

0 commit comments

Comments
 (0)