@@ -123,10 +123,7 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
123123 {
124124 // generate gradient subgraph for op.
125125 var op = queue . Dequeue ( ) ;
126- if ( op . name == "rnn/while/Exit" )
127- {
128126
129- }
130127 _maybe_colocate_with ( op , gradient_uid , colocate_gradients_with_ops ) ;
131128 {
132129 if ( loop_state != null )
@@ -136,15 +133,14 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
136133 loop_state . ExitGradWhileContext ( op , before : true ) ;
137134
138135 Tensor [ ] in_grads = null ;
136+ Func < Operation , Tensor [ ] , Tensor [ ] > grad_fn = null ;
139137 var is_partitioned_call = _IsPartitionedCall ( op ) ;
140138 var is_func_call = false ;
141139 var has_out_grads = out_grads . Exists ( x => x != null ) ;
142140 if ( has_out_grads && ! stop_ops . Contains ( op ) )
143141 {
144142 // A grad_fn must be defined, either as a function or as None
145143 // for ops that do not have gradients.
146-
147- Func < Operation , Tensor [ ] , Tensor [ ] > grad_fn = null ;
148144 try
149145 {
150146 grad_fn = ops . get_gradient_function ( op ) ;
@@ -167,61 +163,57 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
167163 throw new LookupError ( $ "No gradient defined for operation '{ op . name } ' (op type: { op . type } )") ;
168164 }
169165 }
166+ }
170167
171- if ( loop_state != null )
172- loop_state . EnterGradWhileContext ( op , before : false ) ;
168+ if ( loop_state != null )
169+ loop_state . EnterGradWhileContext ( op , before : false ) ;
173170
174- if ( ( is_func_call || grad_fn != null ) && has_out_grads )
171+ if ( ( is_func_call || grad_fn != null ) && has_out_grads )
172+ {
173+ // NOTE: If _AggregatedGrads didn't compute a value for the i'th
174+ // output, it means that the cost does not depend on output[i],
175+ // therefore dC/doutput[i] is 0.
176+ foreach ( var ( i , out_grad ) in enumerate ( out_grads ) )
175177 {
176- // NOTE: If _AggregatedGrads didn't compute a value for the i'th
177- // output, it means that the cost does not depend on output[i],
178- // therefore dC/doutput[i] is 0.
179- foreach ( var ( i , out_grad ) in enumerate ( out_grads ) )
180- {
181- if ( out_grad == null &&
182- ( grad_fn == null || _IsTrainable ( op . outputs [ i ] ) ) )
183- {
184- // Only trainable outputs or outputs for a function call that
185- // will use SymbolicGradient get a zero gradient. Gradient
186- // functions should ignore the gradient for other outputs.
187- if ( loop_state != null )
188- out_grads [ i ] = new List < Tensor > { loop_state . ZerosLike ( op , i ) } ;
189- else
190- out_grads [ i ] = new List < Tensor > { control_flow_ops . ZerosLikeOutsideLoop ( op , i ) } ;
191- }
192- }
193-
194- tf_with ( ops . name_scope ( op . name + "_grad" ) , scope1 =>
178+ if ( out_grad == null &&
179+ ( grad_fn == null || _IsTrainable ( op . outputs [ i ] ) ) )
195180 {
196- if ( grad_fn != null )
197- {
198- in_grads = _MaybeCompile ( grad_scope ,
199- op ,
200- out_grads . Where ( x => x != null ) . Select ( x => x [ 0 ] ) . ToArray ( ) ,
201- null ,
202- grad_fn ) ;
203- }
181+ // Only trainable outputs or outputs for a function call that
182+ // will use SymbolicGradient get a zero gradient. Gradient
183+ // functions should ignore the gradient for other outputs.
184+ if ( loop_state != null )
185+ out_grads [ i ] = new List < Tensor > { loop_state . ZerosLike ( op , i ) } ;
204186 else
205- {
206- throw new NotImplementedException ( "lambda: _SymGrad(op, out_grads)" ) ;
207- }
208- _VerifyGeneratedGradients ( in_grads , op ) ;
209- if ( gate_gradients && in_grads . Count ( x => x != null ) > 1 )
210- {
211- ops . _colocate_with_for_gradient ( null , gradient_uid , ignore_existing : true ) ;
212- in_grads = control_flow_ops . tuple ( in_grads ) ;
213- }
214- } ) ;
187+ out_grads [ i ] = new List < Tensor > { control_flow_ops . ZerosLikeOutsideLoop ( op , i ) } ;
188+ }
215189 }
216- else
190+
191+ tf_with ( ops . name_scope ( op . name + "_grad" ) , scope1 =>
217192 {
218- // If no grad_fn is defined or none of out_grads is available,
219- // just propagate a list of None backwards.
220- in_grads = new Tensor [ _NonEagerInputs ( op , xs ) . Count ( ) ] ;
221- }
193+ if ( grad_fn != null )
194+ {
195+ in_grads = _MaybeCompile ( grad_scope ,
196+ op ,
197+ out_grads . Where ( x => x != null ) . Select ( x => x [ 0 ] ) . ToArray ( ) ,
198+ null ,
199+ grad_fn ) ;
200+ }
201+ else
202+ {
203+ throw new NotImplementedException ( "lambda: _SymGrad(op, out_grads)" ) ;
204+ }
205+ _VerifyGeneratedGradients ( in_grads , op ) ;
206+ if ( gate_gradients && in_grads . Count ( x => x != null ) > 1 )
207+ {
208+ ops . _colocate_with_for_gradient ( null , gradient_uid , ignore_existing : true ) ;
209+ in_grads = control_flow_ops . tuple ( in_grads ) ;
210+ }
211+ } ) ;
222212 }
223213 else
224214 {
215+ // If no grad_fn is defined or none of out_grads is available,
216+ // just propagate a list of None backwards.
225217 in_grads = new Tensor [ _NonEagerInputs ( op , xs ) . Count ( ) ] ;
226218 }
227219
@@ -370,7 +362,16 @@ private static void _SetGrad(Dictionary<string, List<List<Tensor>>> grads, Tenso
370362 grads [ op . name ] = op_grads ;
371363 }
372364 var t_grads = op_grads [ t . value_index ] ;
373- t_grads . Add ( grad ) ;
365+ if ( t_grads . Count == 0 )
366+ t_grads . Add ( grad ) ;
367+ else
368+ op_grads [ t . value_index ] [ 0 ] = grad ;
369+
370+ /*if (control_flow_util.IsLoopSwitch(op) &&
371+ t_grads[0] == null)
372+ op_grads[t.value_index] = new List<Tensor> { grad };
373+ else
374+ t_grads.Add(grad);*/
374375 }
375376
376377 private static IEnumerable < Tensor > _NonEagerInputs ( Operation op , Tensor [ ] xs )
@@ -379,15 +380,19 @@ private static IEnumerable<Tensor> _NonEagerInputs(Operation op, Tensor[] xs)
379380 yield return op . inputs [ i ] ;
380381 }
381382
382- private static List < List < Tensor > > _AggregatedGrads ( Dictionary < string , List < List < Tensor > > > grads , Operation op , string gradient_uid , object loop_state , int aggregation_method = 0 )
383+ private static List < List < Tensor > > _AggregatedGrads ( Dictionary < string , List < List < Tensor > > > grads , Operation op , string gradient_uid ,
384+ ControlFlowState loop_state , int aggregation_method = 0 )
383385 {
384386 var out_grads = _GetGrads ( grads , op ) ;
385387
386388 foreach ( var ( i , out_grad ) in enumerate ( out_grads ) )
387389 {
388390 if ( loop_state != null )
389391 {
390-
392+ if ( out_grads . Count > 1 &&
393+ out_grads [ 1 ] . Count > 0 &&
394+ control_flow_util . IsLoopSwitch ( op ) )
395+ continue ;
391396 }
392397
393398 // Aggregate multiple gradients, and convert [] to None.
0 commit comments