@@ -48,7 +48,7 @@ public void Record(Tensors flat_outputs, Tensors inference_args)
4848 getBackwardFunction : ( ) => backward_function ) ;
4949 }
5050
51- ( BackwardFunction , Tensors ) _wrap_backward_function ( FuncGraph forward_graph , ConcreteFunction backward , Tensors flat_outputs )
51+ ( BackwardFunction , Tensors ) _wrap_backward_function ( FuncGraph forward_graph , ConcreteFunction backward , Tensors outputs )
5252 {
5353 BackwardFunction _backward_function_wrapper = ( output_grads , unneeded_gradients ) =>
5454 {
@@ -61,10 +61,11 @@ public void Record(Tensors flat_outputs, Tensors inference_args)
6161 processed_args . add ( arg ) ;
6262 input_index += 1 ;
6363 }
64- return output_grads ; // backward.Invoke(processed_args.ToArray());
64+
65+ return backward . CallFlat ( processed_args . ToArray ( ) , outputs ) ;
6566 } ;
6667
67- return ( _backward_function_wrapper , flat_outputs ) ;
68+ return ( _backward_function_wrapper , outputs ) ;
6869 }
6970
7071 protected ( EagerDefinedFunction , FuncGraph , ConcreteFunction , List < int > , int )
@@ -82,24 +83,23 @@ public void Record(Tensors flat_outputs, Tensors inference_args)
8283 }
8384
8485 var gradients_wrt_outputs = new List < Tensor > ( ) ;
85- var backwards_graph = new FuncGraph ( $ "{ _BACKWARD_PREFIX } { _func_graph . FuncName } _{ ops . uid ( ) } ") ;
86+ var backwards_graph = new FuncGraph ( $ "{ _BACKWARD_PREFIX } _{ ops . uid ( ) } ") ;
8687 foreach ( var output in trainable_outputs )
8788 gradients_wrt_outputs . Add ( tf . placeholder ( output . dtype , output . shape ) ) ;
8889 var gradients_wrt_inputs = gradients_util . _GradientsHelper ( trainable_outputs . ToArray ( ) ,
8990 _func_graph . Inputs ,
9091 grad_ys : gradients_wrt_outputs . ToArray ( ) ,
9192 src_graph : _func_graph ) ;
9293
93- tf . Context . restore_mode ( ) ;
94-
95- var forward_function_name = $ "{ _FORWARD_PREFIX } { _func_graph . FuncName } _{ ops . uid ( ) } ";
94+ var forward_function_name = $ "{ _FORWARD_PREFIX } _{ ops . uid ( ) } ";
9695 var backward_function_attr = new Dictionary < string , string > ( ) ;
9796 backward_function_attr [ FORWARD_FUNCTION_ATTRIBUTE_NAME ] = forward_function_name ;
97+ gradients_wrt_outputs . append ( backwards_graph . internal_captures ( ) ) ;
9898 backwards_graph . Inputs = gradients_wrt_outputs ;
9999 backwards_graph . Outputs = gradients_wrt_inputs ;
100100
101101 var backward_function = new ConcreteFunction ( backwards_graph , backward_function_attr ) ;
102-
102+
103103 var forward_function_attr = new Dictionary < string , string > ( ) ;
104104 forward_function_attr [ BACKWARD_FUNCTION_ATTRIBUTE_NAME ] = backward_function . Name ;
105105 var forward_function = new EagerDefinedFunction ( forward_function_name , _func_graph ,
0 commit comments