Skip to content

Commit 3e07372

Browse files
committed
WhileContext AddBackpropAccumulator
1 parent efed258 commit 3e07372

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,84 @@ public override void AddOp(Operation op)
446446
return (total_iterations, next_n);
447447
}
448448

449+
/// <summary>
450+
/// Add an accumulation loop for every loop invariant.
451+
/// </summary>
452+
/// <param name="op">The Enter op for a loop invariant.</param>
453+
/// <param name="grad">The partial gradient of an iteration for a loop invariant.</param>
454+
/// <returns>The gradient for a loop invariant.</returns>
455+
public Tensor AddBackpropAccumulator(Operation op, Tensor grad)
456+
{
457+
Tensor acc = null;
458+
Exit();
459+
// Create a zeros tensor with the right shape for acc. If we don't
460+
// know the full shape statically, we will have to get the shape
461+
// dynamically from the forward inference. Getting the shape right
462+
// for the zeros is only needed for the base case when the loop exits
463+
// without running any iterations.
464+
var shape = grad.TensorShape;
465+
if (shape.is_fully_defined())
466+
{
467+
if (outer_context != null)
468+
outer_context.Enter();
469+
acc = constant_op.constant(0, grad.dtype, shape: shape, name: "b_acc");
470+
if (outer_context != null)
471+
outer_context.Exit();
472+
}
473+
else
474+
{
475+
var value = op.inputs[0];
476+
if(outer_context is WhileContext wc)
477+
{
478+
// We are in a nested while loop.
479+
var forward_ctxt = grad_state.forward_context;
480+
forward_ctxt.outer_context.Enter();
481+
var zeros_shape = array_ops.shape_internal(value, optimize: false);
482+
forward_ctxt.outer_context.Exit();
483+
var outer_grad_state = grad_state.outer_grad_state;
484+
var history_zeros_shape = outer_grad_state.AddForwardAccumulator(zeros_shape);
485+
outer_context.Enter();
486+
var real_shape = outer_grad_state.AddBackpropAccumulatedValue(
487+
history_zeros_shape, zeros_shape);
488+
acc = array_ops.zeros(real_shape, grad.dtype);
489+
outer_context.Exit();
490+
}
491+
else
492+
{
493+
if (outer_context != null)
494+
outer_context.Enter();
495+
var zeros_shape = array_ops.shape_internal(value, optimize: false);
496+
acc = array_ops.zeros(zeros_shape, grad.dtype);
497+
if (outer_context != null)
498+
outer_context.Exit();
499+
}
500+
throw new NotImplementedException("AddBackpropAccumulator");
501+
}
502+
503+
Enter();
504+
AddName(acc.name);
505+
var enter_acc = _Enter(
506+
acc,
507+
_name,
508+
is_constant: false,
509+
parallel_iterations: _parallel_iterations,
510+
name: "b_acc");
511+
loop_enters.append(enter_acc);
512+
var merge_acc = merge(new[] { enter_acc, enter_acc }, name: "b_acc")[0];
513+
514+
var switch_result = @switch(merge_acc, _pivot);
515+
var (switch_acc_false, switch_acc_true) = (switch_result[0], switch_result[1]);
516+
517+
var add_acc = math_ops.add(switch_acc_true, grad);
518+
var next_acc = _NextIteration(add_acc);
519+
merge_acc.op._update_input(1, next_acc);
520+
521+
var result_acc = exit(switch_acc_false, name: "b_acc");
522+
loop_exits.append(result_acc);
523+
ExitResult(new[] { result_acc });
524+
return result_acc;
525+
}
526+
449527
/// <summary>
450528
/// Add the backprop loop that controls the iterations.
451529
/// </summary>

0 commit comments

Comments
 (0)