|
1 | 1 | package com.example; |
2 | 2 |
|
3 | 3 | import com.example.aot.AOT; |
4 | | -import com.example.auxiliary.ChatFormat; |
5 | 4 | import com.example.core.model.tensor.FloatTensor; |
6 | | -import com.example.inference.CategoricalSampler; |
7 | | -import com.example.inference.Sampler; |
8 | | -import com.example.inference.ToppSampler; |
9 | | -import com.example.inference.engine.impl.Llama; |
10 | | -import com.example.inference.engine.impl.Options; |
| 5 | +import com.example.inference.sampler.CategoricalSampler; |
| 6 | +import com.example.inference.sampler.Sampler; |
| 7 | +import com.example.inference.sampler.ToppSampler; |
| 8 | +import com.example.model.Model; |
11 | 9 | import com.example.loader.weights.ModelLoader; |
12 | | -import com.example.loader.weights.State; |
13 | 10 | import com.example.tornadovm.FloatArrayUtils; |
14 | | -import com.example.tornadovm.TornadoVMMasterPlan; |
15 | 11 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray; |
16 | 12 |
|
17 | 13 | import java.io.IOException; |
18 | | -import java.util.ArrayList; |
19 | | -import java.util.List; |
20 | | -import java.util.Scanner; |
21 | | -import java.util.Set; |
22 | | -import java.util.function.IntConsumer; |
23 | 14 | import java.util.random.RandomGenerator; |
24 | 15 | import java.util.random.RandomGeneratorFactory; |
25 | 16 |
|
@@ -115,156 +106,20 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp, |
115 | 106 | return sampler; |
116 | 107 | } |
117 | 108 |
|
118 | | - static void runInteractive(Llama model, Sampler sampler, Options options) { |
119 | | - State state = null; |
120 | | - List<Integer> conversationTokens = new ArrayList<>(); |
121 | | - ChatFormat chatFormat = new ChatFormat(model.tokenizer()); |
122 | | - conversationTokens.add(chatFormat.beginOfText); |
123 | | - if (options.systemPrompt() != null) { |
124 | | - conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); |
125 | | - } |
126 | | - int startPosition = 0; |
127 | | - Scanner in = new Scanner(System.in); |
128 | | - |
129 | | - // Initialize TornadoVM plan once at the beginning if GPU path is enabled |
130 | | - TornadoVMMasterPlan tornadoVMPlan = null; |
131 | | - |
132 | | - try { |
133 | | - while (true) { |
134 | | - System.out.print("> "); |
135 | | - System.out.flush(); |
136 | | - String userText = in.nextLine(); |
137 | | - if (List.of("quit", "exit").contains(userText)) { |
138 | | - break; |
139 | | - } |
140 | | - if (state == null) { |
141 | | - state = model.createNewState(); |
142 | | - } |
143 | | - |
144 | | - if (USE_TORNADOVM && tornadoVMPlan == null) { |
145 | | - tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model); |
146 | | - } |
147 | | - |
148 | | - conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText))); |
149 | | - conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); |
150 | | - Set<Integer> stopTokens = chatFormat.getStopTokens(); |
151 | | - |
152 | | - List<Integer> responseTokens; |
153 | | - IntConsumer tokenConsumer = token -> { |
154 | | - if (options.stream()) { |
155 | | - if (!model.tokenizer().isSpecialToken(token)) { |
156 | | - System.out.print(model.tokenizer().decode(List.of(token))); |
157 | | - } |
158 | | - } |
159 | | - }; |
160 | | - |
161 | | - // Choose between GPU and CPU path based on configuration |
162 | | - if (USE_TORNADOVM) { |
163 | | - // GPU path using TornadoVM |
164 | | - responseTokens = Llama.generateTokensGPU(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), |
165 | | - sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan); |
166 | | - } else { |
167 | | - // CPU path |
168 | | - responseTokens = Llama.generateTokens(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler, |
169 | | - options.echo(), tokenConsumer); |
170 | | - } |
171 | | - |
172 | | - // Include stop token in the prompt history, but not in the response displayed to the user. |
173 | | - conversationTokens.addAll(responseTokens); |
174 | | - startPosition = conversationTokens.size(); |
175 | | - Integer stopToken = null; |
176 | | - if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { |
177 | | - stopToken = responseTokens.getLast(); |
178 | | - responseTokens.removeLast(); |
179 | | - } |
180 | | - if (!options.stream()) { |
181 | | - String responseText = model.tokenizer().decode(responseTokens); |
182 | | - System.out.println(responseText); |
183 | | - } |
184 | | - if (stopToken == null) { |
185 | | - System.err.println("\n Ran out of context length...\n Increase context length with by passing to llama-tornado --max-tokens XXX"); |
186 | | - break; |
187 | | - } |
188 | | - System.out.print("\n"); |
189 | | - |
190 | | - // Optionally print performance metrics after each response |
191 | | - if (SHOW_PERF_INTERACTIVE) { |
192 | | - Llama.LastRunMetrics.printMetrics(); |
193 | | - } |
194 | | - } |
195 | | - } finally { |
196 | | - // Clean up TornadoVM resources when exiting the chat loop |
197 | | - if (USE_TORNADOVM && tornadoVMPlan != null) { |
198 | | - try { |
199 | | - tornadoVMPlan.freeTornadoExecutionPlan(); |
200 | | - } catch (Exception e) { |
201 | | - System.err.println("Error while cleaning up TornadoVM resources: " + e.getMessage()); |
202 | | - } |
203 | | - } |
204 | | - } |
205 | | - } |
206 | | - |
207 | | - static void runInstructOnce(Llama model, Sampler sampler, Options options) { |
208 | | - State state = model.createNewState(); |
209 | | - ChatFormat chatFormat = new ChatFormat(model.tokenizer()); |
210 | | - TornadoVMMasterPlan tornadoVMPlan = null; |
211 | | - |
212 | | - List<Integer> promptTokens = new ArrayList<>(); |
213 | | - promptTokens.add(chatFormat.beginOfText); |
214 | | - if (options.systemPrompt() != null) { |
215 | | - promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); |
216 | | - } |
217 | | - promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt()))); |
218 | | - promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); |
219 | | - List<Integer> responseTokens; |
220 | | - |
221 | | - // Define the token consumer |
222 | | - IntConsumer tokenConsumer = token -> { |
223 | | - if (options.stream()) { |
224 | | - if (!model.tokenizer().isSpecialToken(token)) { |
225 | | - System.out.print(model.tokenizer().decode(List.of(token))); |
226 | | - } |
227 | | - } |
228 | | - }; |
229 | | - |
230 | | - Set<Integer> stopTokens = chatFormat.getStopTokens(); |
231 | | - if (USE_TORNADOVM) { |
232 | | - tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model); |
233 | | - // Call generateTokensGPU without the token consumer parameter |
234 | | - responseTokens = Llama.generateTokensGPU(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan); |
235 | | - } else { |
236 | | - // CPU path still uses the token consumer |
237 | | - responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer); |
238 | | - } |
239 | | - |
240 | | - if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { |
241 | | - responseTokens.removeLast(); |
242 | | - } |
243 | | - if (!options.stream()) { |
244 | | - String responseText = model.tokenizer().decode(responseTokens); |
245 | | - System.out.println(responseText); |
246 | | - } |
247 | | - |
248 | | - Llama.LastRunMetrics.printMetrics(); |
249 | | - |
250 | | - if (tornadoVMPlan != null) { |
251 | | - tornadoVMPlan.freeTornadoExecutionPlan(); |
252 | | - } |
253 | | - } |
254 | | - |
255 | 109 | public static void main(String[] args) throws IOException { |
256 | 110 | Options options = Options.parseOptions(args); |
257 | | - Llama model; |
| 111 | + Model model; |
258 | 112 | if (USE_AOT) { |
259 | 113 | model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens()); |
260 | 114 | } else { |
261 | 115 | model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true); |
262 | 116 | } |
263 | | - Sampler sampler = selectSampler(model.configuration().vocabularySize, options.temperature(), options.topp(), options.seed()); |
| 117 | + assert model != null; |
| 118 | + Sampler sampler = selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed()); |
264 | 119 | if (options.interactive()) { |
265 | | - runInteractive(model, sampler, options); |
| 120 | + model.runInteractive(sampler, options); |
266 | 121 | } else { |
267 | | - runInstructOnce(model, sampler, options); |
| 122 | + model.runInstructOnce(sampler, options); |
268 | 123 | } |
269 | 124 | } |
270 | 125 | } |
|
0 commit comments