Skip to content

Commit ac7a0ae

Browse files
committed
Remove BF16FloatTensor and implement TornadoVM support for weights
BF16FloatTensor class was entirely removed, simplifying the codebase. Additionally, TornadoVM support was added for loading model weights, allowing improved performance when using TornadoVM. New helper methods were introduced for handling FloatArray and ByteArray conversions.
1 parent 12aa536 commit ac7a0ae

File tree

3 files changed

+90
-128
lines changed

3 files changed

+90
-128
lines changed

src/main/java/com/example/core/model/tensor/BF16FloatTensor.java

Lines changed: 0 additions & 104 deletions
This file was deleted.

src/main/java/com/example/core/model/tensor/FloatTensor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ public abstract class FloatTensor {
5454
}
5555
}
5656

57-
static short readShort(MemorySegment memorySegment, long offset) {
57+
public static short readShort(MemorySegment memorySegment, long offset) {
5858
// The MemorySegment.get* methods should be used instead.
5959
return UNSAFE.getShort(memorySegment.address() + offset);
6060
}
6161

62-
static byte readByte(MemorySegment memorySegment, long offset) {
62+
public static byte readByte(MemorySegment memorySegment, long offset) {
6363
// The MemorySegment.get* methods should be used instead.
6464
return UNSAFE.getByte(memorySegment.address() + offset);
6565
}

src/main/java/com/example/loader/weights/ModelLoader.java

Lines changed: 88 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,18 @@
99
import com.example.core.model.tensor.GGMLTensorEntry;
1010
import com.example.core.model.tensor.Q4_0FloatTensor;
1111
import com.example.core.model.tensor.Q8_0FloatTensor;
12+
import com.example.core.types.Float16;
1213
import com.example.core.types.Pair;
1314
import com.example.inference.engine.impl.Configuration;
1415
import com.example.inference.engine.impl.Llama;
1516
import com.example.inference.operation.RoPE;
1617
import com.example.tokenizer.impl.Tokenizer;
1718
import com.example.tokenizer.vocabulary.Vocabulary;
19+
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
1820
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
1921

2022
import java.io.IOException;
23+
import java.lang.foreign.MemorySegment;
2124
import java.nio.ByteOrder;
2225
import java.nio.FloatBuffer;
2326
import java.nio.channels.FileChannel;
@@ -30,6 +33,9 @@
3033
import java.util.stream.Collectors;
3134
import java.util.stream.IntStream;
3235

36+
import static com.example.core.model.tensor.FloatTensor.readByte;
37+
import static com.example.core.model.tensor.FloatTensor.readShort;
38+
3339
public final class ModelLoader {
3440
private static final String TOKENIZER_LLAMA_3_MODEL = "gpt2";
3541

@@ -83,39 +89,83 @@ public static Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, Co
8389
);
8490

8591
GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
92+
GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings);
8693

87-
return createRegularWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings);
94+
if (LlamaApp.USE_TORNADOVM) {
95+
System.out.println("Loading weights in TornadoVM format");
96+
return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
97+
} else {
98+
return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
99+
}
88100
}
89101

90-
private static Weights createRegularWeights(Map<String, GGMLTensorEntry> tensorEntries,
91-
Configuration config,
92-
Pair<float[], float[]> ropeFreqs,
93-
GGMLTensorEntry tokenEmbeddings) {
94-
float[] ropeFreqsReal = ropeFreqs.first();
95-
float[] ropeFreqsImag = ropeFreqs.second();
96-
return new Weights(loadQuantized(tokenEmbeddings), loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
102+
private static Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
103+
GGMLTensorEntry outputWeight) {
104+
return new Weights(
105+
// Load directly to TornadoVM format
106+
loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
107+
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
108+
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
109+
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
110+
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
111+
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
112+
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
113+
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
114+
loadArrayAsFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
115+
FloatArray.fromArray(ropeFreqs.first()), FloatArray.fromArray(ropeFreqs.second()), createByteArrayFromTensor(outputWeight), outputWeight.ggmlType());
116+
}
117+
118+
/**
119+
* Creates weights in standard format only
120+
*/
121+
private static Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
122+
GGMLTensorEntry outputWeight) {
123+
return new Weights(loadQuantized(tokenEmbeddings), loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
97124
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
98125
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
99126
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
100127
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
101128
loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
102-
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1
103-
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2
104-
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3
105-
// final layer normalization that's applied after all transformer blocks but before the final projection to vocabulary logits.
106-
toFloatBuffer(tensorEntries.get("output_norm.weight")),
107-
FloatBuffer.wrap(ropeFreqsReal), //
108-
FloatBuffer.wrap(ropeFreqsImag), //
109-
// If "output.weight" is not present, then the embedding weights are tied/shared with the decoder.
110-
// This is commonly referred to as "tie word embeddings".
111-
loadQuantized(tensorEntries.getOrDefault("output.weight", tokenEmbeddings)),
112-
tensorEntries.getOrDefault("output.weight", tokenEmbeddings).ggmlType());
129+
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
130+
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
131+
loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), toFloatBuffer(tensorEntries.get("output_norm.weight")),
132+
FloatBuffer.wrap(ropeFreqs.first()), FloatBuffer.wrap(ropeFreqs.second()), loadQuantized(outputWeight), outputWeight.ggmlType());
133+
}
134+
135+
private static FloatArray[] loadArrayAsFloatArray(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
136+
FloatArray[] array = new FloatArray[size];
137+
for (int i = 0; i < size; i++) {
138+
array[i] = loadTensorAsFloatArray(getTensorEntry.apply(i));
139+
}
140+
return array;
141+
}
142+
143+
private static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) {
144+
if (tensorEntry.ggmlType() == GGMLType.F32) {
145+
FloatBuffer buffer = tensorEntry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
146+
return FloatArray.fromFloatBuffer(buffer);
147+
} else {
148+
throw new UnsupportedOperationException("Conversion to FloatArray from " + tensorEntry.ggmlType());
149+
}
113150
}
151+
152+
private static FloatArray[] loadArrayAsFloatArrayFromBuffer(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
153+
FloatArray[] array = new FloatArray[size];
154+
for (int i = 0; i < size; i++) {
155+
array[i] = floatBufferToFloatArray(getTensorEntry.apply(i));
156+
}
157+
return array;
158+
}
159+
160+
private static ByteArray createByteArrayFromTensor(GGMLTensorEntry entry) {
161+
FloatTensor tensor = loadQuantized(entry);
162+
return ByteArray.fromSegment(tensor.asMemorySegment());
163+
}
164+
114165
private static FloatArray loadTensorAsFloatArray(GGMLTensorEntry entry) {
115166
if (entry.ggmlType() == GGMLType.F32) {
116167
// For F32, we can directly create FloatArray from memory
117-
FloatBuffer buffer = entry.memorySegment().asByteBuffer()
118-
.order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
168+
FloatBuffer buffer = entry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
119169
FloatArray array = new FloatArray(buffer.remaining());
120170
for (int i = 0; i < buffer.remaining(); i++) {
121171
array.set(i, buffer.get());
@@ -132,6 +182,22 @@ private static FloatArray loadTensorAsFloatArray(GGMLTensorEntry entry) {
132182
}
133183
}
134184

185+
public static float getFloat(int index, int size, MemorySegment memorySegment) {
186+
assert 0 <= index && index < size;
187+
int blockIndex = index / GGMLType.Q4_0.getBlockSize();
188+
int blockOffset = blockIndex * GGMLType.Q4_0.getTypeSize();
189+
float scale = Float.float16ToFloat(readShort(memorySegment, blockOffset));
190+
byte quant;
191+
int modIndex = index % GGMLType.Q4_0.getBlockSize();
192+
if (modIndex < GGMLType.Q4_0.getBlockSize() / 2) {
193+
quant = (byte) (readByte(memorySegment, blockOffset + Float16.BYTES + modIndex) & 0x0F);
194+
} else {
195+
quant = (byte) ((readByte(memorySegment, blockOffset + Float16.BYTES + modIndex - GGMLType.Q4_0.getBlockSize() / 2) >>> 4) & 0x0F);
196+
}
197+
quant -= 8;
198+
return quant * scale;
199+
}
200+
135201
private static Tokenizer createTokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
136202
String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges");
137203
List<Pair<Integer, Integer>> merges = Arrays.stream(mergeLines).map(line -> line.split(" "))
@@ -151,7 +217,7 @@ private static Tokenizer createTokenizer(Map<String, Object> metadata, Vocabular
151217

152218
public static FloatTensor loadQuantized(GGMLTensorEntry entry) {
153219
GGMLType ggmlType = entry.ggmlType();
154-
// System.out.println("Tensor type: " + ggmlType + " " + entry.name() + " " + entry.shape().length);
220+
// System.out.println("Loading quantized tensor of type " + entry.name());
155221
return switch (ggmlType) {
156222
// case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
157223
case Q8_0 -> new Q8_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());

0 commit comments

Comments
 (0)