Skip to content

Commit e2190c9

Browse files
committed
ControlFlow MergeOutput
1 parent 5ee46e4 commit e2190c9

File tree

5 files changed

+459
-309
lines changed

5 files changed

+459
-309
lines changed

src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
using Tensorflow.Operations.ControlFlows;
2121
using static Tensorflow.ControlFlowContextDef;
2222
using static Tensorflow.Binding;
23+
using util = Tensorflow.control_flow_util;
2324

2425
namespace Tensorflow.Operations
2526
{
@@ -146,6 +147,14 @@ public virtual void Exit()
146147
graph._set_control_flow_context(last_context);
147148
}
148149

150+
public void ExitResult(Tensor[] result)
151+
{
152+
if(_outer_context != null)
153+
{
154+
throw new NotImplementedException("ExitResult");
155+
}
156+
}
157+
149158
/// <summary>
150159
/// Add `op` to the current context.
151160
/// </summary>
@@ -172,6 +181,11 @@ public virtual Tensor AddValue(Tensor val)
172181
return null;
173182
}
174183

184+
public void AddName(string name)
185+
{
186+
_values.Add(name);
187+
}
188+
175189
/// <summary>
176190
/// Notifies a scope about an operator added to an inner scope.
177191
/// </summary>
@@ -246,9 +260,11 @@ protected virtual (Operation[], Operation[]) _RemoveExternalControlEdges(Operati
246260
}
247261
else
248262
{
249-
foreach(Tensor x in op.control_inputs)
263+
foreach(Operation x in op.control_inputs)
250264
{
251-
throw new NotImplementedException("");
265+
var ctxt = util.GetOutputContext(x);
266+
if (ctxt != null && ctxt.GetWhileContext() == while_ctxt)
267+
internal_control_inputs.append(x);
252268
}
253269
}
254270

@@ -288,6 +304,14 @@ protected ControlFlowContext from_control_flow_context_def(ControlFlowContextDef
288304
throw new NotImplementedException($"Unknown ControlFlowContextDef field: {context_def.CtxtCase}");
289305
}
290306

307+
public virtual bool IsWhileContext()
308+
{
309+
throw new NotImplementedException("IsWhileContext");
310+
}
311+
312+
public virtual bool IsCondContext()
313+
=> false;
314+
291315
public object to_proto()
292316
{
293317
throw new NotImplementedException();

src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs

Lines changed: 157 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,20 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using System;
18+
using System.Linq;
19+
using System.Collections.Generic;
20+
using util = Tensorflow.control_flow_util;
21+
using static Tensorflow.Binding;
22+
1723
namespace Tensorflow.Operations.ControlFlows
1824
{
1925
/// <summary>
2026
/// Maintain the mapping from the loops to their grad states.
2127
/// </summary>
2228
public class ControlFlowState
2329
{
30+
Dictionary<ControlFlowContext, GradLoopState> _map;
2431
//class ControlFlowState(object):
2532
// """Maintain the mapping from the loops to their grad states."""
2633

@@ -40,51 +47,67 @@ public class ControlFlowState
4047
// return self._map.get(forward_ctxt)
4148
// return None
4249

43-
// def ProcessUnusedLoopExits(self, pending_count, to_ops_set):
44-
// """Process all the "unused" loop exits.
45-
46-
// The "unused" exits of the loops are added to `unused_exits`. An exit is
47-
// unused if its pending_count is 0. If there is an exit with real gradient,
48-
// all these deferred exits will enter the backprop loop with zero gradient.
49-
// Otherwise, they will enter the backprop loop with None. As an example,
50-
// people often write:
51-
52-
// ```python
53-
// v1, _ = tf.while_loop(p, b, [x1, x2])
54-
// result = gradients(v1, x1)
55-
// ```
56-
57-
// The exit node for x2 is not included by the betweenness analysis. But we
58-
// need to backprop x2 if x2 is involved in computing v1.
59-
60-
// Args:
61-
// pending_count: The number of backprop inputs for every op.
62-
// to_ops_set: The set of ops for ys in gradients(ys, xs)
63-
64-
// Returns:
65-
// The set of unused loop exits that we know at this point we need
66-
// to backprop.
67-
// """
68-
// loop_exits = []
69-
// for grad_state in self._map.values():
70-
// for y in grad_state.forward_loop_exits:
71-
// if pending_count[y.op] == 0:
72-
// grad_state.pending_exits_count -= 1
73-
// if y.op not in to_ops_set:
74-
// grad_state.unused_exits.append(y)
75-
// if grad_state.pending_exits_count == 0:
76-
// loop_exits.extend(grad_state.unused_exits)
77-
// # Need to include Enters in backprop for higher-order gradients.
78-
// for y in grad_state.forward_context.loop_enters:
79-
// if pending_count[y.op] == 0:
80-
// pending_count[y.op] = 1
81-
// return loop_exits
82-
83-
// def EnterGradWhileContext(self, op, before):
84-
// """Enter the WhileContext for gradient computation."""
85-
// grad_state = self.GetGradState(op, before)
86-
// if grad_state:
87-
// grad_state.grad_context.Enter()
50+
public ControlFlowState()
51+
{
52+
_map = new Dictionary<ControlFlowContext, GradLoopState>();
53+
}
54+
55+
/// <summary>
56+
/// Return the grad state for this op if it's in a forward loop context.
57+
/// </summary>
58+
/// <param name="op"></param>
59+
/// <param name="before"></param>
60+
/// <returns></returns>
61+
public GradLoopState GetGradState(Operation op, bool before)
62+
{
63+
ControlFlowContext forward_ctxt = null;
64+
if (before && util.IsLoopExit(op))
65+
{
66+
forward_ctxt = op._get_control_flow_context();
67+
forward_ctxt = forward_ctxt.outer_context;
68+
if (forward_ctxt != null)
69+
forward_ctxt = forward_ctxt.GetWhileContext();
70+
}
71+
else
72+
forward_ctxt = util.GetWhileContext(op);
73+
if (forward_ctxt != null)
74+
return _map.get(forward_ctxt);
75+
return null;
76+
}
77+
78+
public Tensor[] ProcessUnusedLoopExits(Dictionary<string, int> pending_count, List<Operation> to_ops_set)
79+
{
80+
var loop_exits = new List<Tensor>();
81+
foreach(var grad_state in _map.Values)
82+
{
83+
foreach(var y in grad_state.forward_loop_exits)
84+
{
85+
if(!pending_count.ContainsKey(y.op.name))
86+
{
87+
grad_state.pending_exits_count -= 1;
88+
if (!to_ops_set.Contains(y.op))
89+
grad_state.unused_exits.append(y);
90+
if (grad_state.pending_exits_count == 0)
91+
loop_exits.extend(grad_state.unused_exits);
92+
}
93+
}
94+
95+
foreach(var y in grad_state.forward_context.loop_enters)
96+
{
97+
if (!pending_count.ContainsKey(y.op.name))
98+
pending_count[y.op.name] = 1;
99+
}
100+
}
101+
102+
return loop_exits.ToArray();
103+
}
104+
105+
public void EnterGradWhileContext(Operation op, bool before)
106+
{
107+
var grad_state = GetGradState(op, before);
108+
if (grad_state != null)
109+
grad_state.grad_context.Enter();
110+
}
88111

89112
// def ExitGradWhileContext(self, op, before):
90113
// """Exit the WhileContext for gradient computation."""
@@ -118,6 +141,32 @@ public class ControlFlowState
118141
// if loop_exit.op not in between_ops:
119142
// between_ops.add(loop_exit.op)
120143
// between_op_list.append(loop_exit.op)
144+
public void AddWhileContext(Operation op, List<Operation> between_op_list, List<Operation> between_ops)
145+
{
146+
var forward_ctxt = op.GetWhileContext();
147+
var grad_state = _map.ContainsKey(forward_ctxt) ? _map[forward_ctxt] : null;
148+
if(grad_state == null)
149+
{
150+
GradLoopState outer_grad_state = null;
151+
var outer_forward_ctxt = forward_ctxt.outer_context;
152+
if (outer_forward_ctxt != null)
153+
outer_forward_ctxt = outer_forward_ctxt.GetWhileContext();
154+
if (outer_forward_ctxt != null)
155+
outer_grad_state = _map[outer_forward_ctxt];
156+
grad_state = new GradLoopState(forward_ctxt, outer_grad_state);
157+
_map[forward_ctxt] = grad_state;
158+
159+
// We need to include all exits of a loop for backprop.
160+
foreach (var loop_exit in grad_state.forward_loop_exits)
161+
{
162+
if(!between_ops.Contains(loop_exit.op))
163+
{
164+
between_ops.add(loop_exit.op);
165+
between_op_list.append(loop_exit.op);
166+
}
167+
}
168+
}
169+
}
121170

122171
// def ZerosLikeForExit(self, val):
123172
// """Create zeros_like gradient for a loop exit.
@@ -174,70 +223,69 @@ public class ControlFlowState
174223
// result = array_ops.zeros_like(val, optimize=False)
175224
// return result
176225

177-
// def ZerosLike(self, op, index):
178-
// """Create zeros_like for the specified output of an op.
179-
180-
// If op is in a while loop that is part of gradients(), this method
181-
// must be called in its grad loop context.
182-
183-
// Args:
184-
// op: A tensorflow operation.
185-
// index: the index for a specific output of the op.
186-
187-
// Returns:
188-
// A zero tensor of the same shape of op.outputs[index].
189-
// """
190-
// if util.IsLoopSwitch(op):
191-
// return None
192-
// if op.graph._building_function: # pylint: disable=protected-access
193-
// # The optimization here is tricky to apply to functions
194-
// return array_ops.zeros_like(op.outputs[index])
195-
// dead_branch = util.IsSwitch(op)
196-
// forward_ctxt = _GetWhileContext(op)
197-
// grad_state = self._map.get(forward_ctxt)
198-
// if grad_state is None:
199-
// # op is not in a while loop that is part of gradients().
200-
// return ZerosLikeOutsideLoop(op, index)
201-
// op_ctxt = op._get_control_flow_context()
202-
// val = ops.convert_to_tensor(op.outputs[index], name="tensor")
203-
// shape = val.get_shape()
204-
// if shape.is_fully_defined():
205-
// # If the shape is known statically, just create a zero tensor with
206-
// # the right shape in the grad loop context.
207-
// result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)
208-
// if dead_branch:
209-
// # op is a cond switch. Guard the zero tensor with a switch.
210-
// pred = grad_state.history_map.get(op_ctxt.pred.name)
211-
// branch = op_ctxt.branch
212-
// result = _SwitchRefOrTensor(result, pred)[1 - branch]
213-
// else:
214-
// # Unknown shape so keep a history of the shape at runtime.
215-
// if dead_branch:
216-
// # Need to add a special switch to guard the value.
217-
// pred = op_ctxt.pred
218-
// branch = op_ctxt.branch
219-
// op_ctxt.outer_context.Enter()
220-
// val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch]
221-
// zeros_shape = array_ops.shape_internal(val, optimize=False)
222-
// op_ctxt.outer_context.Exit()
223-
// val.op._set_control_flow_context(op_ctxt)
224-
// zeros_shape.op._set_control_flow_context(op_ctxt)
225-
// else:
226-
// op_ctxt.Enter()
227-
// zeros_shape = array_ops.shape_internal(val, optimize=False)
228-
// op_ctxt.Exit()
229-
230-
// # Add forward accumulator for shape.
231-
// grad_state.grad_context.Exit()
232-
// history_zeros_shape = grad_state.AddForwardAccumulator(
233-
// zeros_shape, dead_branch=dead_branch)
234-
// grad_state.grad_context.Enter()
235-
236-
// # Create a zero tensor with the right shape.
237-
// shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape,
238-
// zeros_shape, dead_branch)
239-
// result = array_ops.zeros(shape, val.dtype)
240-
// return result
226+
public Tensor ZerosLike(Operation op, int index)
227+
{
228+
if (util.IsLoopSwitch(op))
229+
return null;
230+
if (op.graph.building_function)
231+
return array_ops.zeros_like(op.outputs[index]);
232+
var dead_branch = util.IsSwitch(op);
233+
var forward_ctxt = util.GetWhileContext(op);
234+
var grad_state = _map.get(forward_ctxt);
235+
// op is not in a while loop that is part of gradients().
236+
if (grad_state == null)
237+
return ZerosLikeOutsideLoop(op, index);
238+
throw new NotImplementedException("ZerosLike");
239+
}
240+
241+
public Tensor ZerosLikeOutsideLoop(Operation op, int index)
242+
{
243+
var val = op.outputs[index];
244+
if (!util.IsSwitch(op))
245+
{
246+
if (val.dtype == dtypes.resource)
247+
throw new NotImplementedException("ZerosLikeOutsideLoop");
248+
/*return array_ops.zeros(
249+
gen_resource_variable_ops.variable_shape(val),
250+
dtype: default_gradient.get_zeros_dtype(val));*/
251+
return array_ops.zeros_like(val, optimize: false);
252+
}
253+
else
254+
throw new NotImplementedException("ZerosLikeOutsideLoop");
255+
}
256+
257+
/// <summary>
258+
/// Create zeros_like gradient for a loop exit.
259+
/// </summary>
260+
/// <param name="val"></param>
261+
/// <returns></returns>
262+
public Tensor ZerosLikeForExit(Tensor val)
263+
{
264+
Tensor result = null;
265+
var val_shape = val.TensorShape;
266+
var forward_ctxt = val.op._get_control_flow_context();
267+
var outer_forward_ctxt = forward_ctxt.outer_context;
268+
if (outer_forward_ctxt != null)
269+
outer_forward_ctxt = outer_forward_ctxt.GetWhileContext();
270+
GradLoopState outer_grad_state = null;
271+
if (outer_forward_ctxt != null)
272+
outer_grad_state = _map.get(outer_forward_ctxt);
273+
// This is a nested loop.
274+
if (outer_grad_state != null)
275+
{
276+
throw new NotImplementedException("ZerosLikeForExit");
277+
}
278+
else
279+
{
280+
// If the shape is known statically, just create a zero tensor
281+
// with the right shape.
282+
if (val_shape.is_fully_defined())
283+
result = array_ops.zeros(val_shape.dims, val.dtype);
284+
else
285+
result = array_ops.zeros_like(val, optimize: false);
286+
}
287+
return result;
288+
}
241289

242290
// def PostProcessing(self):
243291
// """Perform postprocessing at the end of gradients().

0 commit comments

Comments
 (0)