@@ -17,6 +17,7 @@ limitations under the License.
1717using System ;
1818using System . Collections . Generic ;
1919using System . Linq ;
20+ using Tensorflow . Operations . ControlFlows ;
2021using static Tensorflow . Binding ;
2122
2223namespace Tensorflow
@@ -82,6 +83,7 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
8283 var stop_gradient_ops = stop_gradients . Select ( x => x . op ) . ToList ( ) ;
8384 var ( reachable_to_ops , pending_count , loop_state ) = _PendingCount ( to_ops , from_ops , colocate_gradients_with_ops , new List < object > ( ) , xs ) ;
8485
86+ // Add the initial gradients for the ys.
8587 foreach ( var ( y , grad_y ) in zip ( ys , grad_ys ) )
8688 _SetGrad ( grads , y , grad_y ) ;
8789
@@ -103,12 +105,25 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
103105 }
104106 }
105107
108+ if ( loop_state != null )
109+ {
110+ var loop_exits = loop_state . ProcessUnusedLoopExits ( pending_count , to_ops_set ) ;
111+ foreach ( var y in loop_exits )
112+ {
113+ //if(IsTrainable(y))
114+ throw new NotImplementedException ( "" ) ;
115+ }
116+ }
117+
106118 var stop_ops = _StopOps ( from_ops , stop_gradient_ops , pending_count , xs ) ;
107119 while ( queue . Count > 0 )
108120 {
109121 // generate gradient subgraph for op.
110122 var op = queue . Dequeue ( ) ;
123+ if ( op . name == "rnn/while/basic_rnn_cell/Tanh" )
124+ {
111125
126+ }
112127 _maybe_colocate_with ( op , gradient_uid , colocate_gradients_with_ops ) ;
113128 //if (loop_state != null)
114129 //loop_state.EnterGradWhileContext(op, before: true);
@@ -147,8 +162,8 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
147162 }
148163 }
149164
150- // if (loop_state)
151- // loop_state.EnterGradWhileContext(op, before: false);
165+ if ( loop_state != null )
166+ loop_state . EnterGradWhileContext ( op , before : false ) ;
152167
153168 if ( ( is_func_call || grad_fn != null ) && has_out_grads )
154169 {
@@ -164,7 +179,7 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
164179 // will use SymbolicGradient get a zero gradient. Gradient
165180 // functions should ignore the gradient for other outputs.
166181 if ( loop_state != null )
167- ;
182+ out_grads [ i ] = new List < Tensor > { loop_state . ZerosLike ( op , i ) } ;
168183 else
169184 out_grads [ i ] = new List < Tensor > { control_flow_ops . ZerosLikeOutsideLoop ( op , i ) } ;
170185 }
@@ -275,7 +290,7 @@ private static void _maybe_colocate_with(Operation op, string gradient_uid, bool
275290 /// <param name="colocate_gradients_with_ops"></param>
276291 /// <param name="func_graphs"></param>
277292 /// <param name="xs"></param>
278- private static ( Operation [ ] , Dictionary < string , int > , object ) _PendingCount ( List < Operation > to_ops , List < Operation > from_ops , bool colocate_gradients_with_ops , List < object > func_graphs , Tensor [ ] xs )
293+ private static ( Operation [ ] , Dictionary < string , int > , ControlFlowState ) _PendingCount ( List < Operation > to_ops , List < Operation > from_ops , bool colocate_gradients_with_ops , List < object > func_graphs , Tensor [ ] xs )
279294 {
280295 // Mark reachable ops from from_ops.
281296 var reached_ops = new List < Operation > ( ) ;
@@ -308,6 +323,7 @@ private static (Operation[], Dictionary<string, int>, object) _PendingCount(List
308323 // 'loop_state' is None if there are no while loops.
309324 var loop_state = control_flow_ops . MaybeCreateControlFlowState ( between_op_list , between_ops , colocate_gradients_with_ops ) ;
310325
326+ // Initialize pending count for between ops.
311327 var pending_count = new Dictionary < string , int > ( ) ;
312328 foreach ( var op in between_op_list )
313329 {
@@ -550,7 +566,7 @@ private static void _UpdatePendingAndEnqueueReady(Dictionary<string, List<List<T
550566 Operation op ,
551567 Queue < Operation > queue ,
552568 Dictionary < string , int > pending_count ,
553- object loop_state ,
569+ ControlFlowState loop_state ,
554570 Tensor [ ] xs )
555571 {
556572 foreach ( var x in _NonEagerInputs ( op , xs ) )
@@ -564,14 +580,49 @@ private static void _UpdatePendingAndEnqueueReady(Dictionary<string, List<List<T
564580
565581 if ( loop_state != null && ! ready )
566582 {
567-
583+ ready = pending_count [ x . op . name ] > 0 && control_flow_util . IsLoopSwitch ( x . op ) ;
568584 }
569585
570586 if ( ready )
571587 {
588+ // if x is an exit without real gradient, defer processing them.
572589 if ( control_flow_util . IsLoopExit ( x . op ) )
573590 {
574-
591+ var grad_state = loop_state . GetGradState ( x . op , before : false ) ;
592+ grad_state . deferred_exits . append ( x ) ;
593+ grad_state . pending_exits_count -= 1 ;
594+ // We now have all the exits so process them.
595+ if ( grad_state . pending_exits_count == 0 )
596+ {
597+ var has_not_none_grad = false ;
598+ foreach ( var y in grad_state . deferred_exits )
599+ {
600+ if ( _HasAnyNotNoneGrads ( grads , y . op ) )
601+ {
602+ has_not_none_grad = true ;
603+ queue . Enqueue ( y . op ) ;
604+ }
605+ else
606+ grad_state . unused_exits . append ( y ) ;
607+ }
608+ if ( has_not_none_grad )
609+ {
610+ // For an unused exit, if it has trainable outputs, backprop
611+ // a zero gradient. Otherwise, just ignore it.
612+ foreach ( var y in grad_state . unused_exits )
613+ {
614+ if ( IsTrainable ( y ) )
615+ _SetGrad ( grads , y , loop_state . ZerosLikeForExit ( y ) ) ;
616+ queue . Enqueue ( y . op ) ;
617+ }
618+ }
619+ else
620+ {
621+ // All exits are "unused" so use None as gradient.
622+ foreach ( var y in grad_state . unused_exits )
623+ queue . Enqueue ( y . op ) ;
624+ }
625+ }
575626 }
576627 else
577628 {
@@ -581,6 +632,32 @@ private static void _UpdatePendingAndEnqueueReady(Dictionary<string, List<List<T
581632 }
582633 }
583634
635+ private static bool IsTrainable ( Tensor tensor )
636+ {
637+ var dtype = tensor . dtype . as_base_dtype ( ) ;
638+ return new TF_DataType [ ] { dtypes . float16 , dtypes . float32 , dtypes . float64 ,
639+ dtypes . complex64 , dtypes . complex128 ,
640+ dtypes . resource , dtypes . variant } . Contains ( dtype ) ;
641+ }
642+
643+ /// <summary>
644+ /// Return true if op has real gradient.
645+ /// </summary>
646+ /// <param name="grads"></param>
647+ /// <param name="op"></param>
648+ /// <returns></returns>
649+ private static bool _HasAnyNotNoneGrads ( Dictionary < string , List < List < Tensor > > > grads , Operation op )
650+ {
651+ var out_grads = _GetGrads ( grads , op ) ;
652+ foreach ( var out_grad in out_grads )
653+ {
654+ if ( out_grad . Exists ( g => g != null ) )
655+ return true ;
656+ }
657+ return false ;
658+ }
659+
660+
584661 private static Tensor [ ] _MaybeCompile ( string scope , Operation op , Tensor [ ] out_grads , Action func , Func < Operation , Tensor [ ] , Tensor [ ] > grad_fn )
585662 {
586663 scope = scope . EndsWith ( "/" ) ? scope . Substring ( 0 , scope . Length - 1 ) : scope ;
@@ -589,6 +666,9 @@ private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_g
589666
590667 private static void _VerifyGeneratedGradients ( Tensor [ ] grads , Operation op )
591668 {
669+ if ( op . type == "While" || op . type == "StatelessWhile" )
670+ return ;
671+
592672 if ( grads . Count ( ) != op . inputs . _inputs . Count ( ) )
593673 throw new ValueError ( $ "Num gradients { grads . Length } generated for op { op . node_def } do not match num " +
594674 $ "inputs { op . inputs . _inputs . Count ( ) } ") ;
0 commit comments