Skip to content

Commit c8a61b2

Browse files
committed
_UpdatePendingAndEnqueueReady
1 parent d2de8be commit c8a61b2

File tree

1 file changed

+87
-7
lines changed

1 file changed

+87
-7
lines changed

src/TensorFlowNET.Core/Gradients/gradients_util.cs

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
using System;
1818
using System.Collections.Generic;
1919
using System.Linq;
20+
using Tensorflow.Operations.ControlFlows;
2021
using static Tensorflow.Binding;
2122

2223
namespace Tensorflow
@@ -82,6 +83,7 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
8283
var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList();
8384
var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs);
8485

86+
// Add the initial gradients for the ys.
8587
foreach (var (y, grad_y) in zip(ys, grad_ys))
8688
_SetGrad(grads, y, grad_y);
8789

@@ -103,12 +105,25 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
103105
}
104106
}
105107

108+
if(loop_state != null)
109+
{
110+
var loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set);
111+
foreach(var y in loop_exits)
112+
{
113+
//if(IsTrainable(y))
114+
throw new NotImplementedException("");
115+
}
116+
}
117+
106118
var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs);
107119
while (queue.Count > 0)
108120
{
109121
// generate gradient subgraph for op.
110122
var op = queue.Dequeue();
123+
if(op.name == "rnn/while/basic_rnn_cell/Tanh")
124+
{
111125

126+
}
112127
_maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops);
113128
//if (loop_state != null)
114129
//loop_state.EnterGradWhileContext(op, before: true);
@@ -147,8 +162,8 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
147162
}
148163
}
149164

150-
// if (loop_state)
151-
//loop_state.EnterGradWhileContext(op, before: false);
165+
if (loop_state != null)
166+
loop_state.EnterGradWhileContext(op, before: false);
152167

153168
if ((is_func_call || grad_fn != null) && has_out_grads)
154169
{
@@ -164,7 +179,7 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
164179
// will use SymbolicGradient get a zero gradient. Gradient
165180
// functions should ignore the gradient for other outputs.
166181
if (loop_state != null)
167-
;
182+
out_grads[i] = new List<Tensor> { loop_state.ZerosLike(op, i) };
168183
else
169184
out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) };
170185
}
@@ -275,7 +290,7 @@ private static void _maybe_colocate_with(Operation op, string gradient_uid, bool
275290
/// <param name="colocate_gradients_with_ops"></param>
276291
/// <param name="func_graphs"></param>
277292
/// <param name="xs"></param>
278-
private static (Operation[], Dictionary<string, int>, object) _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, Tensor[] xs)
293+
private static (Operation[], Dictionary<string, int>, ControlFlowState) _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, Tensor[] xs)
279294
{
280295
// Mark reachable ops from from_ops.
281296
var reached_ops = new List<Operation>();
@@ -308,6 +323,7 @@ private static (Operation[], Dictionary<string, int>, object) _PendingCount(List
308323
// 'loop_state' is None if there are no while loops.
309324
var loop_state = control_flow_ops.MaybeCreateControlFlowState(between_op_list, between_ops, colocate_gradients_with_ops);
310325

326+
// Initialize pending count for between ops.
311327
var pending_count = new Dictionary<string, int>();
312328
foreach (var op in between_op_list)
313329
{
@@ -550,7 +566,7 @@ private static void _UpdatePendingAndEnqueueReady(Dictionary<string, List<List<T
550566
Operation op,
551567
Queue<Operation> queue,
552568
Dictionary<string, int> pending_count,
553-
object loop_state,
569+
ControlFlowState loop_state,
554570
Tensor[] xs)
555571
{
556572
foreach (var x in _NonEagerInputs(op, xs))
@@ -564,14 +580,49 @@ private static void _UpdatePendingAndEnqueueReady(Dictionary<string, List<List<T
564580

565581
if (loop_state != null && !ready)
566582
{
567-
583+
ready = pending_count[x.op.name] > 0 && control_flow_util.IsLoopSwitch(x.op);
568584
}
569585

570586
if (ready)
571587
{
588+
// if x is an exit without real gradient, defer processing them.
572589
if (control_flow_util.IsLoopExit(x.op))
573590
{
574-
591+
var grad_state = loop_state.GetGradState(x.op, before: false);
592+
grad_state.deferred_exits.append(x);
593+
grad_state.pending_exits_count -= 1;
594+
// We now have all the exits so process them.
595+
if (grad_state.pending_exits_count == 0)
596+
{
597+
var has_not_none_grad = false;
598+
foreach(var y in grad_state.deferred_exits)
599+
{
600+
if (_HasAnyNotNoneGrads(grads, y.op))
601+
{
602+
has_not_none_grad = true;
603+
queue.Enqueue(y.op);
604+
}
605+
else
606+
grad_state.unused_exits.append(y);
607+
}
608+
if (has_not_none_grad)
609+
{
610+
// For an unused exit, if it has trainable outputs, backprop
611+
// a zero gradient. Otherwise, just ignore it.
612+
foreach (var y in grad_state.unused_exits)
613+
{
614+
if (IsTrainable(y))
615+
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y));
616+
queue.Enqueue(y.op);
617+
}
618+
}
619+
else
620+
{
621+
// All exits are "unused" so use None as gradient.
622+
foreach (var y in grad_state.unused_exits)
623+
queue.Enqueue(y.op);
624+
}
625+
}
575626
}
576627
else
577628
{
@@ -581,6 +632,32 @@ private static void _UpdatePendingAndEnqueueReady(Dictionary<string, List<List<T
581632
}
582633
}
583634

635+
private static bool IsTrainable(Tensor tensor)
636+
{
637+
var dtype = tensor.dtype.as_base_dtype();
638+
return new TF_DataType[] { dtypes.float16, dtypes.float32, dtypes.float64,
639+
dtypes.complex64, dtypes.complex128,
640+
dtypes.resource, dtypes.variant}.Contains(dtype);
641+
}
642+
643+
/// <summary>
644+
/// Return true if op has real gradient.
645+
/// </summary>
646+
/// <param name="grads"></param>
647+
/// <param name="op"></param>
648+
/// <returns></returns>
649+
private static bool _HasAnyNotNoneGrads(Dictionary<string, List<List<Tensor>>> grads, Operation op)
650+
{
651+
var out_grads = _GetGrads(grads, op);
652+
foreach(var out_grad in out_grads)
653+
{
654+
if (out_grad.Exists(g => g != null))
655+
return true;
656+
}
657+
return false;
658+
}
659+
660+
584661
private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn)
585662
{
586663
scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope;
@@ -589,6 +666,9 @@ private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_g
589666

590667
private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op)
591668
{
669+
if (op.type == "While" || op.type == "StatelessWhile")
670+
return;
671+
592672
if (grads.Count() != op.inputs._inputs.Count())
593673
throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " +
594674
$"inputs {op.inputs._inputs.Count()}");

0 commit comments

Comments
 (0)