Skip to content

Commit 45bfbbe

Browse files
committed
Refactor ModelLoader for improved accessibility and cleanup
1 parent 783883c commit 45bfbbe

File tree

1 file changed

+38
-71
lines changed

1 file changed

+38
-71
lines changed

src/main/java/com/example/loader/weights/ModelLoader.java

Lines changed: 38 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import com.example.core.model.tensor.GGMLTensorEntry;
1010
import com.example.core.model.tensor.Q4_0FloatTensor;
1111
import com.example.core.model.tensor.Q8_0FloatTensor;
12-
import com.example.core.types.Float16;
1312
import com.example.core.types.Pair;
1413
import com.example.inference.engine.impl.Configuration;
1514
import com.example.inference.engine.impl.Llama;
@@ -22,7 +21,6 @@
2221
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
2322

2423
import java.io.IOException;
25-
import java.lang.foreign.MemorySegment;
2624
import java.nio.ByteOrder;
2725
import java.nio.FloatBuffer;
2826
import java.nio.channels.FileChannel;
@@ -35,9 +33,6 @@
3533
import java.util.stream.Collectors;
3634
import java.util.stream.IntStream;
3735

38-
import static com.example.core.model.tensor.FloatTensor.readByte;
39-
import static com.example.core.model.tensor.FloatTensor.readShort;
40-
4136
public final class ModelLoader {
4237
private static final String TOKENIZER_LLAMA_3_MODEL = "gpt2";
4338

@@ -105,8 +100,7 @@ private static Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tenso
105100
GGMLTensorEntry outputWeight) {
106101
return new Weights(
107102
// Load directly to TornadoVM format
108-
loadTensorAsFloatArray(tokenEmbeddings),
109-
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
103+
loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
110104
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
111105
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
112106
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
@@ -115,7 +109,7 @@ private static Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tenso
115109
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
116110
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
117111
loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
118-
FloatArray.fromArray(ropeFreqs.first()), FloatArray.fromArray(ropeFreqs.second()), createByteArrayFromTensor(outputWeight), outputWeight.ggmlType());
112+
FloatArray.fromArray(ropeFreqs.first()), FloatArray.fromArray(ropeFreqs.second()), loadTensorAsHalfFloatArray(outputWeight), outputWeight.ggmlType());
119113
}
120114

121115
/**
@@ -135,23 +129,51 @@ private static Weights createStandardWeights(Map<String, GGMLTensorEntry> tensor
135129
FloatBuffer.wrap(ropeFreqs.first()), FloatBuffer.wrap(ropeFreqs.second()), loadQuantized(outputWeight), outputWeight.ggmlType());
136130
}
137131

138-
private static FloatArray[] loadArrayAsFloatArray(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
132+
private static Tokenizer createTokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
133+
String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges");
134+
List<Pair<Integer, Integer>> merges = Arrays.stream(mergeLines).map(line -> line.split(" "))
135+
.map(parts -> new Pair<>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).toList();
136+
137+
int allTokens = vocabulary.size();
138+
int baseTokens = 128000; // assume all tokens after the base ones are special.
139+
int reservedSpecialTokens = allTokens - baseTokens;
140+
List<String> specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList();
141+
142+
assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent());
143+
144+
Map<String, Integer> specialTokens = IntStream.range(0, specialTokensList.size()).boxed().collect(Collectors.toMap(i -> specialTokensList.get(i), i -> baseTokens + i));
145+
146+
return new Tokenizer(vocabulary, merges, LLAMA_3_PATTERN, specialTokens);
147+
}
148+
149+
public static FloatTensor loadQuantized(GGMLTensorEntry entry) {
150+
GGMLType ggmlType = entry.ggmlType();
151+
return switch (ggmlType) {
152+
// case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
153+
case Q8_0 -> new Q8_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
154+
case Q4_0 -> new Q4_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
155+
case F16 -> new F16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
156+
default -> throw new UnsupportedOperationException("Quantization format " + ggmlType);
157+
};
158+
}
159+
160+
public static FloatArray[] loadArrayAsFloatArray(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
139161
FloatArray[] array = new FloatArray[size];
140162
for (int i = 0; i < size; i++) {
141163
array[i] = loadTensorAsFloatArray(getTensorEntry.apply(i));
142164
}
143165
return array;
144166
}
145167

146-
private static HalfFloatArray[] loadArrayAsHalfFloatArray(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
168+
public static HalfFloatArray[] loadArrayAsHalfFloatArray(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
147169
HalfFloatArray[] array = new HalfFloatArray[size];
148170
for (int i = 0; i < size; i++) {
149171
array[i] = loadTensorAsHalfFloatArray(getTensorEntry.apply(i));
150172
}
151173
return array;
152174
}
153175

154-
private static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) {
176+
public static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) {
155177
if (tensorEntry.ggmlType() == GGMLType.F32) {
156178
FloatBuffer buffer = tensorEntry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
157179
return FloatArray.fromFloatBuffer(buffer);
@@ -160,21 +182,20 @@ private static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) {
160182
}
161183
}
162184

163-
164-
private static FloatArray[] loadArrayAsFloatArrayFromBuffer(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
185+
public static FloatArray[] loadArrayAsFloatArrayFromBuffer(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
165186
FloatArray[] array = new FloatArray[size];
166187
for (int i = 0; i < size; i++) {
167188
array[i] = floatBufferToFloatArray(getTensorEntry.apply(i));
168189
}
169190
return array;
170191
}
171192

172-
private static ByteArray createByteArrayFromTensor(GGMLTensorEntry entry) {
193+
public static ByteArray createByteArrayFromTensor(GGMLTensorEntry entry) {
173194
FloatTensor tensor = loadQuantized(entry);
174195
return ByteArray.fromSegment(tensor.asMemorySegment());
175196
}
176197

177-
private static FloatArray loadTensorAsFloatArray(GGMLTensorEntry entry) {
198+
public static FloatArray loadTensorAsFloatArray(GGMLTensorEntry entry) {
178199
if (entry.ggmlType() == GGMLType.F32) {
179200
// For F32, we can directly create FloatArray from memory
180201
FloatBuffer buffer = entry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
@@ -194,18 +215,10 @@ private static FloatArray loadTensorAsFloatArray(GGMLTensorEntry entry) {
194215
}
195216
}
196217

197-
private static HalfFloatArray loadTensorAsHalfFloatArray(GGMLTensorEntry entry) {
218+
public static HalfFloatArray loadTensorAsHalfFloatArray(GGMLTensorEntry entry) {
198219
if (entry.ggmlType() == GGMLType.F32) {
199-
// For F32, we can directly create FloatArray from memory
200-
// FloatBuffer buffer = entry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
201-
// FloatArray array = new FloatArray(buffer.remaining());
202-
// for (int i = 0; i < buffer.remaining(); i++) {
203-
// array.set(i, buffer.get());
204-
// }
205-
// return array
206-
// ;
207220
System.out.println("Loading F32 tensor as HalfFloatArray");
208-
return null;
221+
return null;
209222
} else {
210223
// For quantized formats, we need to load through FloatTensor
211224
FloatTensor tensor = loadQuantized(entry);
@@ -218,52 +231,6 @@ private static HalfFloatArray loadTensorAsHalfFloatArray(GGMLTensorEntry entry)
218231
}
219232
}
220233

221-
public static float getFloat(int index, int size, MemorySegment memorySegment) {
222-
assert 0 <= index && index < size;
223-
int blockIndex = index / GGMLType.Q4_0.getBlockSize();
224-
int blockOffset = blockIndex * GGMLType.Q4_0.getTypeSize();
225-
float scale = Float.float16ToFloat(readShort(memorySegment, blockOffset));
226-
byte quant;
227-
int modIndex = index % GGMLType.Q4_0.getBlockSize();
228-
if (modIndex < GGMLType.Q4_0.getBlockSize() / 2) {
229-
quant = (byte) (readByte(memorySegment, blockOffset + Float16.BYTES + modIndex) & 0x0F);
230-
} else {
231-
quant = (byte) ((readByte(memorySegment, blockOffset + Float16.BYTES + modIndex - GGMLType.Q4_0.getBlockSize() / 2) >>> 4) & 0x0F);
232-
}
233-
quant -= 8;
234-
return quant * scale;
235-
}
236-
237-
private static Tokenizer createTokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
238-
String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges");
239-
List<Pair<Integer, Integer>> merges = Arrays.stream(mergeLines).map(line -> line.split(" "))
240-
.map(parts -> new Pair<>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).toList();
241-
242-
int allTokens = vocabulary.size();
243-
int baseTokens = 128000; // assume all tokens after the base ones are special.
244-
int reservedSpecialTokens = allTokens - baseTokens;
245-
List<String> specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList();
246-
247-
assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent());
248-
249-
Map<String, Integer> specialTokens = IntStream.range(0, specialTokensList.size()).boxed().collect(Collectors.toMap(i -> specialTokensList.get(i), i -> baseTokens + i));
250-
251-
return new Tokenizer(vocabulary, merges, LLAMA_3_PATTERN, specialTokens);
252-
}
253-
254-
public static FloatTensor loadQuantized(GGMLTensorEntry entry) {
255-
GGMLType ggmlType = entry.ggmlType();
256-
// System.out.println("Loading quantized tensor of type " + entry.name());
257-
return switch (ggmlType) {
258-
// case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
259-
case Q8_0 -> new Q8_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
260-
case Q4_0 -> new Q4_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
261-
// case BF16 -> new BF16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
262-
case F16 -> new F16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
263-
default -> throw new UnsupportedOperationException("Quantization format " + ggmlType);
264-
};
265-
}
266-
267234
public static FloatTensor[] loadArrayOfQuantized(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
268235
FloatTensor[] array = new FloatTensor[size];
269236
for (int i = 0; i < size; i++) {

0 commit comments

Comments
 (0)