Skip to content

Commit 07ddb74

Browse files
committed
GradLoopState.AddBackpropAccumulatedValue
1 parent 3e07372 commit 07ddb74

File tree

1 file changed

+67
-35
lines changed

1 file changed

+67
-35
lines changed

src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs

Lines changed: 67 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,26 @@ public class GradLoopState
7878
/// </summary>
7979
public int pending_exits_count { get; set; }
8080

81+
Operation _grad_sync;
82+
public Operation grad_sync
83+
{
84+
get
85+
{
86+
if(_grad_sync == null)
87+
{
88+
tf_with(ops.control_dependencies(null), delegate
89+
{
90+
_grad_sync = gen_control_flow_ops.control_trigger(name: "b_sync");
91+
});
92+
_grad_sync._set_control_flow_context(_grad_context);
93+
_grad_index.op._add_control_input(_grad_sync);
94+
if (_grad_context.outer_context != null)
95+
_grad_context.outer_context.AddInnerOp(_grad_sync);
96+
}
97+
return _grad_sync;
98+
}
99+
}
100+
81101
public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_)
82102
{
83103
// Information needed by backprop.
@@ -155,7 +175,7 @@ public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_)
155175
/// <returns>The stack that contains the accumulated history of the tensor.</returns>
156176
public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false)
157177
{
158-
using (_forward_index.graph.as_default())
178+
_forward_index.graph.as_default();
159179
{
160180
var curr_ctxt = ops.get_default_graph()._get_control_flow_context();
161181
return tf_with(ops.control_dependencies(null), delegate
@@ -220,38 +240,33 @@ public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false)
220240

221241
public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false)
222242
{
223-
throw new NotImplementedException();
224-
// history_ctxt = history_value.op._get_control_flow_context()
225-
// # Find the cond context that controls history_value if any.
226-
// cond_ctxt = None
227-
// value_ctxt = value.op._get_control_flow_context()
228-
// while value_ctxt and value_ctxt != history_ctxt:
229-
// if isinstance(value_ctxt, CondContext):
230-
// cond_ctxt = value_ctxt
231-
// break
232-
// value_ctxt = value_ctxt.outer_context
233-
// with ops.control_dependencies(None):
234-
// self.grad_context.Enter()
235-
// if cond_ctxt:
236-
// # Guard stack pop with a switch if it is controlled by a cond.
237-
// grad_state = self
238-
// pred = None
239-
// while pred is None and grad_state:
240-
// pred = grad_state.history_map.get(cond_ctxt.pred.name)
241-
// grad_state = grad_state.outer_grad_state
242-
// if pred is None:
243-
// pred = cond_ctxt.pred
244-
// branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
245-
// history_value = _SwitchRefOrTensor(history_value, pred)[branch]
246-
// pop = gen_data_flow_ops.stack_pop_v2(history_value,
247-
// value.dtype.base_dtype)
248-
// pop.set_shape(value.get_shape())
249-
// self.grad_context.Exit()
250-
// parallel_iterations = self.grad_context.parallel_iterations
251-
// if parallel_iterations > 1:
252-
// # All pops are ordered after pivot_for_body and before grad_sync.
253-
// self.grad_sync._add_control_input(pop.op)
254-
// return pop
243+
var history_ctxt = history_value.op._get_control_flow_context();
244+
// Find the cond context that controls history_value if any.
245+
CondContext cond_ctxt = null;
246+
Tensor pop = null;
247+
var value_ctxt = value.op._get_control_flow_context();
248+
while(value_ctxt != null && value_ctxt != history_ctxt)
249+
{
250+
if (value_ctxt is CondContext cc)
251+
cond_ctxt = cc;
252+
value_ctxt = value_ctxt.outer_context;
253+
}
254+
tf_with(ops.control_dependencies(null), delegate
255+
{
256+
grad_context.Enter();
257+
if(cond_ctxt != null)
258+
{
259+
throw new NotImplementedException("AddBackpropAccumulatedValue");
260+
}
261+
pop = gen_data_flow_ops.stack_pop_v2(history_value, value.dtype.as_base_dtype());
262+
pop.set_shape(value.TensorShape);
263+
grad_context.Exit();
264+
});
265+
var parallel_iterations = grad_context.parallel_iterations;
266+
if (parallel_iterations > 1)
267+
// All pops are ordered after pivot_for_body and before grad_sync.
268+
grad_sync._add_control_input(pop.op);
269+
return pop;
255270
}
256271

257272
/// <summary>
@@ -272,11 +287,28 @@ public Tensor GetRealValue(Tensor value)
272287
var enter_op = util.GetLoopConstantEnter(cur_value);
273288
if(enter_op != null)
274289
{
275-
throw new NotImplementedException("GetRealValue");
290+
// Special case: cur_value comes from a constant Enter node.
291+
cur_value = enter_op.inputs[0];
292+
cur_grad_state = cur_grad_state.outer_grad_state;
293+
if(cur_grad_state == null)
294+
{
295+
// We are now outside all nested loops for this gradient(),
296+
// so `value` is a loop invariant and there is no need to
297+
// save the history of value. Just make cur_value to enter
298+
// the right control flow context.
299+
real_value = _grad_context.AddValue(cur_value);
300+
break;
301+
}
276302
}
277303
else if (constant_op.is_constant(cur_value))
278304
{
279-
throw new NotImplementedException("GetRealValue");
305+
// We are now outside all nested loops for this gradient(),
306+
// so `value` is a loop invariant and there is no need to
307+
// save the history of value. Just make cur_value to enter
308+
// the right control flow context.
309+
real_value = constant_op.constant(
310+
tensor_util.constant_value(cur_value), dtype: cur_value.dtype);
311+
break;
280312
}
281313
else
282314
{

0 commit comments

Comments
 (0)