@@ -25,6 +25,7 @@ public abstract class TapeGradientFunctions
2525 protected List < int > _forwardprop_output_indices ;
2626 protected int _num_forwardprop_outputs ;
2727 protected ConcreteFunction _backward ;
28+ BackwardFunction _backward_function_wrapper ;
2829
2930 public TapeGradientFunctions ( FuncGraph func_graph ,
3031 bool need_gradients_for_jvps )
@@ -58,60 +59,66 @@ public void Record(Tensors flat_outputs, Tensors inference_args)
5859 /// <returns></returns>
5960 ( BackwardFunction , Tensors ) _wrap_backward_function ( FuncGraph forward_graph , ConcreteFunction backward , Tensors outputs )
6061 {
61- var capture_mapping = new Dictionary < long , Tensor > ( ) ;
62- foreach ( var ( i , output ) in enumerate ( outputs ) )
63- capture_mapping [ forward_graph . Outputs [ i ] . Id ] = output ;
64-
65- var remapped_captures = new Tensors ( ) ;
66- foreach ( var capture in backward . CapturedInputs )
67- {
68- if ( capture_mapping . ContainsKey ( capture . Id ) )
69- remapped_captures . Add ( capture_mapping [ capture . Id ] ) ;
70- }
71-
7262 var backward_function_inputs = backward . Inputs . Length - backward . CapturedInputs . Length ;
7363 var recorded_outputs = new Tensors ( ) ;
74- var relevant_outputs = outputs ;
7564 var trainable_recorded_outputs = 0 ;
76- var skip_positions = new List < int > ( ) ;
77- foreach ( var ( output_index , output ) in enumerate ( relevant_outputs ) )
65+ foreach ( var ( output_index , output ) in enumerate ( outputs ) )
7866 {
7967 if ( trainable_recorded_outputs < backward_function_inputs )
8068 recorded_outputs . Add ( output ) ;
8169 if ( gradients_util . IsTrainable ( output ) )
8270 trainable_recorded_outputs += 1 ;
83- else
84- skip_positions . Add ( output_index ) ;
8571 }
8672
87- BackwardFunction _backward_function_wrapper = ( args , unneeded_gradients ) =>
73+ if ( _backward_function_wrapper == null )
8874 {
89- var processed_args = new Tensors ( ) ;
90- var input_index = 0 ;
91- foreach ( var ( output_index , arg ) in enumerate ( args ) )
75+ var capture_mapping = new Dictionary < long , Tensor > ( ) ;
76+ foreach ( var ( i , output ) in enumerate ( outputs ) )
77+ capture_mapping [ forward_graph . Outputs [ i ] . Id ] = output ;
78+
79+ var remapped_captures = new Tensors ( ) ;
80+ foreach ( var capture in backward . CapturedInputs )
9281 {
93- if ( skip_positions . Contains ( output_index ) )
94- continue ;
95- if ( arg == null )
96- throw new NotImplementedException ( "" ) ;
97- processed_args . Add ( arg ) ;
98- input_index += 1 ;
99- if ( input_index >= backward_function_inputs )
100- break ;
82+ if ( capture_mapping . ContainsKey ( capture . Id ) )
83+ remapped_captures . Add ( capture_mapping [ capture . Id ] ) ;
10184 }
10285
103- tf . Logger . Debug ( $ "Invoke backward function: { backward . Name } ") ;
104- var gradients = backward . CallFlat ( processed_args , remapped_captures ) ;
105-
106- foreach ( var unneeded_gradient_index in unneeded_gradients )
86+ var skip_positions = new List < int > ( ) ;
87+ foreach ( var ( output_index , output ) in enumerate ( outputs ) )
10788 {
108- var index = Convert . ToInt32 ( unneeded_gradient_index ) ;
109- if ( gradients . Length <= index )
110- gradients . Insert ( index , null ) ;
89+ if ( ! gradients_util . IsTrainable ( output ) )
90+ skip_positions . Add ( output_index ) ;
11191 }
11292
113- return gradients ;
114- } ;
93+ _backward_function_wrapper = ( args , unneeded_gradients ) =>
94+ {
95+ var processed_args = new Tensors ( ) ;
96+ var input_index = 0 ;
97+ foreach ( var ( output_index , arg ) in enumerate ( args ) )
98+ {
99+ if ( skip_positions . Contains ( output_index ) )
100+ continue ;
101+ if ( arg == null )
102+ throw new NotImplementedException ( "" ) ;
103+ processed_args . Add ( arg ) ;
104+ input_index += 1 ;
105+ if ( input_index >= backward_function_inputs )
106+ break ;
107+ }
108+
109+ tf . Logger . Debug ( $ "Invoke backward function: { backward . Name } ") ;
110+ var gradients = backward . CallFlat ( processed_args , remapped_captures ) ;
111+
112+ foreach ( var unneeded_gradient_index in unneeded_gradients )
113+ {
114+ var index = Convert . ToInt32 ( unneeded_gradient_index ) ;
115+ if ( gradients . Length <= index )
116+ gradients . Insert ( index , null ) ;
117+ }
118+
119+ return gradients ;
120+ } ;
121+ }
115122
116123 return ( _backward_function_wrapper , recorded_outputs ) ;
117124 }
0 commit comments