44import org .beehive .gpullama3 .core .model .tensor .FloatTensor ;
55import org .beehive .gpullama3 .inference .state .Phi3State ;
66import org .beehive .gpullama3 .inference .state .State ;
7- import org .beehive .gpullama3 .inference .weights .standard .Qwen2StandardWeights ;
87import org .beehive .gpullama3 .inference .weights .standard .Phi3StandardWeights ;
8+ import org .beehive .gpullama3 .inference .weights .standard .Qwen2StandardWeights ;
99import org .beehive .gpullama3 .inference .weights .standard .Qwen3StandardWeights ;
1010import org .beehive .gpullama3 .inference .weights .standard .StandardWeights ;
1111import org .beehive .gpullama3 .inference .weights .tornado .TornadoWeights ;
1212import org .beehive .gpullama3 .model .Configuration ;
1313import org .beehive .gpullama3 .model .Model ;
14- import org .beehive .gpullama3 .model .qwen2 .Qwen2Configuration ;
1514import org .beehive .gpullama3 .model .phi3 .Phi3Configuration ;
15+ import org .beehive .gpullama3 .model .qwen2 .Qwen2Configuration ;
1616import org .beehive .gpullama3 .model .qwen3 .Qwen3Configuration ;
1717import org .beehive .gpullama3 .tornadovm .TornadoVMMasterPlan ;
18-
1918import uk .ac .manchester .tornado .api .types .arrays .FloatArray ;
2019
2120import java .lang .foreign .MemorySegment ;
@@ -218,9 +217,9 @@ public static FloatTensor forwardJavaQwen2(Model model, State state, int token,
218217 for (int vi = 0 ; vi < rotn ; vi ++) {
219218 FloatTensor vec = (vi == 0 ) ? state .q : state .k ; // the vector to rotate (query or key)
220219 float v0 = vec .getFloat (poffset + ic );
221- float v1 = vec .getFloat (poffset + ic + headSize / 2 );
220+ float v1 = vec .getFloat (poffset + ic + headSize / 2 );
222221 vec .setFloat (poffset + ic , v0 * fcr - v1 * fci );
223- vec .setFloat (poffset + ic + headSize / 2 , v0 * fci + v1 * fcr );
222+ vec .setFloat (poffset + ic + headSize / 2 , v0 * fci + v1 * fcr );
224223 }
225224 }
226225 }
@@ -231,7 +230,7 @@ public static FloatTensor forwardJavaQwen2(Model model, State state, int token,
231230 state .v .copyTo (0 , state .valueCache [curLayer ], position * kvDim , kvDim );
232231
233232 // multihead attention. iterate over all heads
234- Parallel .parallelFor (0 , config .numberOfHeads (), h -> {
233+ Parallel .parallelFor (0 , config .numberOfHeads (), h -> {
235234 // get the query vector for this head
236235 // float* q = s.q + h * headSize;
237236 int qOffset = h * headSize ;
@@ -584,7 +583,7 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i
584583 final Configuration configuration = model .configuration ();
585584 final TornadoWeights weights = (TornadoWeights ) model .weights ();
586585
587- MemorySegment .copy (weights .tokenEmbeddingTable .getSegment (), token * configuration .dim () * Float .BYTES , state .wrapX .getSegment (), 0 , configuration .dim () * Float .BYTES );
586+ MemorySegment .copy (weights .tokenEmbeddingTable .getSegment (), ( long ) token * configuration .dim () * Float .BYTES , state .wrapX .getSegment (), 0 , configuration .dim () * Float .BYTES );
588587
589588 return tornadoVMMasterPlan .tornadoVMForwardExecuteLayered (position );
590589 }
0 commit comments