@@ -55,6 +55,9 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
5555 * is more than one.
5656 **/
5757 var grads = new Dictionary < string , List < List < Tensor > > > ( ) ;
58+ Operation [ ] reachable_to_ops = null ;
59+ ControlFlowState loop_state = null ;
60+ Dictionary < string , int > pending_count = null ;
5861
5962 tf_with ( ops . name_scope ( name , "gradients" ,
6063 values : ys . Concat ( xs ) . Concat ( stop_gradients ) . Concat ( grad_ys ) ) , scope =>
@@ -81,7 +84,7 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
8184 var to_ops = ys . Select ( x => x . op ) . ToList ( ) ;
8285 var from_ops = xs . Select ( x => x . op ) . ToList ( ) ;
8386 var stop_gradient_ops = stop_gradients . Select ( x => x . op ) . ToList ( ) ;
84- var ( reachable_to_ops , pending_count , loop_state ) = _PendingCount ( to_ops , from_ops , colocate_gradients_with_ops , new List < object > ( ) , xs ) ;
87+ ( reachable_to_ops , pending_count , loop_state ) = _PendingCount ( to_ops , from_ops , colocate_gradients_with_ops , new List < object > ( ) , xs ) ;
8588
8689 // Add the initial gradients for the ys.
8790 foreach ( var ( y , grad_y ) in zip ( ys , grad_ys ) )
@@ -120,126 +123,135 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
120123 {
121124 // generate gradient subgraph for op.
122125 var op = queue . Dequeue ( ) ;
123- if ( op . name == "rnn/while/basic_rnn_cell/Tanh " )
126+ if ( op . name == "rnn/while/Exit " )
124127 {
125128
126129 }
127130 _maybe_colocate_with ( op , gradient_uid , colocate_gradients_with_ops ) ;
128- //if (loop_state != null)
129- //loop_state.EnterGradWhileContext(op, before: true);
130- var out_grads = _AggregatedGrads ( grads , op , gradient_uid , loop_state , aggregation_method ) ;
131-
132- Tensor [ ] in_grads = null ;
133- var is_partitioned_call = _IsPartitionedCall ( op ) ;
134- var is_func_call = false ;
135- var has_out_grads = out_grads . Exists ( x => x != null ) ;
136- if ( has_out_grads && ! stop_ops . Contains ( op ) )
137131 {
138- // A grad_fn must be defined, either as a function or as None
139- // for ops that do not have gradients.
132+ if ( loop_state != null )
133+ loop_state . EnterGradWhileContext ( op , before : true ) ;
134+ var out_grads = _AggregatedGrads ( grads , op , gradient_uid , loop_state , aggregation_method ) ;
135+ if ( loop_state != null )
136+ loop_state . ExitGradWhileContext ( op , before : true ) ;
140137
141- Func < Operation , Tensor [ ] , Tensor [ ] > grad_fn = null ;
142- try
143- {
144- grad_fn = ops . get_gradient_function ( op ) ;
145- }
146- catch ( LookupError )
138+ Tensor [ ] in_grads = null ;
139+ var is_partitioned_call = _IsPartitionedCall ( op ) ;
140+ var is_func_call = false ;
141+ var has_out_grads = out_grads . Exists ( x => x != null ) ;
142+ if ( has_out_grads && ! stop_ops . Contains ( op ) )
147143 {
148- if ( is_func_call )
144+ // A grad_fn must be defined, either as a function or as None
145+ // for ops that do not have gradients.
146+
147+ Func < Operation , Tensor [ ] , Tensor [ ] > grad_fn = null ;
148+ try
149149 {
150- if ( is_partitioned_call )
150+ grad_fn = ops . get_gradient_function ( op ) ;
151+ }
152+ catch ( LookupError )
153+ {
154+ if ( is_func_call )
151155 {
156+ if ( is_partitioned_call )
157+ {
158+
159+ }
160+ else
161+ {
152162
163+ }
153164 }
154165 else
155166 {
156-
167+ throw new LookupError ( $ "No gradient defined for operation ' { op . name } ' (op type: { op . type } )" ) ;
157168 }
158169 }
159- else
160- {
161- throw new LookupError ( $ "No gradient defined for operation '{ op . name } ' (op type: { op . type } )") ;
162- }
163- }
164170
165- if ( loop_state != null )
166- loop_state . EnterGradWhileContext ( op , before : false ) ;
171+ if ( loop_state != null )
172+ loop_state . EnterGradWhileContext ( op , before : false ) ;
167173
168- if ( ( is_func_call || grad_fn != null ) && has_out_grads )
169- {
170- // NOTE: If _AggregatedGrads didn't compute a value for the i'th
171- // output, it means that the cost does not depend on output[i],
172- // therefore dC/doutput[i] is 0.
173- foreach ( var ( i , out_grad ) in enumerate ( out_grads ) )
174+ if ( ( is_func_call || grad_fn != null ) && has_out_grads )
174175 {
175- if ( out_grad == null &&
176- ( grad_fn == null || _IsTrainable ( op . outputs [ i ] ) ) )
176+ // NOTE: If _AggregatedGrads didn't compute a value for the i'th
177+ // output, it means that the cost does not depend on output[i],
178+ // therefore dC/doutput[i] is 0.
179+ foreach ( var ( i , out_grad ) in enumerate ( out_grads ) )
177180 {
178- // Only trainable outputs or outputs for a function call that
179- // will use SymbolicGradient get a zero gradient. Gradient
180- // functions should ignore the gradient for other outputs.
181- if ( loop_state != null )
182- out_grads [ i ] = new List < Tensor > { loop_state . ZerosLike ( op , i ) } ;
183- else
184- out_grads [ i ] = new List < Tensor > { control_flow_ops . ZerosLikeOutsideLoop ( op , i ) } ;
181+ if ( out_grad == null &&
182+ ( grad_fn == null || _IsTrainable ( op . outputs [ i ] ) ) )
183+ {
184+ // Only trainable outputs or outputs for a function call that
185+ // will use SymbolicGradient get a zero gradient. Gradient
186+ // functions should ignore the gradient for other outputs.
187+ if ( loop_state != null )
188+ out_grads [ i ] = new List < Tensor > { loop_state . ZerosLike ( op , i ) } ;
189+ else
190+ out_grads [ i ] = new List < Tensor > { control_flow_ops . ZerosLikeOutsideLoop ( op , i ) } ;
191+ }
185192 }
186- }
187193
188- tf_with ( ops . name_scope ( op . name + "_grad" ) , scope1 =>
189- {
190- if ( grad_fn != null )
194+ tf_with ( ops . name_scope ( op . name + "_grad" ) , scope1 =>
191195 {
192- in_grads = _MaybeCompile ( grad_scope ,
193- op ,
194- out_grads . Where ( x => x != null ) . Select ( x => x [ 0 ] ) . ToArray ( ) ,
195- null ,
196- grad_fn ) ;
197- }
198- else
199- {
200- throw new NotImplementedException ( "lambda: _SymGrad(op, out_grads)" ) ;
201- }
202- _VerifyGeneratedGradients ( in_grads , op ) ;
203- if ( gate_gradients && in_grads . Count ( x => x != null ) > 1 )
204- {
205- ops . _colocate_with_for_gradient ( null , gradient_uid , ignore_existing : true ) ;
206- in_grads = control_flow_ops . tuple ( in_grads ) ;
207- }
208- } ) ;
196+ if ( grad_fn != null )
197+ {
198+ in_grads = _MaybeCompile ( grad_scope ,
199+ op ,
200+ out_grads . Where ( x => x != null ) . Select ( x => x [ 0 ] ) . ToArray ( ) ,
201+ null ,
202+ grad_fn ) ;
203+ }
204+ else
205+ {
206+ throw new NotImplementedException ( "lambda: _SymGrad(op, out_grads)" ) ;
207+ }
208+ _VerifyGeneratedGradients ( in_grads , op ) ;
209+ if ( gate_gradients && in_grads . Count ( x => x != null ) > 1 )
210+ {
211+ ops . _colocate_with_for_gradient ( null , gradient_uid , ignore_existing : true ) ;
212+ in_grads = control_flow_ops . tuple ( in_grads ) ;
213+ }
214+ } ) ;
215+ }
216+ else
217+ {
218+ // If no grad_fn is defined or none of out_grads is available,
219+ // just propagate a list of None backwards.
220+ in_grads = new Tensor [ _NonEagerInputs ( op , xs ) . Count ( ) ] ;
221+ }
209222 }
210223 else
211224 {
212- // If no grad_fn is defined or none of out_grads is available,
213- // just propagate a list of None backwards.
214225 in_grads = new Tensor [ _NonEagerInputs ( op , xs ) . Count ( ) ] ;
215226 }
216- }
217- else
218- {
219- in_grads = new Tensor [ _NonEagerInputs ( op , xs ) . Count ( ) ] ;
220- }
221227
222- var inputs = _NonEagerInputs ( op , xs ) . ToList ( ) ;
223- foreach ( var ( t_in , in_grad ) in zip ( inputs , in_grads ) )
224- {
225- if ( in_grad != null )
228+ var inputs = _NonEagerInputs ( op , xs ) . ToList ( ) ;
229+ foreach ( var ( t_in , in_grad ) in zip ( inputs , in_grads ) )
226230 {
227- if ( ! ( in_grad is null ) &&
228- in_grad . Tag == null && // maybe a IndexedSlice
229- t_in . dtype != TF_DataType . TF_RESOURCE )
231+ if ( in_grad != null )
230232 {
231- in_grad . set_shape ( t_in . TensorShape ) ;
232- }
233+ if ( ! ( in_grad is null ) &&
234+ in_grad . Tag == null && // maybe a IndexedSlice
235+ t_in . dtype != TF_DataType . TF_RESOURCE )
236+ {
237+ in_grad . set_shape ( t_in . TensorShape ) ;
238+ }
233239
234- _SetGrad ( grads , t_in , in_grad ) ;
240+ _SetGrad ( grads , t_in , in_grad ) ;
241+ }
235242 }
236- }
237243
244+ if ( loop_state != null )
245+ loop_state . ExitGradWhileContext ( op , before : false ) ;
246+ }
247+
238248 // Update pending count for the inputs of op and enqueue ready ops.
239249 _UpdatePendingAndEnqueueReady ( grads , op , queue , pending_count , loop_state , xs ) ;
240250 }
241251 } ) ;
242252
253+ if ( loop_state != null )
254+ loop_state . PostProcessing ( ) ;
243255 return xs . Select ( x => _GetGrad ( grads , x ) ) . ToArray ( ) ;
244256 }
245257
0 commit comments