11package com .example .loader .weights ;
22
3- import com .example .LlamaApp ;
43import com .example .core .model .GGMLType ;
54import com .example .core .model .tensor .FloatTensor ;
5+ import com .example .core .model .tensor .GGMLTensorEntry ;
6+ import com .example .core .types .Float16 ;
67import uk .ac .manchester .tornado .api .types .HalfFloat ;
78import uk .ac .manchester .tornado .api .types .arrays .ByteArray ;
89import uk .ac .manchester .tornado .api .types .arrays .FloatArray ;
910import uk .ac .manchester .tornado .api .types .arrays .HalfFloatArray ;
1011
12+ import java .lang .foreign .MemorySegment ;
13+ import java .nio .ByteOrder ;
1114import java .nio .FloatBuffer ;
15+ import java .util .function .IntFunction ;
16+
17+ import static com .example .core .model .tensor .FloatTensor .readByte ;
18+ import static com .example .core .model .tensor .FloatTensor .readShort ;
1219
1320public class Weights {
1421 // token embedding table
@@ -28,7 +35,7 @@ public class Weights {
2835 public final FloatTensor [] w3 ; // (layer, hidden_dim, dim)
2936 //
3037 public final FloatTensor wcls ; // (vocab_size, dim)
31- public final ByteArray wclsByteArray ;
38+ public final HalfFloatArray wclsHalfFloat ;
3239 // public final rmsnorm
3340 public final FloatBuffer rms_final_weight ; // (dim,)
3441 // freq_cis for RoPE relatively positional embeddings
@@ -51,6 +58,7 @@ public class Weights {
5158 public FloatArray freq_cis_imagFlat ; // (seq_len, head_size/2)
5259 // (optional) classifier weights for the logits, on the last layer
5360 public GGMLType weightType ;
61+
5462 /**
5563 * Constructor to initialize all weight tensors for the model. Automatically creates TornadoVM-compatible versions when needed.
5664 *
@@ -86,17 +94,19 @@ public class Weights {
8694 /**
8795 * Constructor for standard (non-TornadoVM) mode
8896 */
89- public Weights (FloatTensor token_embedding_table , FloatBuffer [] rms_att_weight ,
90- FloatTensor [] wq , FloatTensor [] wk , FloatTensor [] wv , FloatTensor [] wo ,
91- FloatBuffer [] rms_ffn_weight , FloatTensor [] w1 , FloatTensor [] w2 , FloatTensor [] w3 ,
92- FloatBuffer rms_final_weight , FloatBuffer freq_cis_real , FloatBuffer freq_cis_imag ,
93- FloatTensor wcls , GGMLType weightType ) {
97+ public Weights (FloatTensor token_embedding_table , FloatBuffer [] rms_att_weight , FloatTensor [] wq , FloatTensor [] wk , FloatTensor [] wv , FloatTensor [] wo , FloatBuffer [] rms_ffn_weight ,
98+ FloatTensor [] w1 , FloatTensor [] w2 , FloatTensor [] w3 , FloatBuffer rms_final_weight , FloatBuffer freq_cis_real , FloatBuffer freq_cis_imag , FloatTensor wcls , GGMLType weightType ) {
9499 // Standard format
95100 this .token_embedding_table = token_embedding_table ;
96101 this .rms_att_weight = rms_att_weight ;
97- this .wq = wq ; this .wk = wk ; this .wv = wv ; this .wo = wo ;
102+ this .wq = wq ;
103+ this .wk = wk ;
104+ this .wv = wv ;
105+ this .wo = wo ;
98106 this .rms_ffn_weight = rms_ffn_weight ;
99- this .w1 = w1 ; this .w2 = w2 ; this .w3 = w3 ;
107+ this .w1 = w1 ;
108+ this .w2 = w2 ;
109+ this .w3 = w3 ;
100110 this .wcls = wcls ;
101111 this .rms_final_weight = rms_final_weight ;
102112 this .freq_cis_real = freq_cis_real ;
@@ -106,110 +116,58 @@ public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight,
106116 // TornadoVM format (null when not using TornadoVM)
107117 this .tokenEmbeddingTable = null ;
108118 this .rms_att_weightLayered = null ;
109- this .wqLayered = null ; this .wkLayered = null ; this .wvLayered = null ; this .woLayered = null ;
119+ this .wqLayered = null ;
120+ this .wkLayered = null ;
121+ this .wvLayered = null ;
122+ this .woLayered = null ;
110123 this .rms_ffn_weightLayered = null ;
111- this .w1Layered = null ; this .w2Layered = null ; this .w3Layered = null ;
124+ this .w1Layered = null ;
125+ this .w2Layered = null ;
126+ this .w3Layered = null ;
112127 this .rms_final_weight_as_floatArray = null ;
113- this .freq_cis_realFlat = null ; this .freq_cis_imagFlat = null ;
114- this .wclsByteArray = null ;
128+ this .freq_cis_realFlat = null ;
129+ this .freq_cis_imagFlat = null ;
130+ this .wclsHalfFloat = null ;
115131 }
116132
117133 /**
118134 * Constructor for TornadoVM mode
119135 */
120- public Weights (FloatArray tokenEmbeddingTable ,
121- FloatArray [] rms_att_weightLayered ,
122- HalfFloatArray [] wqLayered , HalfFloatArray [] wkLayered , HalfFloatArray [] wvLayered , HalfFloatArray [] woLayered ,
123- FloatArray [] rms_ffn_weightLayered , HalfFloatArray [] w1Layered , HalfFloatArray [] w2Layered , HalfFloatArray [] w3Layered ,
124- FloatArray rms_final_weight_as_floatArray , FloatArray freq_cis_realFlat , FloatArray freq_cis_imagFlat ,
125- ByteArray wclsByteArray , GGMLType weightType ) {
136+ public Weights (FloatArray tokenEmbeddingTable , FloatArray [] rms_att_weightLayered , HalfFloatArray [] wqLayered , HalfFloatArray [] wkLayered , HalfFloatArray [] wvLayered , HalfFloatArray [] woLayered ,
137+ FloatArray [] rms_ffn_weightLayered , HalfFloatArray [] w1Layered , HalfFloatArray [] w2Layered , HalfFloatArray [] w3Layered , FloatArray rms_final_weight_as_floatArray ,
138+ FloatArray freq_cis_realFlat , FloatArray freq_cis_imagFlat , HalfFloatArray wclsByteArray , GGMLType weightType ) {
126139 // Standard format (null when using TornadoVM)
127140 this .token_embedding_table = null ;
128141 this .rms_att_weight = null ;
129- this .wq = null ; this .wk = null ; this .wv = null ; this .wo = null ;
142+ this .wq = null ;
143+ this .wk = null ;
144+ this .wv = null ;
145+ this .wo = null ;
130146 this .rms_ffn_weight = null ;
131- this .w1 = null ; this .w2 = null ; this .w3 = null ;
147+ this .w1 = null ;
148+ this .w2 = null ;
149+ this .w3 = null ;
132150 this .wcls = null ;
133151 this .rms_final_weight = null ;
134- this .freq_cis_real = null ; this .freq_cis_imag = null ;
152+ this .freq_cis_real = null ;
153+ this .freq_cis_imag = null ;
135154
136155 // TornadoVM format
137156 this .tokenEmbeddingTable = tokenEmbeddingTable ;
138157 this .rms_att_weightLayered = rms_att_weightLayered ;
139- this .wqLayered = wqLayered ; this .wkLayered = wkLayered ; this .wvLayered = wvLayered ; this .woLayered = woLayered ;
158+ this .wqLayered = wqLayered ;
159+ this .wkLayered = wkLayered ;
160+ this .wvLayered = wvLayered ;
161+ this .woLayered = woLayered ;
140162 this .rms_ffn_weightLayered = rms_ffn_weightLayered ;
141- this .w1Layered = w1Layered ; this .w2Layered = w2Layered ; this .w3Layered = w3Layered ;
163+ this .w1Layered = w1Layered ;
164+ this .w2Layered = w2Layered ;
165+ this .w3Layered = w3Layered ;
142166 this .rms_final_weight_as_floatArray = rms_final_weight_as_floatArray ;
143- this .freq_cis_realFlat = freq_cis_realFlat ; this .freq_cis_imagFlat = freq_cis_imagFlat ;
144- this .wclsByteArray = wclsByteArray ;
167+ this .freq_cis_realFlat = freq_cis_realFlat ;
168+ this .freq_cis_imagFlat = freq_cis_imagFlat ;
169+ this .wclsHalfFloat = wclsByteArray ;
145170 this .weightType = weightType ;
146171 }
147172
148- /**
149- * Converts an array of FloatBuffer objects to TornadoVM FloatArray format. Preserves the original buffer position after conversion.
150- *
151- * @param array
152- * Array of FloatBuffers to convert
153- * @return Array of FloatArrays with the same data
154- */
155- private static FloatArray [] loadToFloatArray (FloatBuffer [] array ) {
156- FloatArray [] result = new FloatArray [array .length ];
157- for (int i = 0 ; i < array .length ; i ++) {
158- int size = array [i ].remaining ();
159- result [i ] = new FloatArray (size );
160-
161- // Save and restore buffer position to avoid side effects
162- int originalPosition = array [i ].position ();
163-
164- for (int j = 0 ; j < size ; j ++) {
165- float value = array [i ].get ();
166- result [i ].set (j , value );
167- }
168- // Reset buffer position
169- array [i ].position (originalPosition );
170- }
171-
172- return result ;
173- }
174-
175-
176- /**
177- * Converts a single FloatBuffer to a TornadoVM FloatArray. Creates a duplicate buffer to avoid modifying the original.
178- *
179- * @param input
180- * FloatBuffer to convert
181- * @return FloatArray with the same data
182- */
183- private static FloatArray loadToSingleFloatArray (FloatBuffer input ) {
184- // Create a duplicate to prevent modifying the original buffer
185- FloatBuffer copy = input .duplicate ();
186- int totalSize = copy .remaining ();
187-
188- FloatArray result = new FloatArray (totalSize );
189-
190- int index = 0 ;
191- while (copy .hasRemaining ()) {
192- result .set (index ++, copy .get ());
193- }
194-
195- return result ;
196- }
197-
198- /**
199- * Converts a FloatTensor to a TornadoVM FloatArray.
200- *
201- * @param input
202- * FloatTensor to convert
203- * @return FloatArray with the same data
204- */
205- public FloatArray loadToFloatArray (FloatTensor input ) {
206- FloatArray floatArray = new FloatArray (input .size ());
207-
208- for (int i = 0 ; i < input .size (); i ++) {
209- floatArray .set (i , input .getFloat (i ));
210- }
211-
212- return floatArray ;
213- }
214-
215173}
0 commit comments