Skip to content

Commit ed9a8c8

Browse files
committed
change WhileContext maximum_iterations to Tensor.
1 parent a589552 commit ed9a8c8

File tree

1 file changed

+93
-74
lines changed

1 file changed

+93
-74
lines changed

src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs

Lines changed: 93 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public class WhileContext : ControlFlowContext
4242
public override GradLoopState grad_state => _grad_state;
4343
public override bool back_prop => _back_prop;
4444

45-
public WhileContext(int? maximum_iterations = null,
45+
public WhileContext(Tensor maximum_iterations = null,
4646
int parallel_iterations = 10,
4747
bool back_prop = true,
4848
bool swap_memory = false,
@@ -64,7 +64,7 @@ public WhileContext(int? maximum_iterations = null,
6464
_grad_state = grad_state;
6565
}
6666

67-
private void _init_from_args(int? maximum_iterations,
67+
private void _init_from_args(Tensor maximum_iterations,
6868
int parallel_iterations,
6969
bool back_prop,
7070
bool swap_memory,
@@ -107,9 +107,9 @@ private void _init_from_proto(WhileContextDef context_def, string import_scope =
107107
/// <summary>
108108
/// Add the loop termination condition and body to the graph.
109109
/// </summary>
110-
public Tensor[] BuildLoop(Func<Tensor, Tensor> pred,
111-
Func<Tensor, Tensor> body,
112-
Tensor[] loop_vars,
110+
internal Tensor[] BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred,
111+
Func<Tensor, TItem, LoopVar<TItem>> body,
112+
TItem loop_vars,
113113
TensorShape shape_invariants,
114114
bool return_same_structure)
115115
{
@@ -131,88 +131,107 @@ public Tensor[] BuildLoop(Func<Tensor, Tensor> pred,
131131
return packed_exit_vars as Tensor[];
132132
}
133133

134-
private (Tensor[], Tensor[]) _BuildLoop(Func<Tensor, Tensor> pred,
135-
Func<Tensor, Tensor> body,
136-
Tensor[] original_loop_vars,
137-
Tensor[] loop_vars,
134+
private Tensor _convert_tensorarray_to_flow<TItem>(TItem tensor_or_tensor_array)
135+
{
136+
if (tensor_or_tensor_array is TensorArray tensor_array)
137+
return tensor_array.flow;
138+
else if (tensor_or_tensor_array is Tensor tensor)
139+
return tensor;
140+
141+
throw new NotImplementedException("_convert_tensorarray_to_flow");
142+
}
143+
144+
private (Tensor[], Tensor[]) _BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred,
145+
Func<Tensor, TItem, LoopVar<TItem>> body,
146+
TItem original_loop_vars,
147+
TItem loop_vars,
138148
TensorShape shape_invariants)
139149
{
140150
var flat_loop_vars = original_loop_vars;
141151

152+
// Convert TensorArrays to their flow variables
153+
var loop_vars_tensor = nest.map_structure(
154+
_convert_tensorarray_to_flow,
155+
nest.flatten(loop_vars));
156+
142157
// Let the context know the loop variables so the loop variables
143158
// would be added in the outer contexts properly.
144-
_InitializeValues(loop_vars);
145-
var real_vars = loop_vars;
146-
Tensor[] enter_vars = null;
147-
tf_with(ops.control_dependencies(null), delegate
159+
if (loop_vars is Tensor[] real_vars)
148160
{
149-
enter_vars = real_vars.Select(x => _Enter(x,
150-
_name,
151-
is_constant: false,
152-
parallel_iterations: _parallel_iterations,
153-
use_input_shape: shape_invariants == null))
154-
.ToArray();
155-
156-
foreach(var x in enter_vars)
161+
_InitializeValues(real_vars);
162+
Tensor[] enter_vars = null;
163+
tf_with(ops.control_dependencies(null), delegate
164+
{
165+
enter_vars = real_vars.Select(x => _Enter(x,
166+
_name,
167+
is_constant: false,
168+
parallel_iterations: _parallel_iterations,
169+
use_input_shape: shape_invariants == null))
170+
.ToArray();
171+
172+
foreach (var x in enter_vars)
173+
{
174+
x.graph.prevent_feeding(x);
175+
if (_outer_context != null)
176+
_outer_context.AddInnerOp(x.op);
177+
}
178+
});
179+
180+
// Finds the closest enclosing non-None control pivot.
181+
var outer_context = _outer_context;
182+
while (outer_context != null)
157183
{
158-
x.graph.prevent_feeding(x);
159-
if (_outer_context != null)
160-
_outer_context.AddInnerOp(x.op);
184+
161185
}
162-
});
163186

164-
// Finds the closest enclosing non-None control pivot.
165-
var outer_context = _outer_context;
166-
while (outer_context != null)
167-
{
187+
_SetShapeInvariants(real_vars, enter_vars, shape_invariants);
188+
189+
// Fix the control inputs and control flow context of these enter ops.
190+
_FixControlInputsAndContext(enter_vars);
191+
_InitializeValues(enter_vars);
192+
_loop_enters = enter_vars.ToList();
193+
194+
var merge_vars = enter_vars
195+
.Select(x => merge(new[] { x, x }))
196+
.ToArray();
197+
198+
_pivot_for_pred = merge_vars[0];
199+
200+
// Build the graph for pred.
201+
var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars);
202+
// var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays);
203+
var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0], default(TItem)));
204+
_pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond");
205+
var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot))
206+
.ToArray();
168207

208+
// Build the graph for body.
209+
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray();
210+
// Convert TensorArray flow variables inside the context back into
211+
// their associated TensorArrays for calling the body.
212+
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
213+
/*var body_result = body(packed_vars_for_body[0]);
214+
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);
215+
216+
// Store body_result to keep track of TensorArrays returned by body
217+
var original_body_result = new[] { body_result };
218+
// Convert TensorArrays returned by body into their flow variables
219+
var result = new[] { body_result };
220+
221+
var next_vars = new List<Tensor>();
222+
foreach (var (m, v) in zip(merge_vars, result))
223+
next_vars.Add(_AddNextAndBackEdge(m, v));
224+
225+
// Add the exit ops.
226+
var exit_vars = switch_vars.Select(x => exit(x[0])).ToList();
227+
_loop_exits = exit_vars;
228+
229+
// Exit the loop.
230+
// ExitResult(exit_vars);
231+
return (original_body_result, exit_vars.ToArray());*/
169232
}
170233

171-
_SetShapeInvariants(real_vars, enter_vars, shape_invariants);
172-
173-
// Fix the control inputs and control flow context of these enter ops.
174-
_FixControlInputsAndContext(enter_vars);
175-
_InitializeValues(enter_vars);
176-
_loop_enters = enter_vars.ToList();
177-
178-
var merge_vars = enter_vars
179-
.Select(x => merge(new[] { x, x }))
180-
.ToArray();
181-
182-
_pivot_for_pred = merge_vars[0];
183-
184-
// Build the graph for pred.
185-
var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars);
186-
// var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays);
187-
var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0]));
188-
_pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond");
189-
var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot))
190-
.ToArray();
191-
192-
// Build the graph for body.
193-
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray();
194-
// Convert TensorArray flow variables inside the context back into
195-
// their associated TensorArrays for calling the body.
196-
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
197-
var body_result = body(packed_vars_for_body[0]);
198-
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);
199-
200-
// Store body_result to keep track of TensorArrays returned by body
201-
var original_body_result = new[] { body_result };
202-
// Convert TensorArrays returned by body into their flow variables
203-
var result = new[] { body_result };
204-
205-
var next_vars = new List<Tensor>();
206-
foreach (var (m, v) in zip(merge_vars, result))
207-
next_vars.Add(_AddNextAndBackEdge(m, v));
208-
209-
// Add the exit ops.
210-
var exit_vars = switch_vars.Select(x => exit(x[0])).ToList();
211-
_loop_exits = exit_vars;
212-
213-
// Exit the loop.
214-
// ExitResult(exit_vars);
215-
return (original_body_result, exit_vars.ToArray());
234+
throw new NotImplementedException("");
216235
}
217236

218237
private void _FixControlInputsAndContext(Tensor[] enters)

0 commit comments

Comments
 (0)