@@ -81,7 +81,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
8181 .persistOnDevice (state .wrapX );
8282 taskGraphs .add (activationUpdate .snapshot ());
8383
84-
8584 TaskGraph unifiedLayer = null ;
8685 for (int layerIndex =0 ; layerIndex < config .numberOfLayers ; layerIndex ++) {
8786 unifiedLayer = new TaskGraph ("layer_" + layerIndex );
@@ -135,7 +134,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
135134 }
136135
137136 TaskGraph lastUnifiedLayer = unifiedLayer ;
138-
139137 TaskGraph logits = new TaskGraph ("logits" )
140138 .consumeFromDevice (lastUnifiedLayer .getTaskGraphName (),
141139 state .wrapX
@@ -186,18 +184,21 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
186184 private TaskGraph configureQuantizedMatrixVectorFinalWeight (TaskGraph logits ) {
187185 switch (weights .weightType ) {
188186 case Q8_0 :
189- logits .task ("projection" , TransformerComputeKernels ::matmulTornadoQ8Optimized , context , weights .wclsByteArray , state .wrapX , state .wrapLogits , config .dim );
187+ logits .task ("projection" , TransformerComputeKernels ::matmulTornadoQ8Optimized , //
188+ context , weights .wclsByteArray , state .wrapX , //
189+ state .wrapLogits , config .dim ); //
190190 break ;
191191 case Q4_0 :
192- logits .task ("projection" , TransformerComputeKernels ::matmulTornadoQ4Optimized , context , weights .wclsByteArray , state .wrapX , state .wrapLogits , config .dim );
192+ logits .task ("projection" , TransformerComputeKernels ::matmulTornadoQ4Optimized , //
193+ context , weights .wclsByteArray , state .wrapX , //
194+ state .wrapLogits , config .dim ); //
193195 break ;
194196 default :
195197 throw new UnsupportedOperationException ("Unsupported weight quantization type: " + weights .weightType + ". Only Q8_0 and Q4_0 are supported." );
196198 }
197199 return logits ;
198200 }
199201
200- // @formatter:off
201202 /**
202203 * Configures data transfer operations for a specific layer in the neural network task graph.
203204 *
@@ -218,29 +219,21 @@ private TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerI
218219 // First layer: Transfer initial data to device (one-time transfer)
219220 if (layerIndex == 0 ) {
220221 // Transfer all attention-related data: query, key, value matrices and their caches
221- unifiedLayer .transferToDevice (DataTransferMode .FIRST_EXECUTION ,
222- context , state .wrapXb , state .wrapXb2 ,
223- state .wrapQ , state .wrapK , state .wrapV ,
224- state .wrapKeyCache , state .wrapValueCache ,
225- state .wrapAtt , state .wrapHb );
222+ unifiedLayer .transferToDevice (DataTransferMode .EVERY_EXECUTION , state .positionHolder , state .temp , state .tempFFN ); //
223+ unifiedLayer .transferToDevice (DataTransferMode .FIRST_EXECUTION , //
224+ context , state .wrapXb , state .wrapXb2 , //
225+ state .wrapQ , state .wrapK , state .wrapV , //
226+ state .wrapKeyCache , state .wrapValueCache , //
227+ state .wrapAtt , state .wrapHb ); //
226228 } else {
227229 // Subsequent layers: Consume data already on device from previous layer
228- unifiedLayer .consumeFromDevice (context , state .wrapXb , state .wrapXb2 ,
229- state .wrapQ , state .wrapK , state .wrapV ,
230- state .wrapKeyCache , state .wrapValueCache ,
231- state .wrapAtt , state .wrapHb
230+ unifiedLayer .consumeFromDevice (context , state .wrapXb , state .wrapXb2 , //
231+ state .wrapQ , state .wrapK , state .wrapV , //
232+ state .wrapKeyCache , state .wrapValueCache , //
233+ state .wrapAtt , state .wrapHb , //
234+ state .positionHolder //
232235 );
233236 }
234-
235- // First layer: Transfer position and temp data (transferred every execution)
236- if ((layerIndex ) == 0 ) {
237- // Transfer data that changes with each execution (position, temp buffers)
238- unifiedLayer .transferToDevice (DataTransferMode .EVERY_EXECUTION , state .positionHolder , state .temp , state .tempFFN );
239- } else {
240- // Subsequent layers: Only consume position data from device
241- unifiedLayer .consumeFromDevice (state .positionHolder );
242- }
243- // @formatter:on
244237 return unifiedLayer ;
245238 }
246239
0 commit comments