@@ -53,10 +53,76 @@ public static void _GradientsHelper(object ys,
5353 using ( var namescope = new ops . name_scope < Tensor > ( name , "gradients" , values : all ) )
5454 {
5555 grad_scope = namescope ;
56+ // Get a uid for this call to gradients that can be used to help
57+ // cluster ops for compilation.
58+ var gradient_uid = ops . get_default_graph ( ) . unique_name ( "uid" ) ;
5659
60+ var to_ops = ys1 . Select ( x => x . op ) . ToList ( ) ;
61+ var from_ops = xs1 . Select ( x => x . op ) . ToList ( ) ;
62+ var stop_gradient_ops = stop_gradients1 . Select ( x => x . op ) . ToList ( ) ;
63+ _PendingCount ( to_ops , from_ops , colocate_gradients_with_ops , new List < object > ( ) , xs1 ) ;
5764 }
5865 }
5966
67+ /// <summary>
68+ ///
69+ /// </summary>
70+ /// <param name="grad_ys"></param>
71+ /// <param name="ys"></param>
72+ /// <param name="colocate_gradients_with_ops"></param>
73+ /// <param name="gradient_uid"></param>
74+ private void _DefaultGradYs ( List < Tensor > grad_ys , List < Tensor > ys , bool colocate_gradients_with_ops , string gradient_uid = "__unsupported__" )
75+ {
76+
77+ }
78+
79+ /// <summary>
80+ /// Initialize the pending count for ops between two lists of Operations.
81+ /// 'pending_count[op]' indicates the number of backprop inputs
82+ /// to this operation.
83+ /// </summary>
84+ /// <param name="to_ops"></param>
85+ /// <param name="from_ops"></param>
86+ /// <param name="colocate_gradients_with_ops"></param>
87+ /// <param name="func_graphs"></param>
88+ /// <param name="xs"></param>
89+ private static void _PendingCount ( List < Operation > to_ops , List < Operation > from_ops , bool colocate_gradients_with_ops , List < object > func_graphs , List < Tensor > xs )
90+ {
91+ List < Operation > reached_ops = new List < Operation > ( ) ;
92+ _MarkReachedOps ( from_ops , reached_ops , func_graphs ) ;
93+ }
94+
95+ /// <summary>
96+ /// Mark all ops reached from "from_ops"
97+ /// </summary>
98+ /// <param name="from_ops"></param>
99+ /// <param name="reached_ops"></param>
100+ /// <param name="func_graphs"></param>
101+ private static void _MarkReachedOps ( List < Operation > from_ops , List < Operation > reached_ops , List < object > func_graphs )
102+ {
103+ foreach ( var op in from_ops )
104+ {
105+ reached_ops . Add ( op ) ;
106+ foreach ( var output in op . outputs )
107+ {
108+ reached_ops . AddRange ( _Consumers ( output , func_graphs ) ) ;
109+ }
110+ }
111+
112+ reached_ops . Reverse ( ) ;
113+ }
114+
115+ /// <summary>
116+ /// Returns the consumers of t, crossing closure boundaries where necessary.
117+ /// </summary>
118+ /// <param name="t"></param>
119+ /// <param name="func_graphs"></param>
120+ private static List < Operation > _Consumers ( Tensor t , List < object > func_graphs )
121+ {
122+ var consumers = t . consumers ( ) ;
123+ return consumers ;
124+ }
125+
60126 private static List < Tensor > _AsList ( object ys )
61127 {
62128 List < Tensor > ret = null ;
0 commit comments