@@ -80,78 +80,66 @@ public class Weights {
8080 * RoPE sine components
8181 * @param wcls
8282 * Classifier weights for output logits
83+ *
84+ /**
85+ * Constructor for standard (non-TornadoVM) mode
8386 */
84- public Weights (FloatTensor token_embedding_table , FloatBuffer [] rms_att_weight , FloatTensor [] wq , FloatTensor [] wk , FloatTensor [] wv , FloatTensor [] wo , FloatBuffer [] rms_ffn_weight ,
85- FloatTensor [] w1 , FloatTensor [] w2 , FloatTensor [] w3 , FloatBuffer rms_final_weight , FloatBuffer freq_cis_real , FloatBuffer freq_cis_imag , FloatTensor wcls , GGMLType weightType ) {
87+ public Weights (FloatTensor token_embedding_table , FloatBuffer [] rms_att_weight ,
88+ FloatTensor [] wq , FloatTensor [] wk , FloatTensor [] wv , FloatTensor [] wo ,
89+ FloatBuffer [] rms_ffn_weight , FloatTensor [] w1 , FloatTensor [] w2 , FloatTensor [] w3 ,
90+ FloatBuffer rms_final_weight , FloatBuffer freq_cis_real , FloatBuffer freq_cis_imag ,
91+ FloatTensor wcls , GGMLType weightType ) {
92+ // Standard format
8693 this .token_embedding_table = token_embedding_table ;
8794 this .rms_att_weight = rms_att_weight ;
88- this .wq = wq ;
89- this .wk = wk ;
90- this .wv = wv ;
91- this .wo = wo ;
95+ this .wq = wq ; this .wk = wk ; this .wv = wv ; this .wo = wo ;
9296 this .rms_ffn_weight = rms_ffn_weight ;
93- this .w1 = w1 ;
94- this .w2 = w2 ;
95- this .w3 = w3 ;
97+ this .w1 = w1 ; this .w2 = w2 ; this .w3 = w3 ;
98+ this .wcls = wcls ;
9699 this .rms_final_weight = rms_final_weight ;
97100 this .freq_cis_real = freq_cis_real ;
98101 this .freq_cis_imag = freq_cis_imag ;
99- this .wcls = wcls ;
100- this .tokenEmbeddingTable = loadToFloatArray (token_embedding_table ); // (vocab_size, dim)
101-
102- if (LlamaApp .USE_TORNADOVM ) {
103- this .freq_cis_imagFlat = loadToSingleFloatArray (freq_cis_imag );
104- this .freq_cis_realFlat = loadToSingleFloatArray (freq_cis_real );
105-
106- // Store read-only weight as a ByteArray in TornadoVM
107- this .wclsByteArray = ByteArray .fromSegment (wcls .asMemorySegment ());
108- this .rms_final_weight_as_floatArray = FloatArray .fromFloatBuffer (rms_final_weight );
109-
110- this .rms_att_weightLayered = loadToFloatArray (rms_att_weight );
111-
112- this .wqLayered = loadToFloatArray (wq );
113- this .wkLayered = loadToFloatArray (wk );
114- this .wvLayered = loadToFloatArray (wv );
115- this .woLayered = loadToFloatArray (wo );
116- this .rms_ffn_weightLayered = loadToFloatArray (rms_ffn_weight );
117- this .w1Layered = loadToFloatArray (w1 );
118- this .w2Layered = loadToFloatArray (w2 );
119- this .w3Layered = loadToFloatArray (w3 );
120-
121- } else {
122- this .freq_cis_imagFlat = null ;
123- this .freq_cis_realFlat = null ;
124- this .wclsByteArray = null ;
125- this .rms_final_weight_as_floatArray = null ;
126- this .rms_att_weightLayered = null ;
127- this .wqLayered = null ;
128- this .wkLayered = null ;
129- this .wvLayered = null ;
130- this .woLayered = null ;
131- this .rms_ffn_weightLayered = null ;
132- this .w1Layered = null ;
133- this .w2Layered = null ;
134- this .w3Layered = null ;
135- }
136102 this .weightType = weightType ;
103+
104+ // TornadoVM format (null when not using TornadoVM)
105+ this .tokenEmbeddingTable = null ;
106+ this .rms_att_weightLayered = null ;
107+ this .wqLayered = null ; this .wkLayered = null ; this .wvLayered = null ; this .woLayered = null ;
108+ this .rms_ffn_weightLayered = null ;
109+ this .w1Layered = null ; this .w2Layered = null ; this .w3Layered = null ;
110+ this .rms_final_weight_as_floatArray = null ;
111+ this .freq_cis_realFlat = null ; this .freq_cis_imagFlat = null ;
112+ this .wclsByteArray = null ;
137113 }
138114
139115 /**
140- * Converts an array of FloatTensor objects to TornadoVM FloatArray format. This enables efficient GPU computation by flattening multi-dimensional tensors.
141- *
142- * @param array
143- * Array of FloatTensors to convert
144- * @return Array of FloatArrays with the same data
116+ * Constructor for TornadoVM mode
145117 */
146- private static FloatArray [] loadToFloatArray (FloatTensor [] array ) {
147- FloatArray [] floatArrays = new FloatArray [array .length ];
148- for (int i = 0 ; i < array .length ; i ++) {
149- floatArrays [i ] = new FloatArray (array [i ].size ());
150- for (int j = 0 ; j < array [i ].size (); j ++) {
151- floatArrays [i ].set (j , array [i ].getFloat (j ));
152- }
153- }
154- return floatArrays ;
118+ public Weights (FloatArray tokenEmbeddingTable , FloatArray [] rms_att_weightLayered ,
119+ FloatArray [] wqLayered , FloatArray [] wkLayered , FloatArray [] wvLayered , FloatArray [] woLayered ,
120+ FloatArray [] rms_ffn_weightLayered , FloatArray [] w1Layered , FloatArray [] w2Layered , FloatArray [] w3Layered ,
121+ FloatArray rms_final_weight_as_floatArray , FloatArray freq_cis_realFlat , FloatArray freq_cis_imagFlat ,
122+ ByteArray wclsByteArray , GGMLType weightType ) {
123+ // Standard format (null when using TornadoVM)
124+ this .token_embedding_table = null ;
125+ this .rms_att_weight = null ;
126+ this .wq = null ; this .wk = null ; this .wv = null ; this .wo = null ;
127+ this .rms_ffn_weight = null ;
128+ this .w1 = null ; this .w2 = null ; this .w3 = null ;
129+ this .wcls = null ;
130+ this .rms_final_weight = null ;
131+ this .freq_cis_real = null ; this .freq_cis_imag = null ;
132+
133+ // TornadoVM format
134+ this .tokenEmbeddingTable = tokenEmbeddingTable ;
135+ this .rms_att_weightLayered = rms_att_weightLayered ;
136+ this .wqLayered = wqLayered ; this .wkLayered = wkLayered ; this .wvLayered = wvLayered ; this .woLayered = woLayered ;
137+ this .rms_ffn_weightLayered = rms_ffn_weightLayered ;
138+ this .w1Layered = w1Layered ; this .w2Layered = w2Layered ; this .w3Layered = w3Layered ;
139+ this .rms_final_weight_as_floatArray = rms_final_weight_as_floatArray ;
140+ this .freq_cis_realFlat = freq_cis_realFlat ; this .freq_cis_imagFlat = freq_cis_imagFlat ;
141+ this .wclsByteArray = wclsByteArray ;
142+ this .weightType = weightType ;
155143 }
156144
157145 /**
0 commit comments