Skip to content

Commit 8243807

Browse files
committed
IsLoopConstantEnter
1 parent a65d881 commit 8243807

File tree

13 files changed

+272
-26
lines changed

13 files changed

+272
-26
lines changed

src/TensorFlowNET.Core/Util/IFlatten.cs renamed to src/TensorFlowNET.Core/Interfaces/IFlatten.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
using System.Collections.Generic;
33
using System.Text;
44

5-
namespace Tensorflow.Operations
5+
namespace Tensorflow
66
{
77
public interface ICanBeFlattened
88
{
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public interface IPackable
8+
{
9+
void Pack(object[] sequences);
10+
}
11+
}

src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ public virtual void Exit()
170170
/// <summary>
171171
/// Add `op` to the current context.
172172
/// </summary>
173-
public void AddOp(Operation op)
173+
public virtual void AddOp(Operation op)
174174
{
175175
_AddOpInternal(op);
176176
}
@@ -210,11 +210,6 @@ public virtual void AddInnerOp(Operation op)
210210
/// </summary>
211211
protected virtual void _AddOpInternal(Operation op)
212212
{
213-
if (op.name == "rnn/while/Less")
214-
{
215-
216-
}
217-
218213
if(op == null)
219214
{
220215
throw new NotImplementedException("");
@@ -255,9 +250,34 @@ protected virtual bool _IsInOuterContext(Operation op)
255250
throw new NotImplementedException("_IsInOuterContext");
256251
}
257252

258-
protected virtual void _RemoveExternalControlEdges(Operation op)
253+
/// <summary>
254+
/// Remove any external control dependency on this op.
255+
/// </summary>
256+
/// <param name="op"></param>
257+
protected virtual (Operation[], Operation[]) _RemoveExternalControlEdges(Operation op)
259258
{
260-
var internal_control_inputs = op.control_inputs;
259+
var while_ctxt = GetWhileContext();
260+
261+
var internal_control_inputs = new List<Operation>();
262+
// A control input of `op` is internal if it is in the same while
263+
// loop context as the enclosing while loop context of self.
264+
if (while_ctxt == null)
265+
{
266+
internal_control_inputs = op.control_inputs.ToList();
267+
}
268+
else
269+
{
270+
foreach(Tensor x in op.control_inputs)
271+
{
272+
throw new NotImplementedException("");
273+
}
274+
}
275+
276+
var external_control_inputs = new List<Operation>();
277+
if (len(internal_control_inputs) != len(op.control_inputs))
278+
throw new NotImplementedException("");
279+
280+
return (internal_control_inputs.ToArray(), external_control_inputs.ToArray());
261281
}
262282

263283
/// <summary>

src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45

56
namespace Tensorflow.Operations
67
{
7-
internal class LoopVar<TItem> : ICanBeFlattened
8+
internal class LoopVar<TItem> : ICanBeFlattened, IPackable
89
{
9-
public Tensor Counter { get; }
10-
public TItem Item { get; }
10+
public Tensor Counter { get; set; }
11+
public TItem Item { get; set; }
1112

1213
public LoopVar(Tensor counter, TItem item)
1314
{
@@ -25,6 +26,13 @@ public object[] Flatten()
2526
return elements.ToArray();
2627
}
2728

29+
public void Pack(object[] sequences)
30+
{
31+
Counter = sequences[0] as Tensor;
32+
if (typeof(TItem).GetInterface(typeof(IPackable).Name) != null)
33+
(Item as IPackable).Pack(sequences.Skip(1).ToArray());
34+
}
35+
2836
public static implicit operator (Tensor, TItem)(LoopVar<TItem> loopVar)
2937
{
3038
return (loopVar.Counter, loopVar.Item);

src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs

Lines changed: 138 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,13 @@ private TensorShape _get_shape_invariant(Tensor var, int[] shape = null)
240240

241241
// Build the graph for body.
242242
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray();
243+
_pivot_for_body = vars_for_body[0];
243244
// Convert TensorArray flow variables inside the context back into
244245
// their associated TensorArrays for calling the body.
245-
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
246-
var body_result = body(original_loop_vars);
246+
var vars_for_body_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
247+
var packed_vars_for_body = nest.pack_sequence_as2(original_loop_vars, vars_for_body_with_tensor_arrays);
248+
var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);
249+
var body_result = body(packed_vars_for_body);
247250
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);
248251

249252
// Store body_result to keep track of TensorArrays returned by body
@@ -267,17 +270,27 @@ private TensorShape _get_shape_invariant(Tensor var, int[] shape = null)
267270
private void _FixControlInputsAndContext(Tensor[] enters)
268271
{
269272
var graph = ops.get_default_graph();
270-
foreach(var e in enters)
273+
foreach(var x in enters)
271274
{
272-
var inp_op = e.op.inputs[0].op;
275+
var inp_op = x.op.inputs[0].op;
273276
var control_inputs = graph._control_dependencies_for_inputs(new[] { inp_op });
277+
var outer_control_inputs = new List<Operation>();
278+
foreach(Operation op in control_inputs)
279+
{
280+
// We need to keep control inputs that are in any ancestor
281+
// ControlFlowContext, and within outer WhileContext.
282+
var keep_as_control_input = true;
283+
var op_ctxt = control_flow_util.GetOutputContext(op);
284+
var outer_ctxt = outer_context;
285+
throw new NotImplementedException("");
286+
}
274287
// op for op in control_inputs if self._IsInOuterContext(op)
275-
var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op))
288+
/*var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op))
276289
.Select(x => x.op)
277-
.ToArray();
278-
e.op._set_control_flow_context(this);
279-
e.op._add_control_inputs(outer_control_inputs);
280-
graph._record_op_seen_by_control_dependencies(e.op);
290+
.ToArray();*/
291+
x.op._set_control_flow_context(this);
292+
x.op._add_control_inputs(outer_control_inputs.ToArray());
293+
graph._record_op_seen_by_control_dependencies(x.op);
281294
}
282295
}
283296

@@ -288,13 +301,127 @@ private void _InitializeValues(Tensor[] values)
288301
_values.Add(x.name);
289302
}
290303

304+
protected override void _AddOpInternal(Operation op)
305+
{
306+
Operation[] external_inputs = new Operation[0];
307+
if (op == null)
308+
{
309+
throw new NotImplementedException("");
310+
}
311+
else
312+
{
313+
foreach (var index in range(len(op.inputs)))
314+
{
315+
var x = op.inputs[index];
316+
var real_x = AddValue(x);
317+
if (real_x != x)
318+
op._update_input(index, real_x);
319+
}
320+
321+
// Remove any external control dependency on this op.
322+
(_, external_inputs) = _RemoveExternalControlEdges(op);
323+
// Add a control dependency to prevent loop invariants from
324+
// enabling ops that should not be executed.
325+
_MaybeAddControlDependency(op);
326+
foreach (Tensor x in op.outputs)
327+
_values.Add(x.name);
328+
}
329+
330+
if (external_inputs.Length > 0)
331+
{
332+
throw new NotImplementedException("external_inputs.Length > 0");
333+
}
334+
335+
if (_outer_context != null || !IsLoopExit(op))
336+
foreach (Tensor x in op.outputs)
337+
op.graph.prevent_feeding(x);
338+
339+
if (_outer_context != null)
340+
_outer_context.AddInnerOp(op);
341+
}
342+
343+
protected void _MaybeAddControlDependency(Operation op)
344+
{
345+
// Determines if `op` needs a control dependency.
346+
Func<Operation, bool> _IsOpFree = (op1) =>
347+
{
348+
if (op1.control_inputs.Length > 0)
349+
return false;
350+
351+
if (op1.type == "SymbolicGradient")
352+
return true;
353+
354+
foreach (Tensor x in op1.inputs)
355+
if (!control_flow_util.IsLoopConstantEnter(x.op))
356+
return false;
357+
358+
return true;
359+
};
360+
361+
if (_IsOpFree(op))
362+
op._add_control_input(GetControlPivot().op);
363+
}
364+
365+
private Tensor GetControlPivot()
366+
{
367+
if (_pivot_for_body != null)
368+
return _pivot_for_body;
369+
return _pivot_for_pred;
370+
}
371+
372+
public override void AddOp(Operation op)
373+
{
374+
_AddOpInternal(op);
375+
}
376+
291377
public override Tensor AddValue(Tensor val)
292378
{
293379
var result = val;
294-
var new_value = _values.Contains(val.name);
380+
var new_value = !_values.Contains(val.name);
295381
new_value &= val.op._get_control_flow_context() != this;
296382
if (new_value)
297-
throw new NotImplementedException("");
383+
{
384+
_values.Add(val.name);
385+
386+
// If we are in a grad context and val is from its forward context,
387+
// use GetRealValue(), which adds the logic to save the history of
388+
// val in forward.
389+
var grad_ctxt = ops.get_default_graph()._get_control_flow_context();
390+
if(grad_ctxt != null)
391+
{
392+
grad_ctxt = grad_ctxt.GetWhileContext();
393+
if (grad_ctxt.grad_state != null)
394+
{
395+
throw new NotImplementedException("");
396+
}
397+
}
398+
399+
if (_outer_context != null)
400+
{
401+
result = _outer_context.AddValue(val);
402+
}
403+
404+
// Create an Enter to make `result` known to this loop context.
405+
Tensor enter = null;
406+
tf_with(ops.control_dependencies(new ITensorOrOperation[0]), delegate
407+
{
408+
enter = _Enter(
409+
result,
410+
_name,
411+
is_constant: true,
412+
parallel_iterations: _parallel_iterations);
413+
enter.graph.prevent_feeding(enter);
414+
if (_outer_context != null)
415+
_outer_context.AddInnerOp(enter.op);
416+
});
417+
418+
// Fix the control inputs and control flow context of these enter ops.
419+
_FixControlInputsAndContext(new[] { enter });
420+
// Add `enter` in this context.
421+
_values.Add(enter.name);
422+
_external_values[val.name] = enter;
423+
result = enter;
424+
}
298425
else
299426
{
300427
var actual_val = _external_values.ContainsKey(val.name) ? _external_values[val.name] : null;

src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
namespace Tensorflow.Operations
66
{
7-
internal class BodyItemInRnnWhileLoop : ICanBeFlattened
7+
internal class BodyItemInRnnWhileLoop : ICanBeFlattened, IPackable
88
{
99
/// <summary>
1010
/// int32 scalar Tensor.
@@ -36,5 +36,12 @@ public object[] Flatten()
3636
elements.Add(state);
3737
return elements.ToArray();
3838
}
39+
40+
public void Pack(object[] sequences)
41+
{
42+
time = sequences[0] as Tensor;
43+
output_ta_t = new[] { sequences[1] as TensorArray };
44+
state = sequences[2] as Tensor;
45+
}
3946
}
4047
}

src/TensorFlowNET.Core/Operations/NnOps/rnn.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,12 @@ private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, T
192192
// Take a time step of the dynamic RNN.
193193
Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) =>
194194
{
195-
throw new NotImplementedException("");
195+
if (in_graph_mode)
196+
{
197+
input_ta.Select(ta => ta.read(time)).ToArray();
198+
}
199+
200+
return item;
196201
};
197202

198203
control_flow_ops.while_loop(

src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,5 +159,20 @@ public void _maybe_colocate_with(Tensor value)
159159
{
160160
_colocate_with.Add(value);
161161
}
162+
163+
public Tensor read(Tensor index, string name = null)
164+
{
165+
var value = gen_data_flow_ops.tensor_array_read_v3(
166+
handle: _handle,
167+
index: index,
168+
flow_in: _flow,
169+
dtype: _dtype,
170+
name: name);
171+
172+
if (_element_shape != null)
173+
value.set_shape(_element_shape[0].dims);
174+
175+
return value;
176+
}
162177
}
163178
}

src/TensorFlowNET.Core/Operations/control_flow_ops.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,8 @@ public static Tensor while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TIt
648648
body_buildloop = (item) =>
649649
{
650650
var (i, lv) = (item.Counter, item.Item);
651-
return new LoopVar<TItem>(i + 1, orig_body(lv));
651+
var ob = orig_body(lv);
652+
return new LoopVar<TItem>(i + 1, ob);
652653
};
653654
}
654655
try_to_pack = false;

src/TensorFlowNET.Core/Operations/control_flow_util.py.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,26 @@ public class control_flow_util
3030
public static bool IsLoopExit(Operation op)
3131
{
3232
return op.type == "Exit" || op.type == "RefExit";
33+
}
34+
35+
/// <summary>
36+
/// Returns true if `op` is an Enter.
37+
/// </summary>
38+
/// <param name="op"></param>
39+
/// <returns></returns>
40+
public static bool IsLoopEnter(Operation op)
41+
{
42+
return op.type == "Enter" || op.type == "RefEnter";
43+
}
44+
45+
/// <summary>
46+
/// Return true iff op is a loop invariant.
47+
/// </summary>
48+
/// <param name="op"></param>
49+
/// <returns></returns>
50+
public static bool IsLoopConstantEnter(Operation op)
51+
{
52+
return IsLoopEnter(op) && op.get_attr<bool>("is_constant");
3353
}
3454

3555
/// <summary>

0 commit comments

Comments
 (0)