@@ -49,24 +49,61 @@ public void Record(Tensors flat_outputs, Tensors inference_args)
4949 getBackwardFunction : ( ) => backward_function ) ;
5050 }
5151
52+ /// <summary>
53+ /// Create a backward function given `outputs` from the forward function.
54+ /// </summary>
55+ /// <param name="forward_graph"></param>
56+ /// <param name="backward"></param>
57+ /// <param name="outputs"></param>
58+ /// <returns></returns>
5259 ( BackwardFunction , Tensors ) _wrap_backward_function ( FuncGraph forward_graph , ConcreteFunction backward , Tensors outputs )
5360 {
54- BackwardFunction _backward_function_wrapper = ( output_grads , unneeded_gradients ) =>
61+ var capture_mapping = new Dictionary < long , Tensor > ( ) ;
62+ foreach ( var ( i , output ) in enumerate ( outputs ) )
63+ capture_mapping [ forward_graph . Outputs [ i ] . Id ] = output ;
64+
65+ var remapped_captures = new Tensors ( ) ;
66+ foreach ( var capture in backward . CapturedInputs )
67+ {
68+ if ( capture_mapping . ContainsKey ( capture . Id ) )
69+ remapped_captures . Add ( capture_mapping [ capture . Id ] ) ;
70+ }
71+
72+ var backward_function_inputs = backward . Inputs . Length - backward . CapturedInputs . Length ;
73+ var recorded_outputs = new Tensors ( ) ;
74+ var relevant_outputs = outputs ;
75+ var trainable_recorded_outputs = 0 ;
76+ var skip_positions = new List < int > ( ) ;
77+ foreach ( var ( output_index , output ) in enumerate ( relevant_outputs ) )
78+ {
79+ if ( trainable_recorded_outputs < backward_function_inputs )
80+ recorded_outputs . Add ( output ) ;
81+ if ( gradients_util . IsTrainable ( output ) )
82+ trainable_recorded_outputs += 1 ;
83+ else
84+ skip_positions . Add ( output_index ) ;
85+ }
86+
87+ BackwardFunction _backward_function_wrapper = ( args , unneeded_gradients ) =>
5588 {
56- var processed_args = new List < Tensor > ( ) ;
89+ var processed_args = new Tensors ( ) ;
5790 var input_index = 0 ;
58- foreach ( var ( output_index , arg ) in enumerate ( output_grads ) )
91+ foreach ( var ( output_index , arg ) in enumerate ( args ) )
5992 {
60- if ( arg is null )
93+ if ( skip_positions . Contains ( output_index ) )
94+ continue ;
95+ if ( arg == null )
6196 throw new NotImplementedException ( "" ) ;
62- processed_args . add ( arg ) ;
97+ processed_args . Add ( arg ) ;
6398 input_index += 1 ;
99+ if ( input_index >= backward_function_inputs )
100+ break ;
64101 }
65102 tf . Logger . Debug ( $ "Invoke backward function: { backward . Name } ") ;
66- return backward . CallFlat ( processed_args . ToArray ( ) , outputs ) ;
103+ return backward . CallFlat ( processed_args , remapped_captures ) ;
67104 } ;
68105
69- return ( _backward_function_wrapper , outputs ) ;
106+ return ( _backward_function_wrapper , recorded_outputs ) ;
70107 }
71108
72109 protected ( EagerDefinedFunction , FuncGraph , ConcreteFunction , List < int > , int )
@@ -103,7 +140,7 @@ public void Record(Tensors flat_outputs, Tensors inference_args)
103140 }
104141 backwards_graph . Exit ( ) ;
105142
106- var forward_function_name = $ "{ _FORWARD_PREFIX } _{ ops . uid ( ) } ";
143+ var forward_function_name = $ "{ _FORWARD_PREFIX } _{ _func_graph . FuncName } _ { ops . uid ( ) } ";
107144 var backward_function_attr = new Dictionary < string , string > ( ) ;
108145 backward_function_attr [ FORWARD_FUNCTION_ATTRIBUTE_NAME ] = forward_function_name ;
109146 gradients_wrt_outputs . append ( backwards_graph . internal_captures ) ;
0 commit comments