|
5 | 5 | import com.example.inference.sampler.CategoricalSampler; |
6 | 6 | import com.example.inference.sampler.Sampler; |
7 | 7 | import com.example.inference.sampler.ToppSampler; |
8 | | -import com.example.model.Model; |
9 | 8 | import com.example.loader.weights.ModelLoader; |
| 9 | +import com.example.model.Model; |
10 | 10 | import com.example.tornadovm.FloatArrayUtils; |
11 | 11 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray; |
12 | 12 |
|
@@ -106,16 +106,48 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp, |
106 | 106 | return sampler; |
107 | 107 | } |
108 | 108 |
|
109 | | - public static void main(String[] args) throws IOException { |
110 | | - Options options = Options.parseOptions(args); |
111 | | - Model model; |
| 109 | + /** |
| 110 | + * Loads the language model based on the given options. |
| 111 | + * <p> |
| 112 | + * If Ahead-of-Time (AOT) mode is enabled, attempts to use a pre-loaded compiled model. |
| 113 | + * Otherwise, loads the model from the specified path using the model loader. |
| 114 | + * </p> |
| 115 | + * |
| 116 | + * @param options the parsed CLI options containing model path and max token limit |
| 117 | + * @return the loaded {@link Model} instance |
| 118 | + * @throws IOException if the model fails to load |
| 119 | + * @throws IllegalStateException if AOT loading is enabled but the preloaded model is unavailable |
| 120 | + */ |
| 121 | + private static Model loadModel(Options options) throws IOException { |
112 | 122 | if (USE_AOT) { |
113 | | - model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens()); |
114 | | - } else { |
115 | | - model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true); |
| 123 | + Model model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens()); |
| 124 | + if (model == null) { |
| 125 | + throw new IllegalStateException("Failed to load precompiled AOT model."); |
| 126 | + } |
| 127 | + return model; |
116 | 128 | } |
117 | | - assert model != null; |
118 | | - Sampler sampler = selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed()); |
| 129 | + return ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true); |
| 130 | + } |
| 131 | + |
| 132 | + private static Sampler createSampler(Model model, Options options) { |
| 133 | + return selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed()); |
| 134 | + } |
| 135 | + |
| 136 | + /** |
| 137 | + * Entry point for running the LLaMA-based model with provided command-line arguments. |
| 138 | + * |
| 139 | + * <p>Initializes model options, loads the appropriate model (either AOT or on-demand), |
| 140 | + * configures the sampler, and runs either in interactive or single-instruction mode |
| 141 | + * based on the input options.</p> |
| 142 | + * |
| 143 | + * @param args command-line arguments used to configure model path, temperature, seed, etc. |
| 144 | + * @throws IOException if model loading or file operations fail. |
| 145 | + */ |
| 146 | + public static void main(String[] args) throws IOException { |
| 147 | + Options options = Options.parseOptions(args); |
| 148 | + Model model = loadModel(options); |
| 149 | + Sampler sampler = createSampler(model, options); |
| 150 | + |
119 | 151 | if (options.interactive()) { |
120 | 152 | model.runInteractive(sampler, options); |
121 | 153 | } else { |
|
0 commit comments