Skip to content

Commit 3425aa4

Browse files
committed
LoopVar<T>
1 parent 61bc941 commit 3425aa4

File tree

3 files changed

+160
-83
lines changed

3 files changed

+160
-83
lines changed

src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
using System.Linq;
2020
using Tensorflow.Operations.ControlFlows;
2121
using static Tensorflow.ControlFlowContextDef;
22+
using static Tensorflow.Binding;
2223

2324
namespace Tensorflow.Operations
2425
{
@@ -72,6 +73,7 @@ public abstract class ControlFlowContext : IObjectLife
7273
public ControlFlowContext()
7374
{
7475
_context_stack = new Stack<ControlFlowContext>();
76+
_external_values = new Dictionary<string, ITensorOrOperation>();
7577
}
7678

7779
public string name { get => _name; }
@@ -180,6 +182,11 @@ public void AddOp(Operation op)
180182

181183
public virtual bool back_prop => throw new NotImplementedException("abstract method");
182184

185+
/// <summary>
186+
/// Add `val` to the current context and its outer context recursively.
187+
/// </summary>
188+
/// <param name="val"></param>
189+
/// <returns></returns>
183190
public virtual Tensor AddValue(Tensor val)
184191
{
185192
// to be overridden
@@ -203,7 +210,25 @@ public virtual void AddInnerOp(Operation op)
203210
/// </summary>
204211
protected virtual void _AddOpInternal(Operation op)
205212
{
206-
213+
if (op.name == "rnn/while/Less")
214+
{
215+
216+
}
217+
218+
if(op == null)
219+
{
220+
throw new NotImplementedException("");
221+
}
222+
else
223+
{
224+
foreach(var index in range(len(op.inputs)))
225+
{
226+
var x = op.inputs[index];
227+
var real_x = AddValue(x);
228+
if (real_x != x)
229+
op._update_input(index, real_x);
230+
}
231+
}
207232
}
208233

209234
protected bool OpInContext(Operation op)

src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,10 @@ public object[] Flatten()
2424
elements.Add(Item);
2525
return elements.ToArray();
2626
}
27+
28+
public static implicit operator (Tensor, TItem)(LoopVar<TItem> loopVar)
29+
{
30+
return (loopVar.Counter, loopVar.Item);
31+
}
2732
}
2833
}

src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs

Lines changed: 129 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ private void _init_from_args(Tensor maximum_iterations,
7171
string name)
7272
{
7373
_name = ops.get_default_graph().unique_name(name);
74+
_maximum_iterations = maximum_iterations;
75+
_parallel_iterations = parallel_iterations;
7476
_back_prop = back_prop;
7577
_swap_memory = swap_memory;
7678
_loop_exits = new List<Tensor>();
@@ -107,18 +109,27 @@ private void _init_from_proto(WhileContextDef context_def, string import_scope =
107109
/// <summary>
108110
/// Add the loop termination condition and body to the graph.
109111
/// </summary>
110-
internal Tensor[] BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred,
111-
Func<Tensor, TItem, LoopVar<TItem>> body,
112+
internal Tensor[] BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred,
113+
Func<LoopVar<TItem>, LoopVar<TItem>> body,
112114
LoopVar<TItem> loop_vars,
113-
TensorShape shape_invariants,
115+
TensorShape[] shape_invariants,
114116
bool return_same_structure)
115117
{
116118
// Keep original_loop_vars to identify which are TensorArrays
117119
var original_loop_vars = loop_vars;
118120
// Convert TensorArrays to their flow variables
121+
var loop_vars_tensors = nest.flatten2(loop_vars)
122+
.Select(x => _convert_tensorarray_to_flow(x))
123+
.ToArray();
124+
125+
if (shape_invariants == null)
126+
shape_invariants = loop_vars_tensors
127+
.Select(x => _get_shape_invariant(x as Tensor))
128+
.ToArray();
129+
119130
Enter();
120131
var(original_body_result, exit_vars) = _BuildLoop(
121-
pred, body, original_loop_vars, loop_vars, shape_invariants);
132+
pred, body, original_loop_vars, loop_vars_tensors, shape_invariants);
122133
Exit();
123134

124135
var flat_result = original_body_result;
@@ -131,7 +142,7 @@ internal Tensor[] BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred,
131142
return packed_exit_vars as Tensor[];
132143
}
133144

134-
private Tensor _convert_tensorarray_to_flow<TItem>(TItem tensor_or_tensor_array)
145+
private Tensor _convert_tensorarray_to_flow(object tensor_or_tensor_array)
135146
{
136147
if (tensor_or_tensor_array is TensorArray tensor_array)
137148
return tensor_array.flow;
@@ -141,97 +152,116 @@ private Tensor _convert_tensorarray_to_flow<TItem>(TItem tensor_or_tensor_array)
141152
throw new NotImplementedException("_convert_tensorarray_to_flow");
142153
}
143154

144-
private (Tensor[], Tensor[]) _BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred,
145-
Func<Tensor, TItem, LoopVar<TItem>> body,
146-
LoopVar<TItem> original_loop_vars,
147-
LoopVar<TItem> loop_vars,
148-
TensorShape shape_invariants)
155+
private TensorShape _get_shape_invariant(Tensor var, int[] shape = null)
149156
{
150-
var flat_loop_vars = original_loop_vars;
157+
return var.TensorShape;
158+
}
151159

152-
// Convert TensorArrays to their flow variables
153-
var loop_vars_tensor = nest.map_structure(
154-
_convert_tensorarray_to_flow,
155-
nest.flatten2(loop_vars));
160+
/// <summary>
161+
/// Add the loop termination condition and body to the graph.
162+
/// </summary>
163+
/// <typeparam name="TItem"></typeparam>
164+
/// <param name="pred"></param>
165+
/// <param name="body"></param>
166+
/// <param name="original_loop_vars"></param>
167+
/// <param name="loop_vars"></param>
168+
/// <param name="shape_invariants"></param>
169+
/// <returns></returns>
170+
private (Tensor[], Tensor[]) _BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred,
171+
Func<LoopVar<TItem>, LoopVar<TItem>> body,
172+
LoopVar<TItem> original_loop_vars,
173+
Tensor[] loop_vars,
174+
TensorShape[] shape_invariants)
175+
{
176+
var flat_loop_vars = nest.flatten2(original_loop_vars)
177+
.Select(x => (ITensorOrTensorArray)x)
178+
.ToArray();
156179

157180
// Let the context know the loop variables so the loop variables
158181
// would be added in the outer contexts properly.
159-
if (loop_vars is Tensor[] real_vars)
182+
_InitializeValues(loop_vars);
183+
var real_vars = loop_vars;
184+
Tensor[] enter_vars = null;
185+
tf_with(ops.control_dependencies(null), delegate
160186
{
161-
_InitializeValues(real_vars);
162-
Tensor[] enter_vars = null;
163-
tf_with(ops.control_dependencies(null), delegate
164-
{
165-
enter_vars = real_vars.Select(x => _Enter(x,
166-
_name,
167-
is_constant: false,
168-
parallel_iterations: _parallel_iterations,
169-
use_input_shape: shape_invariants == null))
170-
.ToArray();
171-
172-
foreach (var x in enter_vars)
173-
{
174-
x.graph.prevent_feeding(x);
175-
if (_outer_context != null)
176-
_outer_context.AddInnerOp(x.op);
177-
}
178-
});
179-
180-
// Finds the closest enclosing non-None control pivot.
181-
var outer_context = _outer_context;
182-
while (outer_context != null)
187+
enter_vars = real_vars.Select(x => _Enter(x,
188+
_name,
189+
is_constant: false,
190+
parallel_iterations: _parallel_iterations,
191+
use_input_shape: shape_invariants == null))
192+
.ToArray();
193+
194+
foreach (var x in enter_vars)
183195
{
184-
196+
x.graph.prevent_feeding(x);
197+
if (_outer_context != null)
198+
_outer_context.AddInnerOp(x.op);
185199
}
200+
});
186201

187-
_SetShapeInvariants(real_vars, enter_vars, shape_invariants);
188-
189-
// Fix the control inputs and control flow context of these enter ops.
190-
_FixControlInputsAndContext(enter_vars);
191-
_InitializeValues(enter_vars);
192-
_loop_enters = enter_vars.ToList();
193-
194-
var merge_vars = enter_vars
195-
.Select(x => merge(new[] { x, x }))
196-
.ToArray();
202+
// Finds the closest enclosing non-None control pivot.
203+
var outer_context = _outer_context;
204+
object control_pivot = null;
205+
while (outer_context != null && control_pivot == null)
206+
{
197207

198-
_pivot_for_pred = merge_vars[0];
208+
}
199209

200-
// Build the graph for pred.
201-
var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars);
202-
// var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays);
203-
var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0], default(TItem)));
204-
_pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond");
205-
var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot))
206-
.ToArray();
210+
if (control_pivot != null)
211+
{
207212

208-
// Build the graph for body.
209-
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray();
210-
// Convert TensorArray flow variables inside the context back into
211-
// their associated TensorArrays for calling the body.
212-
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
213-
/*var body_result = body(packed_vars_for_body[0]);
214-
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);
215-
216-
// Store body_result to keep track of TensorArrays returned by body
217-
var original_body_result = new[] { body_result };
218-
// Convert TensorArrays returned by body into their flow variables
219-
var result = new[] { body_result };
220-
221-
var next_vars = new List<Tensor>();
222-
foreach (var (m, v) in zip(merge_vars, result))
223-
next_vars.Add(_AddNextAndBackEdge(m, v));
224-
225-
// Add the exit ops.
226-
var exit_vars = switch_vars.Select(x => exit(x[0])).ToList();
227-
_loop_exits = exit_vars;
228-
229-
// Exit the loop.
230-
// ExitResult(exit_vars);
231-
return (original_body_result, exit_vars.ToArray());*/
232213
}
233214

234-
throw new NotImplementedException("");
215+
_SetShapeInvariants(real_vars, enter_vars, shape_invariants);
216+
217+
// Fix the control inputs and control flow context of these enter ops.
218+
_FixControlInputsAndContext(enter_vars);
219+
_InitializeValues(enter_vars);
220+
_loop_enters = enter_vars.ToList();
221+
222+
var merge_vars = enter_vars
223+
.Select(x => merge(new[] { x, x }))
224+
.ToArray();
225+
226+
_pivot_for_pred = merge_vars[0];
227+
228+
// Build the graph for pred.
229+
var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars);
230+
//var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays, expand_composites: true);
231+
var packed_vars = new LoopVar<TItem>((Tensor)merge_vars_with_tensor_arrays[0],
232+
(TItem)(object)new BodyItemInRnnWhileLoop((Tensor)merge_vars_with_tensor_arrays[1],
233+
new[] { (TensorArray)merge_vars_with_tensor_arrays[2] },
234+
(Tensor)merge_vars_with_tensor_arrays[3]));
235+
var pp = pred(packed_vars);
236+
var c = ops.convert_to_tensor(pp);
237+
_pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond");
238+
var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot))
239+
.ToArray();
240+
241+
// Build the graph for body.
242+
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray();
243+
// Convert TensorArray flow variables inside the context back into
244+
// their associated TensorArrays for calling the body.
245+
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
246+
var body_result = body(original_loop_vars);
247+
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);
248+
249+
// Store body_result to keep track of TensorArrays returned by body
250+
var original_body_result = new[] { body_result };
251+
// Convert TensorArrays returned by body into their flow variables
252+
var result = new[] { body_result };
253+
254+
var next_vars = new List<Tensor>();
255+
//foreach (var (m, v) in zip(merge_vars, result))
256+
//next_vars.Add(_AddNextAndBackEdge(m, v));
257+
258+
// Add the exit ops.
259+
var exit_vars = switch_vars.Select(x => exit(x[0])).ToList();
260+
_loop_exits = exit_vars;
261+
262+
// Exit the loop.
263+
// ExitResult(exit_vars);
264+
return (null, exit_vars.ToArray());
235265
}
236266

237267
private void _FixControlInputsAndContext(Tensor[] enters)
@@ -258,6 +288,23 @@ private void _InitializeValues(Tensor[] values)
258288
_values.Add(x.name);
259289
}
260290

291+
public override Tensor AddValue(Tensor val)
292+
{
293+
var result = val;
294+
var new_value = _values.Contains(val.name);
295+
new_value &= val.op._get_control_flow_context() != this;
296+
if (new_value)
297+
throw new NotImplementedException("");
298+
else
299+
{
300+
var actual_val = _external_values.ContainsKey(val.name) ? _external_values[val.name] : null;
301+
if (actual_val != null)
302+
result = actual_val as Tensor;
303+
}
304+
305+
return result;
306+
}
307+
261308
public override WhileContext GetWhileContext()
262309
{
263310
return this;

0 commit comments

Comments
 (0)