@@ -39,6 +39,12 @@ public class Layer : CheckpointableBase
3939 protected List < Operation > _updates ;
4040 public int [ ] _batch_input_shape ;
4141
42+ private List < Node > _inbound_nodes ;
43+ public List < Node > inbound_nodes => _inbound_nodes ;
44+
45+ private List < Node > _outbound_nodes ;
46+ public List < Node > outbound_nodes => _outbound_nodes ;
47+
4248 public Layer ( bool trainable = true ,
4349 string name = null ,
4450 TF_DataType dtype = TF_DataType . DtInvalid ,
@@ -59,13 +65,15 @@ public Layer(bool trainable = true,
5965 _batch_input_shape = new int [ ] { - 1 , - 1 } ;
6066
6167 _dtype = dtype ;
68+
69+ _inbound_nodes = new List < Node > ( ) ;
6270 }
6371
64- public Tensor __call__ ( Tensor inputs ,
72+ public Tensor __call__ ( Tensor [ ] inputs ,
6573 Tensor training = null ,
6674 VariableScope scope = null )
6775 {
68- var input_list = new Tensor [ ] { inputs } ;
76+ var input_list = inputs ;
6977 Tensor outputs = null ;
7078
7179 // We will attempt to build a TF graph if & only if all inputs are symbolic.
@@ -88,9 +96,9 @@ public Tensor __call__(Tensor inputs,
8896 // Symbolic execution on symbolic tensors. We will attempt to build
8997 // the corresponding TF subgraph inside `backend.get_graph()`
9098 var graph = backend . get_graph ( ) ;
91- outputs = call ( inputs , training : training ) ;
92- _handle_activity_regularization ( inputs , outputs ) ;
93- _set_mask_metadata ( inputs , outputs , null ) ;
99+ outputs = call ( inputs [ 0 ] , training : training ) ;
100+ _handle_activity_regularization ( inputs [ 0 ] , outputs ) ;
101+ _set_mask_metadata ( inputs [ 0 ] , outputs , null ) ;
94102 }
95103 } ) ;
96104
@@ -125,10 +133,10 @@ protected virtual string _name_scope()
125133 return null ;
126134 }
127135
128- protected void _maybe_build ( Tensor inputs )
136+ protected void _maybe_build ( Tensor [ ] inputs )
129137 {
130- var input_list = new Tensor [ ] { inputs } ;
131- build ( inputs . getShape ( ) ) ;
138+ var input_list = inputs ;
139+ build ( input_list [ 0 ] . getShape ( ) ) ;
132140 }
133141
134142 protected virtual void build ( TensorShape input_shape )
@@ -143,10 +151,16 @@ protected virtual RefVariable add_weight(string name,
143151 bool ? trainable = null ,
144152 Func < string , int [ ] , TF_DataType , IInitializer , bool , RefVariable > getter = null )
145153 {
154+ if ( dtype == TF_DataType . DtInvalid )
155+ dtype = TF_DataType . TF_FLOAT ;
156+
157+ if ( trainable == null )
158+ trainable = true ;
159+
146160 var variable = _add_variable_with_custom_getter ( name ,
147161 shape ,
148162 dtype : dtype ,
149- getter : getter ,
163+ getter : getter == null ? base_layer_utils . make_variable : getter ,
150164 overwrite : true ,
151165 initializer : initializer ,
152166 trainable : trainable . Value ) ;
0 commit comments