Skip to content

Commit 4bed4e2

Browse files
committed
Switch to HalfFloatArray for model weights
1 parent bb1676b commit 4bed4e2

File tree

3 files changed

+110
-20
lines changed

3 files changed

+110
-20
lines changed

src/main/java/com/example/loader/weights/ModelLoader.java

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
import com.example.inference.operation.RoPE;
1717
import com.example.tokenizer.impl.Tokenizer;
1818
import com.example.tokenizer.vocabulary.Vocabulary;
19+
import uk.ac.manchester.tornado.api.types.HalfFloat;
1920
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
2021
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
22+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
2123

2224
import java.io.IOException;
2325
import java.lang.foreign.MemorySegment;
@@ -103,15 +105,16 @@ private static Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tenso
103105
GGMLTensorEntry outputWeight) {
104106
return new Weights(
105107
// Load directly to TornadoVM format
106-
loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
107-
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
108-
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
109-
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
110-
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
108+
loadTensorAsFloatArray(tokenEmbeddings),
109+
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
110+
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
111+
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
112+
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
113+
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
111114
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
112-
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
113-
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
114-
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
115+
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
116+
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
117+
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
115118
FloatArray.fromArray(ropeFreqs.first()), FloatArray.fromArray(ropeFreqs.second()), createByteArrayFromTensor(outputWeight), outputWeight.ggmlType());
116119
}
117120

@@ -140,6 +143,14 @@ private static FloatArray[] loadArrayAsFloatArray(int size, IntFunction<GGMLTens
140143
return array;
141144
}
142145

146+
private static HalfFloatArray[] loadArrayAsHalfFloatArray(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
147+
HalfFloatArray[] array = new HalfFloatArray[size];
148+
for (int i = 0; i < size; i++) {
149+
array[i] = loadTensorAsHalfFloatArray(getTensorEntry.apply(i));
150+
}
151+
return array;
152+
}
153+
143154
private static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) {
144155
if (tensorEntry.ggmlType() == GGMLType.F32) {
145156
FloatBuffer buffer = tensorEntry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
@@ -149,6 +160,7 @@ private static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) {
149160
}
150161
}
151162

163+
152164
private static FloatArray[] loadArrayAsFloatArrayFromBuffer(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
153165
FloatArray[] array = new FloatArray[size];
154166
for (int i = 0; i < size; i++) {
@@ -182,6 +194,30 @@ private static FloatArray loadTensorAsFloatArray(GGMLTensorEntry entry) {
182194
}
183195
}
184196

197+
private static HalfFloatArray loadTensorAsHalfFloatArray(GGMLTensorEntry entry) {
198+
if (entry.ggmlType() == GGMLType.F32) {
199+
// For F32, we can directly create FloatArray from memory
200+
// FloatBuffer buffer = entry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
201+
// FloatArray array = new FloatArray(buffer.remaining());
202+
// for (int i = 0; i < buffer.remaining(); i++) {
203+
// array.set(i, buffer.get());
204+
// }
205+
// return array
206+
// ;
207+
System.out.println("Loading F32 tensor as HalfFloatArray");
208+
return null;
209+
} else {
210+
// For quantized formats, we need to load through FloatTensor
211+
FloatTensor tensor = loadQuantized(entry);
212+
HalfFloatArray array = new HalfFloatArray(tensor.size());
213+
for (int i = 0; i < tensor.size(); i++) {
214+
HalfFloat x = new HalfFloat(tensor.getFloat(i));
215+
array.set(i, x);
216+
}
217+
return array;
218+
}
219+
}
220+
185221
public static float getFloat(int index, int size, MemorySegment memorySegment) {
186222
assert 0 <= index && index < size;
187223
int blockIndex = index / GGMLType.Q4_0.getBlockSize();

src/main/java/com/example/loader/weights/Weights.java

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import com.example.LlamaApp;
44
import com.example.core.model.GGMLType;
55
import com.example.core.model.tensor.FloatTensor;
6+
import uk.ac.manchester.tornado.api.types.HalfFloat;
67
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
78
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
9+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
810

911
import java.nio.FloatBuffer;
1012

@@ -34,15 +36,15 @@ public class Weights {
3436
public final FloatBuffer freq_cis_imag; // (seq_len, head_size/2)
3537
// // Layered Data structures
3638
public FloatArray[] rms_att_weightLayered; // (layer, dim) rmsnorm weights
37-
public FloatArray[] wqLayered; // (layer, n_heads * head_size)
38-
public FloatArray[] wkLayered; // (layer, n_kv_heads, head_size)
39-
public FloatArray[] wvLayered; // (layer, n_kv_heads * head_size)
40-
public FloatArray[] woLayered; // (layer, n_heads * head_size, dim)
39+
public HalfFloatArray[] wqLayered; // (layer, n_heads * head_size)
40+
public HalfFloatArray[] wkLayered; // (layer, n_kv_heads, head_size)
41+
public HalfFloatArray[] wvLayered; // (layer, n_kv_heads * head_size)
42+
public HalfFloatArray[] woLayered; // (layer, n_heads * head_size, dim)
4143
public FloatArray[] rms_ffn_weightLayered; // (layer, dim)
42-
public FloatArray[] w1Layered; // (layer, hidden_dim, dim)
43-
public FloatArray[] w2Layered; // (layer, dim, hidden_dim)
44+
public HalfFloatArray[] w1Layered; // (layer, hidden_dim, dim)
45+
public HalfFloatArray[] w2Layered; // (layer, dim, hidden_dim)
4446
//
45-
public FloatArray[] w3Layered; // (layer, hidden_dim, dim)
47+
public HalfFloatArray[] w3Layered; // (layer, hidden_dim, dim)
4648
public FloatArray rms_final_weight_as_floatArray;
4749
public FloatArray tokenEmbeddingTable; // (vocab_size, dim)
4850
public FloatArray freq_cis_realFlat; // (seq_len, head_size/2)
@@ -115,9 +117,10 @@ public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight,
115117
/**
116118
* Constructor for TornadoVM mode
117119
*/
118-
public Weights(FloatArray tokenEmbeddingTable, FloatArray[] rms_att_weightLayered,
119-
FloatArray[] wqLayered, FloatArray[] wkLayered, FloatArray[] wvLayered, FloatArray[] woLayered,
120-
FloatArray[] rms_ffn_weightLayered, FloatArray[] w1Layered, FloatArray[] w2Layered, FloatArray[] w3Layered,
120+
public Weights(FloatArray tokenEmbeddingTable,
121+
FloatArray[] rms_att_weightLayered,
122+
HalfFloatArray[] wqLayered, HalfFloatArray[] wkLayered, HalfFloatArray[] wvLayered, HalfFloatArray[] woLayered,
123+
FloatArray[] rms_ffn_weightLayered, HalfFloatArray[] w1Layered, HalfFloatArray[] w2Layered, HalfFloatArray[] w3Layered,
121124
FloatArray rms_final_weight_as_floatArray, FloatArray freq_cis_realFlat, FloatArray freq_cis_imagFlat,
122125
ByteArray wclsByteArray, GGMLType weightType) {
123126
// Standard format (null when using TornadoVM)

src/main/java/com/example/tornadovm/TransformerComputeKernelsLayered.java

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import uk.ac.manchester.tornado.api.KernelContext;
44
import uk.ac.manchester.tornado.api.annotations.Parallel;
55
import uk.ac.manchester.tornado.api.math.TornadoMath;
6+
import uk.ac.manchester.tornado.api.types.HalfFloat;
67
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
8+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
79
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
810

911
public class TransformerComputeKernelsLayered {
@@ -454,6 +456,24 @@ public static void matrixVectorGeneric(KernelContext context, FloatArray x, Floa
454456
}
455457
}
456458

459+
public static void matrixVectorGeneric(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w, int n, int d, int localWorkGroupSize) {
460+
// One row per workgroup (not per thread)
461+
int rowId = context.groupIdx;
462+
int localId = context.localIdx;
463+
int localSize = localWorkGroupSize;
464+
465+
// Early exit if this workgroup is beyond our output dimension
466+
if (rowId >= d) {
467+
return;
468+
}
469+
float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, n, d);
470+
471+
// Thread 0 in each workgroup writes the final result
472+
if (localId == 0) {
473+
hb.set(rowId, sum);
474+
}
475+
}
476+
457477
/**
458478
* Matrix-vector multiplication with residual connection.
459479
* Combines regular matrix multiplication with addition of existing values.
@@ -468,7 +488,7 @@ public static void matrixVectorGeneric(KernelContext context, FloatArray x, Floa
468488
* @param d Output dimension
469489
* @param localWorkGroupSize Work group size
470490
*/
471-
public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, FloatArray hb, FloatArray w, int n, int d, int localWorkGroupSize) {
491+
public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w, int n, int d, int localWorkGroupSize) {
472492
// One row per workgroup (not per thread)
473493
int rowId = context.groupIdx;
474494
int localId = context.localIdx;
@@ -504,7 +524,7 @@ public static void matrixVectorGenericWithResidual(KernelContext context, FloatA
504524
* @param d Hidden dimension
505525
* @param localWorkGroupSize Work group size
506526
*/
507-
public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, FloatArray x, FloatArray hb, FloatArray w1, FloatArray w3, int n, int d, int localWorkGroupSize) {
527+
public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d, int localWorkGroupSize) {
508528
// One row per workgroup (not per thread)
509529
int rowId = context.groupIdx;
510530
int localId = context.localIdx;
@@ -597,4 +617,35 @@ public static float matrixVectorRowMajorOptimized(KernelContext context, int loc
597617

598618
return localSum[0];
599619
}
620+
621+
public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, FloatArray x, HalfFloatArray w, int n, int d) {
622+
int rowId = context.groupIdx;
623+
int localId = context.localIdx;
624+
625+
// Allocate local memory for reduction
626+
float[] localSum = context.allocateFloatLocalArray(localSize);
627+
628+
int rowOffset = rowId * n;
629+
630+
// Each thread calculates partial dot product
631+
float partialSum = 0.0f;
632+
for (int j = localId; j < n; j += localSize) {
633+
int matrixIdx = rowOffset + j;
634+
partialSum += w.get(matrixIdx).getFloat32() * x.get(j);
635+
}
636+
637+
// Store partial sum in local memory
638+
localSum[localId] = partialSum;
639+
context.localBarrier();
640+
641+
// Parallel reduction within workgroup
642+
for (int stride = localSize / 2; stride > 0; stride >>= 1) {
643+
if (localId < stride) {
644+
localSum[localId] += localSum[localId + stride];
645+
}
646+
context.localBarrier();
647+
}
648+
649+
return localSum[0];
650+
}
600651
}

0 commit comments

Comments
 (0)