Skip to content

Commit 5d45815

Browse files
committed
Refactor Weights class to support TornadoVM and standard modes.
Restructured the `Weights` class by introducing separate constructors for TornadoVM and non-TornadoVM modes. Simplified initialization logic and improved readability by clearly separating the handling of standard and TornadoVM-specific formats.
1 parent ac7a0ae commit 5d45815

File tree

2 files changed

+49
-61
lines changed

2 files changed

+49
-61
lines changed

src/main/java/com/example/loader/weights/Weights.java

Lines changed: 48 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -80,78 +80,66 @@ public class Weights {
8080
* RoPE sine components
8181
* @param wcls
8282
* Classifier weights for output logits
83+
*
84+
/**
85+
* Constructor for standard (non-TornadoVM) mode
8386
*/
84-
public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight, FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo, FloatBuffer[] rms_ffn_weight,
85-
FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3, FloatBuffer rms_final_weight, FloatBuffer freq_cis_real, FloatBuffer freq_cis_imag, FloatTensor wcls, GGMLType weightType) {
87+
public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight,
88+
FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo,
89+
FloatBuffer[] rms_ffn_weight, FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3,
90+
FloatBuffer rms_final_weight, FloatBuffer freq_cis_real, FloatBuffer freq_cis_imag,
91+
FloatTensor wcls, GGMLType weightType) {
92+
// Standard format
8693
this.token_embedding_table = token_embedding_table;
8794
this.rms_att_weight = rms_att_weight;
88-
this.wq = wq;
89-
this.wk = wk;
90-
this.wv = wv;
91-
this.wo = wo;
95+
this.wq = wq; this.wk = wk; this.wv = wv; this.wo = wo;
9296
this.rms_ffn_weight = rms_ffn_weight;
93-
this.w1 = w1;
94-
this.w2 = w2;
95-
this.w3 = w3;
97+
this.w1 = w1; this.w2 = w2; this.w3 = w3;
98+
this.wcls = wcls;
9699
this.rms_final_weight = rms_final_weight;
97100
this.freq_cis_real = freq_cis_real;
98101
this.freq_cis_imag = freq_cis_imag;
99-
this.wcls = wcls;
100-
this.tokenEmbeddingTable = loadToFloatArray(token_embedding_table); // (vocab_size, dim)
101-
102-
if (LlamaApp.USE_TORNADOVM) {
103-
this.freq_cis_imagFlat = loadToSingleFloatArray(freq_cis_imag);
104-
this.freq_cis_realFlat = loadToSingleFloatArray(freq_cis_real);
105-
106-
// Store read-only weight as a ByteArray in TornadoVM
107-
this.wclsByteArray = ByteArray.fromSegment(wcls.asMemorySegment());
108-
this.rms_final_weight_as_floatArray = FloatArray.fromFloatBuffer(rms_final_weight);
109-
110-
this.rms_att_weightLayered = loadToFloatArray(rms_att_weight);
111-
112-
this.wqLayered = loadToFloatArray(wq);
113-
this.wkLayered = loadToFloatArray(wk);
114-
this.wvLayered = loadToFloatArray(wv);
115-
this.woLayered = loadToFloatArray(wo);
116-
this.rms_ffn_weightLayered = loadToFloatArray(rms_ffn_weight);
117-
this.w1Layered = loadToFloatArray(w1);
118-
this.w2Layered = loadToFloatArray(w2);
119-
this.w3Layered = loadToFloatArray(w3);
120-
121-
} else {
122-
this.freq_cis_imagFlat = null;
123-
this.freq_cis_realFlat = null;
124-
this.wclsByteArray = null;
125-
this.rms_final_weight_as_floatArray = null;
126-
this.rms_att_weightLayered = null;
127-
this.wqLayered = null;
128-
this.wkLayered = null;
129-
this.wvLayered = null;
130-
this.woLayered = null;
131-
this.rms_ffn_weightLayered = null;
132-
this.w1Layered = null;
133-
this.w2Layered = null;
134-
this.w3Layered = null;
135-
}
136102
this.weightType = weightType;
103+
104+
// TornadoVM format (null when not using TornadoVM)
105+
this.tokenEmbeddingTable = null;
106+
this.rms_att_weightLayered = null;
107+
this.wqLayered = null; this.wkLayered = null; this.wvLayered = null; this.woLayered = null;
108+
this.rms_ffn_weightLayered = null;
109+
this.w1Layered = null; this.w2Layered = null; this.w3Layered = null;
110+
this.rms_final_weight_as_floatArray = null;
111+
this.freq_cis_realFlat = null; this.freq_cis_imagFlat = null;
112+
this.wclsByteArray = null;
137113
}
138114

139115
/**
140-
* Converts an array of FloatTensor objects to TornadoVM FloatArray format. This enables efficient GPU computation by flattening multi-dimensional tensors.
141-
*
142-
* @param array
143-
* Array of FloatTensors to convert
144-
* @return Array of FloatArrays with the same data
116+
* Constructor for TornadoVM mode
145117
*/
146-
private static FloatArray[] loadToFloatArray(FloatTensor[] array) {
147-
FloatArray[] floatArrays = new FloatArray[array.length];
148-
for (int i = 0; i < array.length; i++) {
149-
floatArrays[i] = new FloatArray(array[i].size());
150-
for (int j = 0; j < array[i].size(); j++) {
151-
floatArrays[i].set(j, array[i].getFloat(j));
152-
}
153-
}
154-
return floatArrays;
118+
public Weights(FloatArray tokenEmbeddingTable, FloatArray[] rms_att_weightLayered,
119+
FloatArray[] wqLayered, FloatArray[] wkLayered, FloatArray[] wvLayered, FloatArray[] woLayered,
120+
FloatArray[] rms_ffn_weightLayered, FloatArray[] w1Layered, FloatArray[] w2Layered, FloatArray[] w3Layered,
121+
FloatArray rms_final_weight_as_floatArray, FloatArray freq_cis_realFlat, FloatArray freq_cis_imagFlat,
122+
ByteArray wclsByteArray, GGMLType weightType) {
123+
// Standard format (null when using TornadoVM)
124+
this.token_embedding_table = null;
125+
this.rms_att_weight = null;
126+
this.wq = null; this.wk = null; this.wv = null; this.wo = null;
127+
this.rms_ffn_weight = null;
128+
this.w1 = null; this.w2 = null; this.w3 = null;
129+
this.wcls = null;
130+
this.rms_final_weight = null;
131+
this.freq_cis_real = null; this.freq_cis_imag = null;
132+
133+
// TornadoVM format
134+
this.tokenEmbeddingTable = tokenEmbeddingTable;
135+
this.rms_att_weightLayered = rms_att_weightLayered;
136+
this.wqLayered = wqLayered; this.wkLayered = wkLayered; this.wvLayered = wvLayered; this.woLayered = woLayered;
137+
this.rms_ffn_weightLayered = rms_ffn_weightLayered;
138+
this.w1Layered = w1Layered; this.w2Layered = w2Layered; this.w3Layered = w3Layered;
139+
this.rms_final_weight_as_floatArray = rms_final_weight_as_floatArray;
140+
this.freq_cis_realFlat = freq_cis_realFlat; this.freq_cis_imagFlat = freq_cis_imagFlat;
141+
this.wclsByteArray = wclsByteArray;
142+
this.weightType = weightType;
155143
}
156144

157145
/**

src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import java.util.List;
1313

1414
public class TornadoVMMasterPlan {
15-
private static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "false"));
15+
private static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False"));
1616

1717
private final State state;
1818
private final Configuration config;

0 commit comments

Comments
 (0)