Skip to content

Commit 783883c

Browse files
committed
Refactor Weights class to improve tensor handling
1 parent 4bed4e2 commit 783883c

File tree

1 file changed

+51
-93
lines changed

1 file changed

+51
-93
lines changed
Lines changed: 51 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
package com.example.loader.weights;
22

3-
import com.example.LlamaApp;
43
import com.example.core.model.GGMLType;
54
import com.example.core.model.tensor.FloatTensor;
5+
import com.example.core.model.tensor.GGMLTensorEntry;
6+
import com.example.core.types.Float16;
67
import uk.ac.manchester.tornado.api.types.HalfFloat;
78
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
89
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
910
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
1011

12+
import java.lang.foreign.MemorySegment;
13+
import java.nio.ByteOrder;
1114
import java.nio.FloatBuffer;
15+
import java.util.function.IntFunction;
16+
17+
import static com.example.core.model.tensor.FloatTensor.readByte;
18+
import static com.example.core.model.tensor.FloatTensor.readShort;
1219

1320
public class Weights {
1421
// token embedding table
@@ -28,7 +35,7 @@ public class Weights {
2835
public final FloatTensor[] w3; // (layer, hidden_dim, dim)
2936
//
3037
public final FloatTensor wcls; // (vocab_size, dim)
31-
public final ByteArray wclsByteArray;
38+
public final HalfFloatArray wclsHalfFloat;
3239
// public final rmsnorm
3340
public final FloatBuffer rms_final_weight; // (dim,)
3441
// freq_cis for RoPE relatively positional embeddings
@@ -51,6 +58,7 @@ public class Weights {
5158
public FloatArray freq_cis_imagFlat; // (seq_len, head_size/2)
5259
// (optional) classifier weights for the logits, on the last layer
5360
public GGMLType weightType;
61+
5462
/**
5563
* Constructor to initialize all weight tensors for the model. Automatically creates TornadoVM-compatible versions when needed.
5664
*
@@ -86,17 +94,19 @@ public class Weights {
8694
/**
8795
* Constructor for standard (non-TornadoVM) mode
8896
*/
89-
public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight,
90-
FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo,
91-
FloatBuffer[] rms_ffn_weight, FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3,
92-
FloatBuffer rms_final_weight, FloatBuffer freq_cis_real, FloatBuffer freq_cis_imag,
93-
FloatTensor wcls, GGMLType weightType) {
97+
public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight, FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo, FloatBuffer[] rms_ffn_weight,
98+
FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3, FloatBuffer rms_final_weight, FloatBuffer freq_cis_real, FloatBuffer freq_cis_imag, FloatTensor wcls, GGMLType weightType) {
9499
// Standard format
95100
this.token_embedding_table = token_embedding_table;
96101
this.rms_att_weight = rms_att_weight;
97-
this.wq = wq; this.wk = wk; this.wv = wv; this.wo = wo;
102+
this.wq = wq;
103+
this.wk = wk;
104+
this.wv = wv;
105+
this.wo = wo;
98106
this.rms_ffn_weight = rms_ffn_weight;
99-
this.w1 = w1; this.w2 = w2; this.w3 = w3;
107+
this.w1 = w1;
108+
this.w2 = w2;
109+
this.w3 = w3;
100110
this.wcls = wcls;
101111
this.rms_final_weight = rms_final_weight;
102112
this.freq_cis_real = freq_cis_real;
@@ -106,110 +116,58 @@ public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight,
106116
// TornadoVM format (null when not using TornadoVM)
107117
this.tokenEmbeddingTable = null;
108118
this.rms_att_weightLayered = null;
109-
this.wqLayered = null; this.wkLayered = null; this.wvLayered = null; this.woLayered = null;
119+
this.wqLayered = null;
120+
this.wkLayered = null;
121+
this.wvLayered = null;
122+
this.woLayered = null;
110123
this.rms_ffn_weightLayered = null;
111-
this.w1Layered = null; this.w2Layered = null; this.w3Layered = null;
124+
this.w1Layered = null;
125+
this.w2Layered = null;
126+
this.w3Layered = null;
112127
this.rms_final_weight_as_floatArray = null;
113-
this.freq_cis_realFlat = null; this.freq_cis_imagFlat = null;
114-
this.wclsByteArray = null;
128+
this.freq_cis_realFlat = null;
129+
this.freq_cis_imagFlat = null;
130+
this.wclsHalfFloat = null;
115131
}
116132

117133
/**
118134
* Constructor for TornadoVM mode
119135
*/
120-
public Weights(FloatArray tokenEmbeddingTable,
121-
FloatArray[] rms_att_weightLayered,
122-
HalfFloatArray[] wqLayered, HalfFloatArray[] wkLayered, HalfFloatArray[] wvLayered, HalfFloatArray[] woLayered,
123-
FloatArray[] rms_ffn_weightLayered, HalfFloatArray[] w1Layered, HalfFloatArray[] w2Layered, HalfFloatArray[] w3Layered,
124-
FloatArray rms_final_weight_as_floatArray, FloatArray freq_cis_realFlat, FloatArray freq_cis_imagFlat,
125-
ByteArray wclsByteArray, GGMLType weightType) {
136+
public Weights(FloatArray tokenEmbeddingTable, FloatArray[] rms_att_weightLayered, HalfFloatArray[] wqLayered, HalfFloatArray[] wkLayered, HalfFloatArray[] wvLayered, HalfFloatArray[] woLayered,
137+
FloatArray[] rms_ffn_weightLayered, HalfFloatArray[] w1Layered, HalfFloatArray[] w2Layered, HalfFloatArray[] w3Layered, FloatArray rms_final_weight_as_floatArray,
138+
FloatArray freq_cis_realFlat, FloatArray freq_cis_imagFlat, HalfFloatArray wclsByteArray, GGMLType weightType) {
126139
// Standard format (null when using TornadoVM)
127140
this.token_embedding_table = null;
128141
this.rms_att_weight = null;
129-
this.wq = null; this.wk = null; this.wv = null; this.wo = null;
142+
this.wq = null;
143+
this.wk = null;
144+
this.wv = null;
145+
this.wo = null;
130146
this.rms_ffn_weight = null;
131-
this.w1 = null; this.w2 = null; this.w3 = null;
147+
this.w1 = null;
148+
this.w2 = null;
149+
this.w3 = null;
132150
this.wcls = null;
133151
this.rms_final_weight = null;
134-
this.freq_cis_real = null; this.freq_cis_imag = null;
152+
this.freq_cis_real = null;
153+
this.freq_cis_imag = null;
135154

136155
// TornadoVM format
137156
this.tokenEmbeddingTable = tokenEmbeddingTable;
138157
this.rms_att_weightLayered = rms_att_weightLayered;
139-
this.wqLayered = wqLayered; this.wkLayered = wkLayered; this.wvLayered = wvLayered; this.woLayered = woLayered;
158+
this.wqLayered = wqLayered;
159+
this.wkLayered = wkLayered;
160+
this.wvLayered = wvLayered;
161+
this.woLayered = woLayered;
140162
this.rms_ffn_weightLayered = rms_ffn_weightLayered;
141-
this.w1Layered = w1Layered; this.w2Layered = w2Layered; this.w3Layered = w3Layered;
163+
this.w1Layered = w1Layered;
164+
this.w2Layered = w2Layered;
165+
this.w3Layered = w3Layered;
142166
this.rms_final_weight_as_floatArray = rms_final_weight_as_floatArray;
143-
this.freq_cis_realFlat = freq_cis_realFlat; this.freq_cis_imagFlat = freq_cis_imagFlat;
144-
this.wclsByteArray = wclsByteArray;
167+
this.freq_cis_realFlat = freq_cis_realFlat;
168+
this.freq_cis_imagFlat = freq_cis_imagFlat;
169+
this.wclsHalfFloat = wclsByteArray;
145170
this.weightType = weightType;
146171
}
147172

148-
/**
149-
* Converts an array of FloatBuffer objects to TornadoVM FloatArray format. Preserves the original buffer position after conversion.
150-
*
151-
* @param array
152-
* Array of FloatBuffers to convert
153-
* @return Array of FloatArrays with the same data
154-
*/
155-
private static FloatArray[] loadToFloatArray(FloatBuffer[] array) {
156-
FloatArray[] result = new FloatArray[array.length];
157-
for (int i = 0; i < array.length; i++) {
158-
int size = array[i].remaining();
159-
result[i] = new FloatArray(size);
160-
161-
// Save and restore buffer position to avoid side effects
162-
int originalPosition = array[i].position();
163-
164-
for (int j = 0; j < size; j++) {
165-
float value = array[i].get();
166-
result[i].set(j, value);
167-
}
168-
// Reset buffer position
169-
array[i].position(originalPosition);
170-
}
171-
172-
return result;
173-
}
174-
175-
176-
/**
177-
* Converts a single FloatBuffer to a TornadoVM FloatArray. Creates a duplicate buffer to avoid modifying the original.
178-
*
179-
* @param input
180-
* FloatBuffer to convert
181-
* @return FloatArray with the same data
182-
*/
183-
private static FloatArray loadToSingleFloatArray(FloatBuffer input) {
184-
// Create a duplicate to prevent modifying the original buffer
185-
FloatBuffer copy = input.duplicate();
186-
int totalSize = copy.remaining();
187-
188-
FloatArray result = new FloatArray(totalSize);
189-
190-
int index = 0;
191-
while (copy.hasRemaining()) {
192-
result.set(index++, copy.get());
193-
}
194-
195-
return result;
196-
}
197-
198-
/**
199-
* Converts a FloatTensor to a TornadoVM FloatArray.
200-
*
201-
* @param input
202-
* FloatTensor to convert
203-
* @return FloatArray with the same data
204-
*/
205-
public FloatArray loadToFloatArray(FloatTensor input) {
206-
FloatArray floatArray = new FloatArray(input.size());
207-
208-
for (int i = 0; i < input.size(); i++) {
209-
floatArray.set(i, input.getFloat(i));
210-
}
211-
212-
return floatArray;
213-
}
214-
215173
}

0 commit comments

Comments
 (0)