@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414 limitations under the License.
1515******************************************************************************/
1616
17+ using System ;
1718using System . Linq ;
1819using System . Collections . Generic ;
1920using Tensorflow . Keras . ArgsDefinition ;
@@ -35,8 +36,9 @@ public class Sequential : Functional
3536 bool _auto_track_sub_layers ;
3637 Shape _inferred_input_shape ;
3738 bool _has_explicit_input_shape ;
38-
39+ bool _graph_initialized ;
3940 public Shape output_shape => outputs [ 0 ] . shape ;
41+ List < INode > _created_nodes ;
4042
4143 public Sequential ( SequentialArgs args )
4244 : base ( args . Inputs , args . Outputs , name : args . Name )
@@ -49,12 +51,13 @@ public Sequential(SequentialArgs args)
4951 _auto_track_sub_layers = false ;
5052 _has_explicit_input_shape = false ;
5153 _is_graph_network = false ;
54+ _created_nodes = new List < INode > ( ) ;
5255
5356 // Add to the model any layers passed to the constructor.
5457 if ( args . Layers != null )
5558 {
5659 foreach ( var layer in args . Layers )
57- add ( layer as Layer ) ;
60+ add ( layer ) ;
5861 }
5962 }
6063
@@ -118,7 +121,69 @@ public void add(ILayer layer)
118121 }
119122 else
120123 {
124+ _self_tracked_trackables . add ( layer ) ;
125+ _handle_deferred_layer_dependencies ( layer ) ;
126+ }
127+ }
121128
129+ void _handle_deferred_layer_dependencies ( params ILayer [ ] layers )
130+ {
131+ _layers . AddRange ( layers ) ;
132+ }
133+
134+ protected override Tensors Call ( Tensors inputs , Tensor state = null , bool ? training = null )
135+ {
136+ if ( ! _has_explicit_input_shape )
137+ {
138+ _build_graph_network_for_inferred_shape ( inputs . shape , inputs . dtype ) ;
139+ }
140+
141+ if ( _graph_initialized )
142+ {
143+ if ( ! built )
144+ _init_graph_network ( this . inputs , outputs ) ;
145+ return base . Call ( inputs , state , training ) ;
146+ }
147+
148+ return base . Call ( inputs , state , training ) ;
149+ }
150+
151+ void _build_graph_network_for_inferred_shape ( Shape input_shape , TF_DataType input_dtype )
152+ {
153+ ops . init_scope ( ) ;
154+ var inputs = keras . Input ( batch_input_shape : input_shape ,
155+ dtype : input_dtype ,
156+ name : $ "{ _layers [ 0 ] . Name } _input") ;
157+ Tensors layer_input = inputs ;
158+ Tensors layer_output = null ;
159+ Tensors outputs = null ;
160+
161+ foreach ( var layer in _layers )
162+ {
163+ clear_previously_created_nodes ( layer , _created_nodes ) ;
164+ layer_output = layer . Apply ( layer_input ) ;
165+ // Keep track of nodes just created above
166+ track_nodes_created_by_last_call ( layer , _created_nodes ) ;
167+ layer_input = layer_output ;
168+ outputs = layer_output ;
169+ }
170+ _init_graph_network ( inputs , outputs ) ;
171+ _graph_initialized = true ;
172+ _inferred_input_shape = input_shape ;
173+ }
174+
175+ void clear_previously_created_nodes ( ILayer layer , List < INode > created_nodes )
176+ {
177+
178+ }
179+
180+ void track_nodes_created_by_last_call ( ILayer layer , List < INode > created_nodes )
181+ {
182+ var node = layer . InboundNodes . Last ( ) ;
183+ created_nodes . Add ( node ) ;
184+ foreach ( var prev_layer in node . iterate_inbound ( ) )
185+ {
186+ created_nodes . add ( prev_layer . Item1 . OutboundNodes . Last ( ) ) ;
122187 }
123188 }
124189 }
0 commit comments