|
13 | 13 | import com.example.model.Configuration; |
14 | 14 | import com.example.model.Model; |
15 | 15 | import com.example.model.ModelType; |
16 | | -import com.example.model.llama.LlamaConfiguration; |
17 | | -import com.example.model.llama.Llama; |
18 | | -import com.example.model.mistral.Mistral; |
19 | | -import com.example.model.mistral.MistralConfiguration; |
20 | 16 | import com.example.inference.operation.RoPE; |
21 | | -import com.example.tokenizer.impl.LlamaTokenizer; |
22 | | -import com.example.tokenizer.impl.MistralTokenizer; |
23 | | -import com.example.tokenizer.impl.Tokenizer; |
24 | | -import com.example.tokenizer.vocabulary.Vocabulary; |
25 | 17 | import uk.ac.manchester.tornado.api.types.HalfFloat; |
26 | 18 | import uk.ac.manchester.tornado.api.types.arrays.ByteArray; |
27 | 19 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray; |
|
33 | 25 | import java.nio.channels.FileChannel; |
34 | 26 | import java.nio.file.Path; |
35 | 27 | import java.nio.file.StandardOpenOption; |
36 | | -import java.util.Arrays; |
37 | | -import java.util.List; |
38 | 28 | import java.util.Map; |
39 | 29 | import java.util.function.IntFunction; |
40 | | -import java.util.stream.Collectors; |
41 | | -import java.util.stream.IntStream; |
42 | 30 |
|
43 | 31 | public final class ModelLoader { |
44 | 32 | private static final String TOKENIZER_LLAMA_3_MODEL = "gpt2"; |
45 | 33 | private static final String TOKENIZER_MISTRAL_MODEL = "llama"; |
46 | 34 |
|
47 | | - private static final String LLAMA_3_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; |
48 | | - private static final String MISTRAL_PATTERN = "\\S+|\\s+"; |
49 | | - |
50 | 35 | private static ModelType detectModelType(Map<String, Object> metadata) { |
51 | 36 | String name = (String) metadata.get("general.name"); |
52 | 37 | String tokenizerModel = (String) metadata.get("tokenizer.ggml.model"); |
@@ -232,37 +217,6 @@ private static Weights createStandardWeights(Map<String, GGMLTensorEntry> tensor |
232 | 217 | FloatBuffer.wrap(ropeFreqs.first()), FloatBuffer.wrap(ropeFreqs.second()), loadQuantized(outputWeight), outputWeight.ggmlType()); |
233 | 218 | } |
234 | 219 |
|
235 | | - private static Tokenizer createLlama3Tokenizer(Map<String, Object> metadata, Vocabulary vocabulary) { |
236 | | - String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges"); |
237 | | - List<Pair<Integer, Integer>> merges = Arrays.stream(mergeLines).map(line -> line.split(" ")) |
238 | | - .map(parts -> new Pair<>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).toList(); |
239 | | - |
240 | | - int allTokens = vocabulary.size(); |
241 | | - int baseTokens = 128000; // assume all tokens after the base ones are special. |
242 | | - int reservedSpecialTokens = allTokens - baseTokens; |
243 | | - List<String> specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList(); |
244 | | - |
245 | | - assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent()); |
246 | | - |
247 | | - Map<String, Integer> specialTokens = IntStream.range(0, specialTokensList.size()).boxed().collect(Collectors.toMap(i -> specialTokensList.get(i), i -> baseTokens + i)); |
248 | | - |
249 | | - return new LlamaTokenizer(vocabulary, merges, LLAMA_3_PATTERN, specialTokens); |
250 | | - |
251 | | - } |
252 | | - |
253 | | - private static Tokenizer createMistralTokenizer(Map<String, Object> metadata, Vocabulary vocabulary) { |
254 | | - int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type"); |
255 | | - List<Integer> specialTokensList = IntStream.range(0, vocabulary.size()).filter(t -> tokenTypes[t] != 1 && tokenTypes[t] != 6).boxed().toList(); |
256 | | - Map<String, Integer> specialTokens = |
257 | | - IntStream.range(0, specialTokensList.size()) |
258 | | - .boxed() |
259 | | - .collect(Collectors.toMap( |
260 | | - t -> vocabulary.get(t), |
261 | | - t -> t) |
262 | | - ); |
263 | | - return new MistralTokenizer(vocabulary, null, specialTokens, tokenTypes); |
264 | | - } |
265 | | - |
266 | 220 | public static FloatTensor loadQuantized(GGMLTensorEntry entry) { |
267 | 221 | GGMLType ggmlType = entry.ggmlType(); |
268 | 222 | return switch (ggmlType) { |
|
0 commit comments