|
1 | 1 | package com.example.loader.weights; |
2 | 2 |
|
3 | 3 | import com.example.LlamaApp; |
4 | | -import com.example.auxiliary.Timer; |
5 | 4 | import com.example.core.model.GGMLType; |
6 | 5 | import com.example.core.model.GGUF; |
7 | 6 | import com.example.core.model.tensor.F16FloatTensor; |
@@ -70,89 +69,10 @@ public static Model loadModel(Path ggufPath, int contextLength, boolean loadWeig |
70 | 69 | // initial load of metadata from gguf file |
71 | 70 | GGUF gguf = GGUF.loadModel(ggufPath); |
72 | 71 | FileChannel fileChannel = FileChannel.open(ggufPath, StandardOpenOption.READ); |
73 | | - |
74 | 72 | // detect model type |
75 | 73 | ModelType modelType = detectModelType(gguf.getMetadata()); |
76 | | - System.out.println("Detected model type: " + modelType); |
77 | | - |
78 | | - // load model (vocabulary, tokenizer, configuration, tensors, weights) |
79 | | - return switch (modelType) { |
80 | | - case LLAMA_3 -> loadLlamaModel(fileChannel, gguf, contextLength, loadWeights); |
81 | | - case MISTRAL -> loadMistralModel(fileChannel, gguf, contextLength, loadWeights); |
82 | | - default -> throw new UnsupportedOperationException("Unsupported model type: " + modelType); |
83 | | - }; |
84 | | - } |
85 | | - |
86 | | - public static Llama loadLlamaModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) throws IOException { |
87 | | - try (var ignored = Timer.log("Load LlaMa model")) { |
88 | | - Map<String, Object> metadata = gguf.getMetadata(); |
89 | | - |
90 | | - Vocabulary vocabulary = Vocabulary.loadLlamaVocabulary(metadata); |
91 | | - Tokenizer tokenizer = createLlama3Tokenizer(metadata, vocabulary); |
92 | | - |
93 | | - LlamaConfiguration config = new LlamaConfiguration( |
94 | | - (int) metadata.get("llama.embedding_length"), |
95 | | - (int) metadata.get("llama.feed_forward_length"), |
96 | | - (int) metadata.get("llama.block_count"), |
97 | | - (int) metadata.get("llama.attention.head_count"), |
98 | | - |
99 | | - metadata.containsKey("llama.attention.head_count_kv") ? |
100 | | - (int) metadata.get("llama.attention.head_count_kv") : |
101 | | - (int) metadata.get("llama.attention.head_count"), |
102 | | - |
103 | | - vocabulary.size(), |
104 | | - (int) metadata.get("llama.context_length"), |
105 | | - (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), |
106 | | - (float) metadata.getOrDefault("llama.rope.freq_base", 10000f) |
107 | | - ).withContextLength(contextLength); |
108 | | - |
109 | | - Weights weights = null; |
110 | | - if (loadWeights) { |
111 | | - Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); |
112 | | - weights = loadWeights(tensorEntries, config); |
113 | | - } |
114 | | - return new Llama(config, tokenizer, weights); |
115 | | - } |
116 | | - } |
117 | | - |
118 | | - public static Mistral loadMistralModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) { |
119 | | - try (var ignored = Timer.log("Load Mistral model")) { |
120 | | - Map<String, Object> metadata = gguf.getMetadata(); |
121 | | - |
122 | | - Vocabulary vocabulary = Vocabulary.loadMistralVocabulary(metadata); |
123 | | - Tokenizer tokenizer = createMistralTokenizer(metadata, vocabulary); |
124 | | - |
125 | | - int modelContextLength = (int) metadata.get("llama.context_length"); |
126 | | - if (contextLength < 0 || modelContextLength < contextLength) { |
127 | | - contextLength = modelContextLength; |
128 | | - } |
129 | | - |
130 | | - MistralConfiguration config = new MistralConfiguration( |
131 | | - (int) metadata.get("llama.embedding_length"), |
132 | | - (int) metadata.get("llama.feed_forward_length"), |
133 | | - (int) metadata.get("llama.block_count"), |
134 | | - (int) metadata.get("llama.attention.head_count"), |
135 | | - |
136 | | - metadata.containsKey("llama.attention.head_count_kv") |
137 | | - ? (int) metadata.get("llama.attention.head_count_kv") |
138 | | - : (int) metadata.get("llama.attention.head_count"), |
139 | | - |
140 | | - vocabulary.size(), |
141 | | - contextLength, |
142 | | - false, |
143 | | - (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), |
144 | | - (float) metadata.getOrDefault("llama.rope.freq_base", 10000f) |
145 | | - ); |
146 | | - |
147 | | - Weights weights = null; |
148 | | - if (loadWeights) { |
149 | | - Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); |
150 | | - weights = loadWeights(tensorEntries, config); |
151 | | - } |
152 | | - return new Mistral(config, tokenizer, weights); |
153 | | - } catch (IOException e) { |
154 | | - throw new RuntimeException(e); |
155 | | - } |
| 74 | + // model type-specific load |
| 75 | + return modelType.loadModel(fileChannel, gguf, contextLength, loadWeights); |
156 | 76 | } |
157 | 77 |
|
158 | 78 | public static Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, Configuration config) { |
|
0 commit comments