Skip to content

Commit 9ff09c4

Browse files
committed
skip layer wihtout trainable weights when save_weights.
1 parent 35e070d commit 9ff09c4

File tree

4 files changed

+47
-51
lines changed

4 files changed

+47
-51
lines changed

src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -318,16 +318,16 @@ public static Tensor[] fused_batch_norm(Tensor x,
318318
return _op.outputs;
319319
}
320320

321-
public static Tensor[] fused_batch_norm_v3(Tensor x,
322-
Tensor scale,
323-
Tensor offset,
324-
IVariableV1 mean,
325-
IVariableV1 variance,
326-
float epsilon = 0.0001f,
327-
float exponential_avg_factor = 1.0f,
328-
string data_format = "NHWC",
329-
bool is_training = true,
330-
string name = null)
321+
public static Tensors fused_batch_norm_v3(Tensor x,
322+
IVariableV1 scale,
323+
IVariableV1 offset,
324+
IVariableV1 mean,
325+
IVariableV1 variance,
326+
float epsilon = 0.0001f,
327+
float exponential_avg_factor = 1.0f,
328+
string data_format = "NHWC",
329+
bool is_training = true,
330+
string name = null)
331331
{
332332
if (tf.executing_eagerly())
333333
{
@@ -337,8 +337,8 @@ public static Tensor[] fused_batch_norm_v3(Tensor x,
337337
x,
338338
scale,
339339
offset,
340-
mean.AsTensor(),
341-
variance.AsTensor(),
340+
mean,
341+
variance,
342342
"epsilon", epsilon,
343343
"exponential_avg_factor", exponential_avg_factor,
344344
"data_format", data_format,

src/TensorFlowNET.Core/Operations/nn_impl.py.cs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,6 @@ public static Tensor[] fused_batch_norm(Tensor x,
107107
string name = null,
108108
float exponential_avg_factor = 1.0f)
109109
{
110-
x = ops.convert_to_tensor(x, name: "input");
111-
var scale_tensor = ops.convert_to_tensor(scale, name: "scale");
112-
var offset_tensor = ops.convert_to_tensor(offset, name: "offset");
113110
/*if (mean == null)
114111
mean = constant_op.constant(new float[0]);
115112
if (variance == null)
@@ -118,11 +115,11 @@ public static Tensor[] fused_batch_norm(Tensor x,
118115
epsilon = epsilon > min_epsilon ? epsilon : min_epsilon;
119116

120117
var results = gen_nn_ops.fused_batch_norm_v3(x,
121-
scale_tensor,
122-
offset_tensor,
118+
scale,
119+
offset,
123120
mean,
124121
variance,
125-
epsilon,
122+
epsilon: epsilon,
126123
exponential_avg_factor: exponential_avg_factor,
127124
data_format: data_format,
128125
is_training: is_training,

src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,14 @@ void variable_accessed(BaseResourceVariable variable)
163163
/// </summary>
164164
/// <returns></returns>
165165
protected Tensor read_value()
166-
=> tf_with(ops.name_scope("Read"), delegate
167-
{
168-
var value = _read_variable_op();
169-
return array_ops.identity(value);
166+
{
167+
var value = tf_with(ops.name_scope("Read"), delegate
168+
{
169+
return _read_variable_op();
170170
});
171+
return array_ops.identity(value);
172+
}
173+
171174

172175
public Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
173176
{

src/TensorFlowNET.Keras/Saving/hdf5_format.cs

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -101,26 +101,28 @@ public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
101101
if (success)
102102
original_backend = attr.First();
103103
}
104-
List<ILayer> filtered_layers = new List<ILayer>();
105-
List<IVariableV1> weights;
104+
105+
var filtered_layers = new List<ILayer>();
106106
foreach (var layer in layers)
107107
{
108-
weights = _legacy_weights(layer);
108+
var weights = _legacy_weights(layer);
109109
if (weights.Count > 0)
110-
{
111110
filtered_layers.append(layer);
112-
}
113111
}
112+
114113
string[] layer_names = load_attributes_from_hdf5_group(f, "layer_names");
115114
var filtered_layer_names = new List<string>();
116115
foreach(var name in layer_names)
117116
{
117+
if (!filtered_layers.Select(x => x.Name).Contains(name))
118+
continue;
118119
long g = H5G.open(f, name);
119120
var weight_names = load_attributes_from_hdf5_group(g, "weight_names");
120121
if (weight_names.Count() > 0)
121122
filtered_layer_names.Add(name);
122123
H5G.close(g);
123124
}
125+
124126
layer_names = filtered_layer_names.ToArray();
125127
if (layer_names.Length != filtered_layers.Count())
126128
throw new ValueError("You are trying to load a weight file " +
@@ -133,7 +135,6 @@ public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
133135
var weight_values = new List<NDArray>();
134136
long g = H5G.open(f, name);
135137
var weight_names = load_attributes_from_hdf5_group(g, "weight_names");
136-
var get_Name = "";
137138
foreach (var i_ in weight_names)
138139
{
139140
(bool success, Array result) = Hdf5.ReadDataset<float>(g, i_);
@@ -153,6 +154,7 @@ public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
153154
$"{weight_values.Count()} elements.");
154155
weight_value_tuples.AddRange(zip(symbolic_weights, weight_values));
155156
}
157+
156158
keras.backend.batch_set_value(weight_value_tuples);
157159
}
158160
public static void toarrayf4(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
@@ -175,43 +177,37 @@ public static void save_weights_to_hdf5_group(long f, List<ILayer> layers)
175177
Hdf5.WriteAttribute(f, "keras_version", "2.5.0");
176178

177179
long g = 0, crDataGroup=0;
178-
List<IVariableV1> weights = new List<IVariableV1>();
179-
//List<IVariableV1> weight_values = new List<IVariableV1>();
180-
List<string> weight_names = new List<string>();
181-
foreach (var layer in layers) {
182-
weight_names = new List<string>();
183-
g = Hdf5.CreateOrOpenGroup(f, Hdf5Utils.NormalizedName(layer.Name));
184-
weights = _legacy_weights(layer);
185-
//weight_values= keras.backend.batch_get_value(weights);
180+
foreach (var layer in layers)
181+
{
182+
var weights = _legacy_weights(layer);
183+
if (weights.Count == 0)
184+
continue;
185+
186+
var weight_names = new List<string>();
187+
// weight_values= keras.backend.batch_get_value(weights);
186188
foreach (var weight in weights)
187-
{
188189
weight_names.Add(weight.Name);
189-
}
190+
191+
g = Hdf5.CreateOrOpenGroup(f, Hdf5Utils.NormalizedName(layer.Name));
190192
save_attributes_to_hdf5_group(g, "weight_names", weight_names.ToArray());
191-
Tensor tensor = null;
192-
foreach (var (name, val) in zip(weight_names, weights)) {
193-
194-
tensor = val.AsTensor();
193+
foreach (var (name, val) in zip(weight_names, weights))
194+
{
195+
var tensor = val.AsTensor();
195196
if (name.IndexOf("/") > 1)
196197
{
197198
crDataGroup = Hdf5.CreateOrOpenGroup(g, Hdf5Utils.NormalizedName(name.Split('/')[0]));
198199
WriteDataset(crDataGroup, name.Split('/')[1], tensor);
199200
Hdf5.CloseGroup(crDataGroup);
200201
}
201-
else {
202+
else
203+
{
202204
WriteDataset(crDataGroup, name, tensor);
203205
}
204-
205-
tensor = null;
206-
}
206+
}
207207
Hdf5.CloseGroup(g);
208-
weight_names = null;
209208
}
210-
weights = null;
211-
// weight_values = null;
212-
213-
214209
}
210+
215211
private static void save_attributes_to_hdf5_group(long f,string name ,Array data)
216212
{
217213
int num_chunks = 1;

0 commit comments

Comments
 (0)