Skip to content

Commit 59cbca5

Browse files
committed
ControlFlowState.PostProcessing
1 parent 3377f12 commit 59cbca5

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,35 @@ public Tensor ZerosLikeForExit(Tensor val)
290290

291291
public void PostProcessing()
292292
{
293-
throw new NotImplementedException("PostProcessing");
293+
foreach(var grad_state in _map.Values)
294+
{
295+
foreach(var b_merge in grad_state.switch_map.Values)
296+
{
297+
if(b_merge.op.inputs[0] == b_merge.op.inputs[1])
298+
{
299+
Tensor next_grad_val = null;
300+
// The value of this loop variable at iteration i+1 doesn't
301+
// depend on its value at iteration i. So use zeros as the
302+
// gradients for all iterations > 0.
303+
var dtype = b_merge.op.inputs[0].dtype;
304+
var shape = b_merge.op.inputs[0].TensorShape;
305+
if (shape.is_fully_defined())
306+
{
307+
grad_state.grad_context.Enter();
308+
// Create a zeros and use it for iterations > 0.
309+
var grad_val = constant_op.constant(0, dtype: dtype, shape: shape);
310+
next_grad_val = control_flow_ops._NextIteration(grad_val);
311+
grad_state.grad_context.Exit();
312+
}
313+
else
314+
{
315+
throw new NotImplementedException("PostProcessing shape is not fully defined.");
316+
}
317+
318+
b_merge.op._update_input(1, next_grad_val);
319+
}
320+
}
321+
}
294322
}
295323
}
296324
}

0 commit comments

Comments
 (0)