Skip to content

Commit 788c11e

Browse files
committed
Refactor: adjust spacing for readability, reorder imports, and improve type casting in InferenceCore.
1 parent 91d7d9e commit 788c11e

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/main/java/org/beehive/gpullama3/inference/InferenceCore.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,17 @@
44
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
55
import org.beehive.gpullama3.inference.state.Phi3State;
66
import org.beehive.gpullama3.inference.state.State;
7-
import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights;
87
import org.beehive.gpullama3.inference.weights.standard.Phi3StandardWeights;
8+
import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights;
99
import org.beehive.gpullama3.inference.weights.standard.Qwen3StandardWeights;
1010
import org.beehive.gpullama3.inference.weights.standard.StandardWeights;
1111
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
1212
import org.beehive.gpullama3.model.Configuration;
1313
import org.beehive.gpullama3.model.Model;
14-
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
1514
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
15+
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
1616
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
1717
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
18-
1918
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
2019

2120
import java.lang.foreign.MemorySegment;
@@ -218,9 +217,9 @@ public static FloatTensor forwardJavaQwen2(Model model, State state, int token,
218217
for (int vi = 0; vi < rotn; vi++) {
219218
FloatTensor vec = (vi == 0) ? state.q : state.k; // the vector to rotate (query or key)
220219
float v0 = vec.getFloat(poffset + ic);
221-
float v1 = vec.getFloat(poffset + ic + headSize/2);
220+
float v1 = vec.getFloat(poffset + ic + headSize / 2);
222221
vec.setFloat(poffset + ic, v0 * fcr - v1 * fci);
223-
vec.setFloat(poffset + ic + headSize/2, v0 * fci + v1 * fcr);
222+
vec.setFloat(poffset + ic + headSize / 2, v0 * fci + v1 * fcr);
224223
}
225224
}
226225
}
@@ -231,7 +230,7 @@ public static FloatTensor forwardJavaQwen2(Model model, State state, int token,
231230
state.v.copyTo(0, state.valueCache[curLayer], position * kvDim, kvDim);
232231

233232
// multihead attention. iterate over all heads
234-
Parallel.parallelFor(0, config.numberOfHeads(), h -> {
233+
Parallel.parallelFor(0, config.numberOfHeads(), h -> {
235234
// get the query vector for this head
236235
// float* q = s.q + h * headSize;
237236
int qOffset = h * headSize;
@@ -584,7 +583,7 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i
584583
final Configuration configuration = model.configuration();
585584
final TornadoWeights weights = (TornadoWeights) model.weights();
586585

587-
MemorySegment.copy(weights.tokenEmbeddingTable.getSegment(), token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES);
586+
MemorySegment.copy(weights.tokenEmbeddingTable.getSegment(), (long) token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES);
588587

589588
return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position);
590589
}

0 commit comments

Comments
 (0)