Skip to content

Commit 2f3a17d

Browse files
dataangelOceania2018
authored andcommitted
update:keras.save_weights or keras.load_weights
1 parent 3bfdedc commit 2f3a17d

File tree

3 files changed

+181
-11
lines changed

3 files changed

+181
-11
lines changed

src/TensorFlowNET.Keras/Engine/Model.Training.cs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,15 @@ public void load_weights(string filepath, bool by_name = false, bool skip_mismat
2727
}
2828
else
2929
{
30-
fdf5_format.load_weights_from_hdf5_group(fileId, Layers);
30+
hdf5_format.load_weights_from_hdf5_group(fileId, Layers);
3131
}
32-
H5G.close(fileId);
32+
Hdf5.CloseFile(fileId);
33+
}
34+
public void save_weights(string filepath, bool overwrite = true, string save_format = null, object options = null)
35+
{
36+
long fileId = Hdf5.CreateFile(filepath);
37+
hdf5_format.save_weights_to_hdf5_group(fileId, Layers);
38+
Hdf5.CloseFile(fileId);
3339
}
3440
}
3541
}

src/TensorFlowNET.Keras/Losses/LogCosh.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public class LogCosh : LossFunctionWrapper, ILossFunc
1212
public LogCosh(
1313
string reduction = null,
1414
string name = null) :
15-
base(reduction: reduction, name: name == null ? "huber" : name){ }
15+
base(reduction: reduction, name: name == null ? "log_cosh" : name){ }
1616

1717
public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
1818
{

src/TensorFlowNET.Keras/Saving/fdf5_format.cs renamed to src/TensorFlowNET.Keras/Saving/hdf5_format.cs

Lines changed: 172 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
using static Tensorflow.Binding;
99
using static Tensorflow.KerasApi;
1010
using System.Linq;
11-
11+
using Tensorflow.Util;
1212
namespace Tensorflow.Keras.Saving
1313
{
14-
public class fdf5_format
14+
public class hdf5_format
1515
{
16-
16+
private static int HDF5_OBJECT_HEADER_LIMIT = 64512;
1717
public static void load_model_from_hdf5(string filepath = "", Dictionary<string, object> custom_objects = null, bool compile = false)
1818
{
1919
long root = Hdf5.OpenFile(filepath,true);
@@ -79,10 +79,7 @@ public static void load_optimizer_weights_from_hdf5_group(long filepath = -1, Di
7979
{
8080

8181
}
82-
public static void save_weights_to_hdf5_group(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
83-
{
8482

85-
}
8683
public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
8784
{
8885
string original_keras_version = "2.4.0";
@@ -136,9 +133,14 @@ public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
136133
var weight_values = new List<NDArray>();
137134
long g = H5G.open(f, name);
138135
var weight_names = load_attributes_from_hdf5_group(g, "weight_names");
136+
var get_Name = "";
139137
foreach (var i_ in weight_names)
140138
{
141-
(bool success, Array result) = Hdf5.ReadDataset<float>(g, i_);
139+
get_Name = i_;
140+
if (get_Name.IndexOf("/") > 1) {
141+
get_Name = get_Name.Split('/')[1];
142+
}
143+
(bool success, Array result) = Hdf5.ReadDataset<float>(g, get_Name);
142144
if (success)
143145
weight_values.Add(np.array(result));
144146
}
@@ -165,9 +167,171 @@ public static void load_weights_from_hdf5_group_by_name(long filepath = -1, Dict
165167
{
166168

167169
}
168-
public static void save_attributes_to_hdf5_group(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
170+
public static void save_weights_to_hdf5_group(long f, List<ILayer> layers)
171+
{
172+
List<string> layerName=new List<string>();
173+
foreach (var layer in layers)
174+
{
175+
layerName.Add(layer.Name);
176+
}
177+
save_attributes_to_hdf5_group(f, "layer_names", layerName.ToArray());
178+
Hdf5.WriteAttribute(f, "backend", "tensorflow");
179+
Hdf5.WriteAttribute(f, "keras_version", "2.5.0");
180+
181+
long g = 0, crDataGroup=0;
182+
List<IVariableV1> weights = new List<IVariableV1>();
183+
//List<IVariableV1> weight_values = new List<IVariableV1>();
184+
List<string> weight_names = new List<string>();
185+
foreach (var layer in layers) {
186+
weight_names = new List<string>();
187+
g = Hdf5.CreateOrOpenGroup(f, Hdf5Utils.NormalizedName(layer.Name));
188+
weights = _legacy_weights(layer);
189+
//weight_values= keras.backend.batch_get_value(weights);
190+
foreach (var weight in weights)
191+
{
192+
weight_names.Add(weight.Name);
193+
}
194+
save_attributes_to_hdf5_group(g, "weight_names", weight_names.ToArray());
195+
Tensor tensor = null;
196+
string get_Name = "";
197+
foreach (var (name, val) in zip(weight_names, weights)) {
198+
get_Name = name;
199+
tensor = val.AsTensor();
200+
if (get_Name.IndexOf("/") > 1)
201+
{
202+
get_Name = name.Split('/')[1];
203+
crDataGroup = Hdf5.CreateOrOpenGroup(g, Hdf5Utils.NormalizedName(get_Name));
204+
Hdf5.CloseGroup(crDataGroup);
205+
}
206+
WriteDataset(g, get_Name, tensor);
207+
tensor = null;
208+
}
209+
Hdf5.CloseGroup(g);
210+
weight_names = null;
211+
}
212+
weights = null;
213+
// weight_values = null;
214+
215+
216+
}
217+
private static void save_attributes_to_hdf5_group(long f,string name ,Array data)
218+
{
219+
int num_chunks = 1;
220+
221+
var chunked_data = Split(data, num_chunks);
222+
int getSize= 0;
223+
224+
string getType = data.Length>0?data.GetValue(0).GetType().Name.ToLower():"string";
225+
226+
switch (getType)
227+
{
228+
case "single":
229+
getSize=sizeof(float);
230+
break;
231+
case "double":
232+
getSize = sizeof(double);
233+
break;
234+
case "string":
235+
getSize = -1;
236+
break;
237+
case "int32":
238+
getSize = sizeof(int);
239+
break;
240+
case "int64":
241+
getSize = sizeof(long);
242+
break;
243+
default:
244+
getSize=-1;
245+
break;
246+
}
247+
int getCount = chunked_data.Count;
248+
249+
if (getSize != -1) {
250+
num_chunks = (int)Math.Ceiling((double)(getCount * getSize) / (double)HDF5_OBJECT_HEADER_LIMIT);
251+
if (num_chunks > 1) chunked_data = Split(data, num_chunks);
252+
}
253+
254+
if (num_chunks > 1)
255+
{
256+
foreach (var (chunk_id, chunk_data) in enumerate(chunked_data))
257+
{
258+
259+
WriteAttrs(f, getType, $"{name}{chunk_id}", chunk_data.ToArray());
260+
261+
}
262+
263+
}
264+
else {
265+
266+
WriteAttrs(f, getType,name, data);
267+
268+
}
269+
270+
}
271+
private static void WriteDataset(long f, string name, Tensor data)
272+
{
273+
switch (data.dtype)
274+
{
275+
case TF_DataType.TF_FLOAT:
276+
Hdf5.WriteDatasetFromArray<float>(f, name, data.numpy().ToMuliDimArray<float>());
277+
break;
278+
case TF_DataType.TF_DOUBLE:
279+
Hdf5.WriteDatasetFromArray<double>(f, name, data.numpy().ToMuliDimArray<float>());
280+
break;
281+
case TF_DataType.TF_INT32:
282+
Hdf5.WriteDatasetFromArray<int>(f, name, data.numpy().ToMuliDimArray<float>());
283+
break;
284+
case TF_DataType.TF_INT64:
285+
Hdf5.WriteDatasetFromArray<long>(f, name, data.numpy().ToMuliDimArray<float>());
286+
break;
287+
default:
288+
Hdf5.WriteDatasetFromArray<float>(f, name, data.numpy().ToMuliDimArray<float>());
289+
break;
290+
}
291+
}
292+
private static void WriteAttrs(long f,string typename, string name, Array data)
169293
{
294+
switch (typename)
295+
{
296+
case "single":
297+
Hdf5.WriteAttributes<float>(f, name, data);
298+
break;
299+
case "double":
300+
Hdf5.WriteAttributes<double>(f, name, data);
301+
break;
302+
case "string":
303+
Hdf5.WriteAttributes<string>(f, name, data);
304+
break;
305+
case "int32":
306+
Hdf5.WriteAttributes<int>(f, name, data);
307+
break;
308+
case "int64":
309+
Hdf5.WriteAttributes<long>(f, name, data);
310+
break;
311+
default:
312+
Hdf5.WriteAttributes<string>(f, name,data);
313+
break;
314+
}
315+
}
316+
private static List<List<object>> Split(Array list, int chunkSize)
317+
{
318+
var splitList = new List<List<object>>();
319+
var chunkCount = (int)Math.Ceiling((double)list.Length / (double)chunkSize);
320+
321+
for (int c = 0; c < chunkCount; c++)
322+
{
323+
var skip = c * chunkSize;
324+
var take = skip + chunkSize;
325+
var chunk = new List<object>(chunkSize);
326+
327+
for (int e = skip; e < take && e < list.Length; e++)
328+
{
329+
chunk.Add(list.GetValue(e));
330+
}
331+
splitList.Add(chunk);
332+
}
170333

334+
return splitList;
171335
}
172336
public static string[] load_attributes_from_hdf5_group(long group, string name)
173337
{

0 commit comments

Comments
 (0)