@@ -116,17 +116,23 @@ public static Dictionary<string, ConcreteFunction> load_function_def_library(Fun
116116 }
117117
118118 Dictionary < string , ConcreteFunction > loaded_gradients = new ( ) ;
119- foreach ( var fdef in _sort_function_defs ( library , function_deps ) )
119+ // Debug(Rinne)
120+ var temp = _sort_function_defs ( library , function_deps ) ;
121+ int i = 0 ;
122+ foreach ( var fdef in temp )
120123 {
124+ i ++ ;
121125 var orig_name = _fix_fdef_in_place ( fdef , functions , load_shared_name_suffix , new_gradient_op_types ) ;
122126
123127 object structured_input_signature = null ;
124128 object structured_outputs = null ;
125129 if ( saved_object_graph is not null && saved_object_graph . ConcreteFunctions . ContainsKey ( orig_name ) )
126130 {
127- var proto = saved_object_graph . ConcreteFunctions [ orig_name ] ;
128- structured_input_signature = nested_structure_coder . decode_proto ( proto . CanonicalizedInputSignature ) ;
129- structured_outputs = nested_structure_coder . decode_proto ( proto . OutputSignature ) ;
131+ // TODO(Rinne): deal with structured_input_signature and structured_outputs.
132+
133+ //var proto = saved_object_graph.ConcreteFunctions[orig_name];
134+ //structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature);
135+ //structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature);
130136 }
131137
132138 graph . as_default ( ) ;
@@ -234,27 +240,41 @@ private static Func<Operation, Tensor[], Tensor[]> _gen_gradient_func(ConcreteFu
234240
235241 private static void _restore_gradient_functions ( FuncGraph func_graph , Dictionary < string , ConcreteFunction > renamed_functions , Dictionary < string , ConcreteFunction > loaded_gradients )
236242 {
237- foreach ( var op in func_graph . get_operations ( ) )
243+ if ( loaded_gradients is null || loaded_gradients . Count == 0 )
238244 {
239- if ( op . op . type == "StatefulPartitionedCall" || op . op . type == "PartitionedCall" )
240- {
241- var function = renamed_functions [ op . op . node_def . Attr [ "f" ] . Func . Name ] ;
242- op . op . _gradient_function = function . _get_gradient_function ( ) ;
243- }
244- string gradient_op_type = null ;
245- try
246- {
247- gradient_op_type = op . op . get_attr ( "_gradient_op_type" ) as string ;
248- }
249- catch ( InvalidArgumentError )
245+ foreach ( var op in func_graph . get_operations ( ) )
250246 {
251- continue ;
247+ if ( op . op . type == "StatefulPartitionedCall" || op . op . type == "PartitionedCall" )
248+ {
249+ var function = renamed_functions [ op . op . node_def . Attr [ "f" ] . Func . Name ] ;
250+ op . op . _gradient_function = function . _get_gradient_function ( ) ;
251+ }
252252 }
253- if ( loaded_gradients . ContainsKey ( gradient_op_type ) )
253+ }
254+ else
255+ {
256+ foreach ( var op in func_graph . get_operations ( ) )
254257 {
255- var grad_fn = loaded_gradients [ gradient_op_type ] ;
256- grad_fn . NumPositionArgs = op . op . inputs . Length ;
257- grad_fn . ArgKeywords = op . op . inputs . _inputs . Select ( x => x . name ) ;
258+ if ( op . op . type == "StatefulPartitionedCall" || op . op . type == "PartitionedCall" )
259+ {
260+ var function = renamed_functions [ op . op . node_def . Attr [ "f" ] . Func . Name ] ;
261+ op . op . _gradient_function = function . _get_gradient_function ( ) ;
262+ }
263+ string gradient_op_type = null ;
264+ try
265+ {
266+ gradient_op_type = op . op . get_attr ( "_gradient_op_type" ) as string ;
267+ }
268+ catch ( InvalidArgumentError )
269+ {
270+ continue ;
271+ }
272+ if ( loaded_gradients . ContainsKey ( gradient_op_type ) )
273+ {
274+ var grad_fn = loaded_gradients [ gradient_op_type ] ;
275+ grad_fn . NumPositionArgs = op . op . inputs . Length ;
276+ grad_fn . ArgKeywords = op . op . inputs . _inputs . Select ( x => x . name ) ;
277+ }
258278 }
259279 }
260280 }
0 commit comments