@@ -30,6 +30,9 @@ limitations under the License.
3030using Tensorflow . Training . Saving . SavedModel ;
3131using Tensorflow . Util ;
3232using static Tensorflow . Binding ;
33+ using Tensorflow . Framework ;
34+ using Tensorflow . Sessions ;
35+
3336
3437namespace Tensorflow . Keras . Engine
3538{
@@ -134,6 +137,62 @@ public virtual List<IVariableV1> Weights
134137 }
135138 }
136139
140+ public virtual void set_weights ( IEnumerable < NDArray > weights )
141+ {
142+ if ( Weights . Count ( ) != weights . Count ( ) ) throw new ValueError (
143+ $ "You called `set_weights` on layer \" { this . name } \" " +
144+ $ "with a weight list of length { len ( weights ) } , but the layer was " +
145+ $ "expecting { len ( Weights ) } weights.") ;
146+
147+
148+
149+ // check if the shapes are compatible
150+ var weight_index = 0 ;
151+ foreach ( var w in weights )
152+ {
153+ if ( ! Weights [ weight_index ] . AsTensor ( ) . is_compatible_with ( w ) )
154+ {
155+ throw new ValueError ( $ "Layer weight shape { w . shape } not compatible with provided weight shape { Weights [ weight_index ] . shape } ") ;
156+ }
157+ weight_index ++ ;
158+ }
159+
160+ if ( tf . executing_eagerly ( ) )
161+ {
162+ foreach ( var ( this_w , v_w ) in zip ( Weights , weights ) )
163+ this_w . assign ( v_w , read_value : true ) ;
164+ }
165+ else
166+ {
167+ // TODO(Wanglongzhi2001):seems like there exist some bug in graph mode when define model, so uncomment the following when it fixed.
168+
169+ //Tensors assign_ops = new Tensors();
170+ //var feed_dict = new FeedDict();
171+
172+ //Graph g = tf.Graph().as_default();
173+ //foreach (var (this_w, v_w) in zip(Weights, weights))
174+ //{
175+ // var tf_dtype = this_w.dtype;
176+ // var placeholder_shape = v_w.shape;
177+ // var assign_placeholder = tf.placeholder(tf_dtype, placeholder_shape);
178+ // var assign_op = this_w.assign(assign_placeholder);
179+ // assign_ops.Add(assign_op);
180+ // feed_dict.Add(assign_placeholder, v_w);
181+ //}
182+ //var sess = tf.Session().as_default();
183+ //sess.run(assign_ops, feed_dict);
184+
185+ //g.Exit();
186+ }
187+ }
188+
189+ public List < NDArray > get_weights ( )
190+ {
191+ List < NDArray > weights = new List < NDArray > ( ) ;
192+ weights . AddRange ( Weights . ConvertAll ( x => x . numpy ( ) ) ) ;
193+ return weights ;
194+ }
195+
137196 protected int id ;
138197 public int Id => id ;
139198 protected string name ;
0 commit comments