@@ -150,40 +150,56 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? train
150150
151151 void _build_graph_network_for_inferred_shape ( Shape input_shape , TF_DataType input_dtype )
152152 {
153+ if ( _inferred_input_shape == input_shape )
154+ return ;
155+
153156 ops . init_scope ( ) ;
154157 var inputs = keras . Input ( batch_input_shape : input_shape ,
155158 dtype : input_dtype ,
156159 name : $ "{ _layers [ 0 ] . Name } _input") ;
157160 Tensors layer_input = inputs ;
158161 Tensors layer_output = null ;
159162 Tensors outputs = null ;
160-
163+ List < INode > created_nodes = new List < INode > ( ) ;
161164 foreach ( var layer in _layers )
162165 {
163166 clear_previously_created_nodes ( layer , _created_nodes ) ;
164167 layer_output = layer . Apply ( layer_input ) ;
165168 // Keep track of nodes just created above
166- track_nodes_created_by_last_call ( layer , _created_nodes ) ;
169+ track_nodes_created_by_last_call ( layer , created_nodes ) ;
167170 layer_input = layer_output ;
168171 outputs = layer_output ;
169172 }
173+ _created_nodes = created_nodes ;
170174 _init_graph_network ( inputs , outputs ) ;
171175 _graph_initialized = true ;
172176 _inferred_input_shape = input_shape ;
173177 }
174178
175179 void clear_previously_created_nodes ( ILayer layer , List < INode > created_nodes )
176180 {
181+ foreach ( var node in layer . InboundNodes )
182+ {
183+ foreach ( var prev_layer in node . InboundLayers )
184+ {
185+ var outNodes = prev_layer . OutboundNodes . Where ( x => ! created_nodes . Contains ( x ) ) . ToArray ( ) ;
186+ prev_layer . OutboundNodes . Clear ( ) ;
187+ prev_layer . OutboundNodes . AddRange ( outNodes ) ;
188+ }
189+ }
177190
191+ var inNodes = layer . InboundNodes . Where ( x => ! created_nodes . Contains ( x ) ) . ToArray ( ) ;
192+ layer . InboundNodes . Clear ( ) ;
193+ layer . InboundNodes . AddRange ( inNodes ) ;
178194 }
179195
180196 void track_nodes_created_by_last_call ( ILayer layer , List < INode > created_nodes )
181197 {
182198 var node = layer . InboundNodes . Last ( ) ;
183199 created_nodes . Add ( node ) ;
184- foreach ( var prev_layer in node . iterate_inbound ( ) )
200+ foreach ( var prev_layer in node . InboundLayers )
185201 {
186- created_nodes . add ( prev_layer . Item1 . OutboundNodes . Last ( ) ) ;
202+ created_nodes . add ( prev_layer . OutboundNodes . Last ( ) ) ;
187203 }
188204 }
189205 }
0 commit comments