Skip to content

Commit 426a55c

Browse files
Add set_weights and get_weights APIs
1 parent f4e7fd4 commit 426a55c

File tree

3 files changed

+43
-8
lines changed

3 files changed

+43
-8
lines changed

src/TensorFlowNET.Core/Keras/Layers/ILayer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public interface ILayer: IWithTrackable, IKerasConfigable
1919
List<IVariableV1> TrainableWeights { get; }
2020
List<IVariableV1> NonTrainableWeights { get; }
2121
List<IVariableV1> Weights { get; set; }
22-
void set_weights(List<NDArray> weights);
22+
void set_weights(IEnumerable<NDArray> weights);
2323
List<NDArray> get_weights();
2424
Shape OutputShape { get; }
2525
Shape BatchInputShape { get; }

src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell
7575
public List<IVariableV1> Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
7676

7777
public List<NDArray> get_weights() => throw new NotImplementedException();
78-
public void set_weights(List<NDArray> weights) => throw new NotImplementedException();
78+
public void set_weights(IEnumerable<NDArray> weights) => throw new NotImplementedException();
7979
public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException();
8080

8181
public Shape OutputShape => throw new NotImplementedException();

src/TensorFlowNET.Keras/Engine/Layer.cs

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ limitations under the License.
3030
using Tensorflow.Training.Saving.SavedModel;
3131
using Tensorflow.Util;
3232
using static Tensorflow.Binding;
33+
using Tensorflow.Framework;
34+
using Tensorflow.Sessions;
35+
3336

3437
namespace Tensorflow.Keras.Engine
3538
{
@@ -134,21 +137,53 @@ public virtual List<IVariableV1> Weights
134137
}
135138
}
136139

137-
public virtual void set_weights(List<NDArray> weights)
140+
public virtual void set_weights(IEnumerable<NDArray> weights)
138141
{
139142
if (Weights.Count() != weights.Count()) throw new ValueError(
140143
$"You called `set_weights` on layer \"{this.name}\"" +
141144
$"with a weight list of length {len(weights)}, but the layer was " +
142145
$"expecting {len(Weights)} weights.");
143-
for (int i = 0; i < weights.Count(); i++)
146+
147+
148+
149+
// check if the shapes are compatible
150+
var weight_index = 0;
151+
foreach(var w in weights)
144152
{
145-
if (weights[i].shape != Weights[i].shape)
153+
if (!Weights[weight_index].AsTensor().is_compatible_with(w))
146154
{
147-
throw new ValueError($"Layer weight shape {weights[i].shape} not compatible with provided weight shape {Weights[i].shape}");
155+
throw new ValueError($"Layer weight shape {w.shape} not compatible with provided weight shape {Weights[weight_index].shape}");
148156
}
157+
weight_index++;
158+
}
159+
160+
if (tf.executing_eagerly())
161+
{
162+
foreach (var (this_w, v_w) in zip(Weights, weights))
163+
this_w.assign(v_w, read_value: true);
164+
}
165+
else
166+
{
167+
// TODO(Wanglongzhi2001):seems like there exist some bug in graph mode when define model, so uncomment the following when it fixed.
168+
169+
//Tensors assign_ops = new Tensors();
170+
//var feed_dict = new FeedDict();
171+
172+
//Graph g = tf.Graph().as_default();
173+
//foreach (var (this_w, v_w) in zip(Weights, weights))
174+
//{
175+
// var tf_dtype = this_w.dtype;
176+
// var placeholder_shape = v_w.shape;
177+
// var assign_placeholder = tf.placeholder(tf_dtype, placeholder_shape);
178+
// var assign_op = this_w.assign(assign_placeholder);
179+
// assign_ops.Add(assign_op);
180+
// feed_dict.Add(assign_placeholder, v_w);
181+
//}
182+
//var sess = tf.Session().as_default();
183+
//sess.run(assign_ops, feed_dict);
184+
185+
//g.Exit();
149186
}
150-
foreach (var (this_w, v_w) in zip(Weights, weights))
151-
this_w.assign(v_w, read_value: true);
152187
}
153188

154189
public List<NDArray> get_weights()

0 commit comments

Comments
 (0)