99import com .example .core .model .tensor .GGMLTensorEntry ;
1010import com .example .core .model .tensor .Q4_0FloatTensor ;
1111import com .example .core .model .tensor .Q8_0FloatTensor ;
12- import com .example .core .types .Float16 ;
1312import com .example .core .types .Pair ;
1413import com .example .inference .engine .impl .Configuration ;
1514import com .example .inference .engine .impl .Llama ;
2221import uk .ac .manchester .tornado .api .types .arrays .HalfFloatArray ;
2322
2423import java .io .IOException ;
25- import java .lang .foreign .MemorySegment ;
2624import java .nio .ByteOrder ;
2725import java .nio .FloatBuffer ;
2826import java .nio .channels .FileChannel ;
3533import java .util .stream .Collectors ;
3634import java .util .stream .IntStream ;
3735
38- import static com .example .core .model .tensor .FloatTensor .readByte ;
39- import static com .example .core .model .tensor .FloatTensor .readShort ;
40-
4136public final class ModelLoader {
4237 private static final String TOKENIZER_LLAMA_3_MODEL = "gpt2" ;
4338
@@ -105,8 +100,7 @@ private static Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tenso
105100 GGMLTensorEntry outputWeight ) {
106101 return new Weights (
107102 // Load directly to TornadoVM format
108- loadTensorAsFloatArray (tokenEmbeddings ),
109- loadArrayAsFloatArrayFromBuffer (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".attn_norm.weight" )),
103+ loadTensorAsFloatArray (tokenEmbeddings ), loadArrayAsFloatArrayFromBuffer (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".attn_norm.weight" )),
110104 loadArrayAsHalfFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".attn_q.weight" )),
111105 loadArrayAsHalfFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".attn_k.weight" )),
112106 loadArrayAsHalfFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".attn_v.weight" )),
@@ -115,7 +109,7 @@ private static Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tenso
115109 loadArrayAsHalfFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".ffn_gate.weight" )),
116110 loadArrayAsHalfFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".ffn_down.weight" )),
117111 loadArrayAsHalfFloatArray (config .numberOfLayers , i -> tensorEntries .get ("blk." + i + ".ffn_up.weight" )), floatBufferToFloatArray (tensorEntries .get ("output_norm.weight" )),
118- FloatArray .fromArray (ropeFreqs .first ()), FloatArray .fromArray (ropeFreqs .second ()), createByteArrayFromTensor (outputWeight ), outputWeight .ggmlType ());
112+ FloatArray .fromArray (ropeFreqs .first ()), FloatArray .fromArray (ropeFreqs .second ()), loadTensorAsHalfFloatArray (outputWeight ), outputWeight .ggmlType ());
119113 }
120114
121115 /**
@@ -135,23 +129,51 @@ private static Weights createStandardWeights(Map<String, GGMLTensorEntry> tensor
135129 FloatBuffer .wrap (ropeFreqs .first ()), FloatBuffer .wrap (ropeFreqs .second ()), loadQuantized (outputWeight ), outputWeight .ggmlType ());
136130 }
137131
138- private static FloatArray [] loadArrayAsFloatArray (int size , IntFunction <GGMLTensorEntry > getTensorEntry ) {
132+ private static Tokenizer createTokenizer (Map <String , Object > metadata , Vocabulary vocabulary ) {
133+ String [] mergeLines = (String []) metadata .get ("tokenizer.ggml.merges" );
134+ List <Pair <Integer , Integer >> merges = Arrays .stream (mergeLines ).map (line -> line .split (" " ))
135+ .map (parts -> new Pair <>(vocabulary .getIndex (parts [0 ]).orElseThrow (), vocabulary .getIndex (parts [1 ]).orElseThrow ())).toList ();
136+
137+ int allTokens = vocabulary .size ();
138+ int baseTokens = 128000 ; // assume all tokens after the base ones are special.
139+ int reservedSpecialTokens = allTokens - baseTokens ;
140+ List <String > specialTokensList = Arrays .stream (vocabulary .tokens (), baseTokens , allTokens ).toList ();
141+
142+ assert specialTokensList .stream ().allMatch (token -> vocabulary .getIndex (token ).isPresent ());
143+
144+ Map <String , Integer > specialTokens = IntStream .range (0 , specialTokensList .size ()).boxed ().collect (Collectors .toMap (i -> specialTokensList .get (i ), i -> baseTokens + i ));
145+
146+ return new Tokenizer (vocabulary , merges , LLAMA_3_PATTERN , specialTokens );
147+ }
148+
149+ public static FloatTensor loadQuantized (GGMLTensorEntry entry ) {
150+ GGMLType ggmlType = entry .ggmlType ();
151+ return switch (ggmlType ) {
152+ // case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
153+ case Q8_0 -> new Q8_0FloatTensor (FloatTensor .numberOfElements (entry .shape ()), entry .memorySegment ());
154+ case Q4_0 -> new Q4_0FloatTensor (FloatTensor .numberOfElements (entry .shape ()), entry .memorySegment ());
155+ case F16 -> new F16FloatTensor (FloatTensor .numberOfElements (entry .shape ()), entry .memorySegment ());
156+ default -> throw new UnsupportedOperationException ("Quantization format " + ggmlType );
157+ };
158+ }
159+
160+ public static FloatArray [] loadArrayAsFloatArray (int size , IntFunction <GGMLTensorEntry > getTensorEntry ) {
139161 FloatArray [] array = new FloatArray [size ];
140162 for (int i = 0 ; i < size ; i ++) {
141163 array [i ] = loadTensorAsFloatArray (getTensorEntry .apply (i ));
142164 }
143165 return array ;
144166 }
145167
146- private static HalfFloatArray [] loadArrayAsHalfFloatArray (int size , IntFunction <GGMLTensorEntry > getTensorEntry ) {
168+ public static HalfFloatArray [] loadArrayAsHalfFloatArray (int size , IntFunction <GGMLTensorEntry > getTensorEntry ) {
147169 HalfFloatArray [] array = new HalfFloatArray [size ];
148170 for (int i = 0 ; i < size ; i ++) {
149171 array [i ] = loadTensorAsHalfFloatArray (getTensorEntry .apply (i ));
150172 }
151173 return array ;
152174 }
153175
154- private static FloatArray floatBufferToFloatArray (GGMLTensorEntry tensorEntry ) {
176+ public static FloatArray floatBufferToFloatArray (GGMLTensorEntry tensorEntry ) {
155177 if (tensorEntry .ggmlType () == GGMLType .F32 ) {
156178 FloatBuffer buffer = tensorEntry .memorySegment ().asByteBuffer ().order (ByteOrder .LITTLE_ENDIAN ).asFloatBuffer ();
157179 return FloatArray .fromFloatBuffer (buffer );
@@ -160,21 +182,20 @@ private static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) {
160182 }
161183 }
162184
163-
164- private static FloatArray [] loadArrayAsFloatArrayFromBuffer (int size , IntFunction <GGMLTensorEntry > getTensorEntry ) {
185+ public static FloatArray [] loadArrayAsFloatArrayFromBuffer (int size , IntFunction <GGMLTensorEntry > getTensorEntry ) {
165186 FloatArray [] array = new FloatArray [size ];
166187 for (int i = 0 ; i < size ; i ++) {
167188 array [i ] = floatBufferToFloatArray (getTensorEntry .apply (i ));
168189 }
169190 return array ;
170191 }
171192
172- private static ByteArray createByteArrayFromTensor (GGMLTensorEntry entry ) {
193+ public static ByteArray createByteArrayFromTensor (GGMLTensorEntry entry ) {
173194 FloatTensor tensor = loadQuantized (entry );
174195 return ByteArray .fromSegment (tensor .asMemorySegment ());
175196 }
176197
177- private static FloatArray loadTensorAsFloatArray (GGMLTensorEntry entry ) {
198+ public static FloatArray loadTensorAsFloatArray (GGMLTensorEntry entry ) {
178199 if (entry .ggmlType () == GGMLType .F32 ) {
179200 // For F32, we can directly create FloatArray from memory
180201 FloatBuffer buffer = entry .memorySegment ().asByteBuffer ().order (ByteOrder .LITTLE_ENDIAN ).asFloatBuffer ();
@@ -194,18 +215,10 @@ private static FloatArray loadTensorAsFloatArray(GGMLTensorEntry entry) {
194215 }
195216 }
196217
197- private static HalfFloatArray loadTensorAsHalfFloatArray (GGMLTensorEntry entry ) {
218+ public static HalfFloatArray loadTensorAsHalfFloatArray (GGMLTensorEntry entry ) {
198219 if (entry .ggmlType () == GGMLType .F32 ) {
199- // For F32, we can directly create FloatArray from memory
200- // FloatBuffer buffer = entry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
201- // FloatArray array = new FloatArray(buffer.remaining());
202- // for (int i = 0; i < buffer.remaining(); i++) {
203- // array.set(i, buffer.get());
204- // }
205- // return array
206- // ;
207220 System .out .println ("Loading F32 tensor as HalfFloatArray" );
208- return null ;
221+ return null ;
209222 } else {
210223 // For quantized formats, we need to load through FloatTensor
211224 FloatTensor tensor = loadQuantized (entry );
@@ -218,52 +231,6 @@ private static HalfFloatArray loadTensorAsHalfFloatArray(GGMLTensorEntry entry)
218231 }
219232 }
220233
221- public static float getFloat (int index , int size , MemorySegment memorySegment ) {
222- assert 0 <= index && index < size ;
223- int blockIndex = index / GGMLType .Q4_0 .getBlockSize ();
224- int blockOffset = blockIndex * GGMLType .Q4_0 .getTypeSize ();
225- float scale = Float .float16ToFloat (readShort (memorySegment , blockOffset ));
226- byte quant ;
227- int modIndex = index % GGMLType .Q4_0 .getBlockSize ();
228- if (modIndex < GGMLType .Q4_0 .getBlockSize () / 2 ) {
229- quant = (byte ) (readByte (memorySegment , blockOffset + Float16 .BYTES + modIndex ) & 0x0F );
230- } else {
231- quant = (byte ) ((readByte (memorySegment , blockOffset + Float16 .BYTES + modIndex - GGMLType .Q4_0 .getBlockSize () / 2 ) >>> 4 ) & 0x0F );
232- }
233- quant -= 8 ;
234- return quant * scale ;
235- }
236-
237- private static Tokenizer createTokenizer (Map <String , Object > metadata , Vocabulary vocabulary ) {
238- String [] mergeLines = (String []) metadata .get ("tokenizer.ggml.merges" );
239- List <Pair <Integer , Integer >> merges = Arrays .stream (mergeLines ).map (line -> line .split (" " ))
240- .map (parts -> new Pair <>(vocabulary .getIndex (parts [0 ]).orElseThrow (), vocabulary .getIndex (parts [1 ]).orElseThrow ())).toList ();
241-
242- int allTokens = vocabulary .size ();
243- int baseTokens = 128000 ; // assume all tokens after the base ones are special.
244- int reservedSpecialTokens = allTokens - baseTokens ;
245- List <String > specialTokensList = Arrays .stream (vocabulary .tokens (), baseTokens , allTokens ).toList ();
246-
247- assert specialTokensList .stream ().allMatch (token -> vocabulary .getIndex (token ).isPresent ());
248-
249- Map <String , Integer > specialTokens = IntStream .range (0 , specialTokensList .size ()).boxed ().collect (Collectors .toMap (i -> specialTokensList .get (i ), i -> baseTokens + i ));
250-
251- return new Tokenizer (vocabulary , merges , LLAMA_3_PATTERN , specialTokens );
252- }
253-
254- public static FloatTensor loadQuantized (GGMLTensorEntry entry ) {
255- GGMLType ggmlType = entry .ggmlType ();
256- // System.out.println("Loading quantized tensor of type " + entry.name());
257- return switch (ggmlType ) {
258- // case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
259- case Q8_0 -> new Q8_0FloatTensor (FloatTensor .numberOfElements (entry .shape ()), entry .memorySegment ());
260- case Q4_0 -> new Q4_0FloatTensor (FloatTensor .numberOfElements (entry .shape ()), entry .memorySegment ());
261- // case BF16 -> new BF16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
262- case F16 -> new F16FloatTensor (FloatTensor .numberOfElements (entry .shape ()), entry .memorySegment ());
263- default -> throw new UnsupportedOperationException ("Quantization format " + ggmlType );
264- };
265- }
266-
267234 public static FloatTensor [] loadArrayOfQuantized (int size , IntFunction <GGMLTensorEntry > getTensorEntry ) {
268235 FloatTensor [] array = new FloatTensor [size ];
269236 for (int i = 0 ; i < size ; i ++) {
0 commit comments