@@ -18,11 +18,9 @@ limitations under the License.
1818using System . Collections . Generic ;
1919using System . Linq ;
2020using System . Threading ;
21- using Tensorflow . Contexts ;
2221using Tensorflow . Keras . ArgsDefinition ;
2322using Tensorflow . Keras . Layers ;
2423using Tensorflow . Keras . Utils ;
25- using Tensorflow . Operations . Activation ;
2624using Tensorflow . Train ;
2725using static Tensorflow . Binding ;
2826
@@ -34,7 +32,7 @@ namespace Tensorflow.Keras.Engine
3432 /// as convolution, batch norm, etc. These operations require managing weights,
3533 /// losses, updates, and inter-layer connectivity.
3634 /// </summary>
37- public abstract class Layer : AutoTrackable
35+ public abstract partial class Layer : AutoTrackable
3836 {
3937 /// <summary>
4038 /// Arguments initialize layer.
@@ -60,8 +58,19 @@ public abstract class Layer : AutoTrackable
6058 protected InputSpec inputSpec ;
6159 public bool SupportsMasking { get ; set ; }
6260 protected List < IVariableV1 > trainableWeights ;
63- public List < IVariableV1 > TrainableVariables => trainableWeights ;
61+ public List < IVariableV1 > trainable_variables
62+ {
63+ get
64+ {
65+ if ( trainableWeights . Count == 0 )
66+ _layers . ForEach ( x => trainableWeights . AddRange ( x . trainableWeights ) ) ;
67+
68+ return trainableWeights ;
69+ }
70+ }
71+
6472 protected List < IVariableV1 > nonTrainableWeights ;
73+ public List < IVariableV1 > non_trainable_variables => nonTrainableWeights ;
6574
6675 string name ;
6776 public string Name => name ;
@@ -112,20 +121,20 @@ public Layer(LayerArgs args)
112121 /// <param name="input"></param>
113122 /// <param name="is_training"></param>
114123 /// <returns></returns>
115- public Tensor [ ] Apply ( Tensor [ ] inputs , bool is_training = false )
124+ public Tensor Apply ( Tensor inputs , bool is_training = false )
116125 {
117- var input = inputs [ 0 ] ;
118- Tensor [ ] outputs = null ;
126+ Tensor outputs = null ;
119127
120128 callContext = callContext ?? new ThreadLocal < CallContext > ( )
121129 {
122130 Value = new CallContext ( )
123131 } ;
124132
133+ var eager = tf . executing_eagerly ( ) ;
125134 using var ctxManager = CallContext . enter ( ) ;
126135
127136 string nameScope = "" ;
128- if ( tf . executing_eagerly ( ) )
137+ if ( eager )
129138 {
130139 nameScope = name ;
131140 }
@@ -134,7 +143,7 @@ public Tensor[] Apply(Tensor[] inputs, bool is_training = false)
134143 throw new NotImplementedException ( "" ) ;
135144 }
136145
137- using var graph = tf . keras . backend . get_graph ( ) . as_default ( ) ;
146+ // using var graph = tf.keras.backend.get_graph().as_default();
138147
139148 tf_with ( ops . name_scope ( nameScope ) , scope =>
140149 {
@@ -143,82 +152,44 @@ public Tensor[] Apply(Tensor[] inputs, bool is_training = false)
143152
144153 outputs = call ( inputs , is_training : is_training ) ;
145154
146- ( input , outputs ) = _set_connectivity_metadata_ ( input , outputs ) ;
147- _handle_activity_regularization ( inputs [ 0 ] , outputs ) ;
148- _set_mask_metadata ( inputs [ 0 ] , outputs , null ) ;
155+ outputs = _set_connectivity_metadata_ ( inputs , outputs ) ;
156+ _handle_activity_regularization ( inputs , outputs ) ;
157+ _set_mask_metadata ( inputs , outputs , null ) ;
149158 } ) ;
150159
151160 return outputs ;
152161 }
153162
154- [ Obsolete ( "User Apply()" ) ]
155- public Tensor [ ] __call__ ( Tensor [ ] inputs ,
156- Tensor training = null ,
157- Tensor state = null ,
158- VariableScope scope = null )
163+ private Tensor _set_connectivity_metadata_ ( Tensor inputs , Tensor outputs )
159164 {
160- var input_list = inputs ;
161- var input = inputs [ 0 ] ;
162- Tensor [ ] outputs = null ;
163-
164- // We will attempt to build a TF graph if & only if all inputs are symbolic.
165- // This is always the case in graph mode. It can also be the case in eager
166- // mode when all inputs can be traced back to `keras.Input()` (when building
167- // models using the functional API).
168- bool build_graph = tf_utils . are_all_symbolic_tensors ( input_list ) ;
169-
170- if ( build_graph )
171- {
172- // Only create Keras history if at least one tensor originates from a
173- // `keras.Input`. Otherwise this Layer may be being used outside the Keras
174- // framework.
175- // base_layer_utils.create_keras_history(inputs)
176- }
177-
178- // with base_layer_utils.call_context(self):
179-
180- // Handle Keras mask propagation from previous layer to current layer.
181- // with base_layer_utils.call_context(self):
182- // Check input assumptions set after layer building, e.g. input shape.
183- if ( build_graph )
165+ /*var returnOutputs = new List<Tensor>();
166+ foreach(var x in outputs)
184167 {
185- // Symbolic execution on symbolic tensors. We will attempt to build
186- // the corresponding TF subgraph inside `backend.get_graph()`
187- var graph = tf . keras . backend . get_graph ( ) . as_default ( ) ;
188- tf_with ( ops . name_scope ( _name_scope ( ) ) , delegate
168+ if (inputs.Contains(x))
189169 {
190- // Build layer if applicable (if the `build` method has been
191- // overridden).
192- MaybeBuild ( inputs ) ;
193-
194- outputs = call ( inputs ,
195- // training: training,
196- state : state ) ;
197170
198- ( input , outputs ) = _set_connectivity_metadata_ ( input , outputs ) ;
199- _handle_activity_regularization ( inputs [ 0 ] , outputs ) ;
200- _set_mask_metadata ( inputs [ 0 ] , outputs , null ) ;
201- } ) ;
202- }
171+ }
172+ returnOutputs.Add(x);
173+ }*/
203174
204- return outputs ;
205- }
175+ new Node ( this , new NodeArgs
176+ {
177+ Outputs = outputs
178+ } ) ;
206179
207- private ( Tensor , Tensor [ ] ) _set_connectivity_metadata_ ( Tensor inputs , Tensor [ ] outputs )
208- {
209180 //_add_inbound_node(input_tensors: inputs, output_tensors: outputs);
210- return ( inputs , outputs ) ;
181+ return outputs ;
211182 }
212183
213- private void _handle_activity_regularization ( Tensor inputs , Tensor [ ] outputs )
184+ private void _handle_activity_regularization ( Tensor inputs , Tensor outputs )
214185 {
215186 //if(_activity_regularizer != null)
216187 {
217188
218189 }
219190 }
220191
221- private void _set_mask_metadata ( Tensor inputs , Tensor [ ] outputs , Tensor previous_mask )
192+ private void _set_mask_metadata ( Tensor inputs , Tensor outputs , Tensor previous_mask )
222193 {
223194
224195 }
@@ -228,7 +199,7 @@ private Tensor compute_mask(Tensor inputs, Tensor mask = null)
228199 return null ;
229200 }
230201
231- protected virtual Tensor [ ] call ( Tensor [ ] inputs , bool is_training = false , Tensor state = null )
202+ protected virtual Tensor call ( Tensor inputs , bool is_training = false , Tensor state = null )
232203 {
233204 throw new NotImplementedException ( "" ) ;
234205 }
@@ -238,15 +209,15 @@ protected virtual string _name_scope()
238209 return Name ;
239210 }
240211
241- protected void MaybeBuild ( Tensor [ ] inputs )
212+ protected void MaybeBuild ( Tensor inputs )
242213 {
243214 // Check input assumptions set before layer building, e.g. input rank.
244215 if ( built )
245216 return ;
246217 if ( DType == TF_DataType . DtInvalid )
247- args . DType = inputs [ 0 ] . dtype ;
218+ args . DType = inputs . dtype ;
248219
249- var input_shapes = inputs [ 0 ] . TensorShape ;
220+ var input_shapes = inputs . TensorShape ;
250221 build ( input_shapes ) ;
251222 built = true ;
252223 }
0 commit comments