@@ -30,6 +30,9 @@ limitations under the License.
3030using Tensorflow . Training . Saving . SavedModel ;
3131using Tensorflow . Util ;
3232using static Tensorflow . Binding ;
33+ using Tensorflow . Framework ;
34+ using Tensorflow . Sessions ;
35+
3336
3437namespace Tensorflow . Keras . Engine
3538{
@@ -134,21 +137,53 @@ public virtual List<IVariableV1> Weights
134137 }
135138 }
136139
137- public virtual void set_weights ( List < NDArray > weights )
140+ public virtual void set_weights ( IEnumerable < NDArray > weights )
138141 {
139142 if ( Weights . Count ( ) != weights . Count ( ) ) throw new ValueError (
140143 $ "You called `set_weights` on layer \" { this . name } \" " +
141144 $ "with a weight list of length { len ( weights ) } , but the layer was " +
142145 $ "expecting { len ( Weights ) } weights.") ;
143- for ( int i = 0 ; i < weights . Count ( ) ; i ++ )
146+
147+
148+
149+ // check if the shapes are compatible
150+ var weight_index = 0 ;
151+ foreach ( var w in weights )
144152 {
145- if ( weights [ i ] . shape != Weights [ i ] . shape )
153+ if ( ! Weights [ weight_index ] . AsTensor ( ) . is_compatible_with ( w ) )
146154 {
147- throw new ValueError ( $ "Layer weight shape { weights [ i ] . shape } not compatible with provided weight shape { Weights [ i ] . shape } ") ;
155+ throw new ValueError ( $ "Layer weight shape { w . shape } not compatible with provided weight shape { Weights [ weight_index ] . shape } ") ;
148156 }
157+ weight_index ++ ;
158+ }
159+
160+ if ( tf . executing_eagerly ( ) )
161+ {
162+ foreach ( var ( this_w , v_w ) in zip ( Weights , weights ) )
163+ this_w . assign ( v_w , read_value : true ) ;
164+ }
165+ else
166+ {
167+ // TODO(Wanglongzhi2001):seems like there exist some bug in graph mode when define model, so uncomment the following when it fixed.
168+
169+ //Tensors assign_ops = new Tensors();
170+ //var feed_dict = new FeedDict();
171+
172+ //Graph g = tf.Graph().as_default();
173+ //foreach (var (this_w, v_w) in zip(Weights, weights))
174+ //{
175+ // var tf_dtype = this_w.dtype;
176+ // var placeholder_shape = v_w.shape;
177+ // var assign_placeholder = tf.placeholder(tf_dtype, placeholder_shape);
178+ // var assign_op = this_w.assign(assign_placeholder);
179+ // assign_ops.Add(assign_op);
180+ // feed_dict.Add(assign_placeholder, v_w);
181+ //}
182+ //var sess = tf.Session().as_default();
183+ //sess.run(assign_ops, feed_dict);
184+
185+ //g.Exit();
149186 }
150- foreach ( var ( this_w , v_w ) in zip ( Weights , weights ) )
151- this_w . assign ( v_w , read_value : true ) ;
152187 }
153188
154189 public List < NDArray > get_weights ( )
0 commit comments