Skip to content

Commit 07f70f9

Browse files
committed
WhileContext BuildLoop
1 parent ed9a8c8 commit 07f70f9

File tree

12 files changed

+117
-40
lines changed

12 files changed

+117
-40
lines changed

src/TensorFlowNET.Core/APIs/tf.control_flow.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public Tensor cond(Tensor pred,
3737
public Operation group<T>(T[] inputs, string name = null) where T : ITensorOrOperation
3838
=> control_flow_ops.group(inputs, name: name);
3939

40-
public Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars,
40+
/*public Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars,
4141
TensorShape shape_invariants = null,
4242
int parallel_iterations = 10,
4343
bool back_prop = true,
@@ -52,7 +52,7 @@ public Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, T
5252
swap_memory: swap_memory,
5353
name: name,
5454
maximum_iterations: maximum_iterations,
55-
return_same_structure: return_same_structure);
55+
return_same_structure: return_same_structure);*/
5656

5757
public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
5858
=> ops.control_dependencies(control_inputs);
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Operations
6+
{
7+
internal class LoopVar<TItem>
8+
{
9+
public Tensor Counter { get; }
10+
public TItem[] Items { get; }
11+
public TItem Item { get; }
12+
13+
public LoopVar(Tensor counter, TItem[] items)
14+
{
15+
Counter = counter;
16+
Items = items;
17+
}
18+
19+
public LoopVar(Tensor counter, TItem item)
20+
{
21+
Counter = counter;
22+
Item = item;
23+
}
24+
}
25+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Operations
6+
{
7+
internal class BodyItemInRnnWhileLoop
8+
{
9+
/// <summary>
10+
/// int32 scalar Tensor.
11+
/// </summary>
12+
public Tensor time { get; set; }
13+
/// <summary>
14+
/// List of `TensorArray`s that represent the output.
15+
/// </summary>
16+
public TensorArray[] output_ta_t { get; set; }
17+
/// <summary>
18+
/// nested tuple of vector tensors that represent the state.
19+
/// </summary>
20+
public Tensor state { get; set; }
21+
22+
public BodyItemInRnnWhileLoop(Tensor time, TensorArray[] output_ta_t, Tensor state)
23+
{
24+
this.time = time;
25+
this.output_ta_t = output_ta_t;
26+
this.state = state;
27+
}
28+
29+
public static implicit operator (Tensor, TensorArray[], Tensor)(BodyItemInRnnWhileLoop item)
30+
=> (item.time, item.output_ta_t, item.state);
31+
}
32+
}

src/TensorFlowNET.Core/Operations/NnOps/rnn.cs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, T
145145
{
146146
var ta = new TensorArray(dtype: dtype_,
147147
size: time_steps,
148-
element_shape: new[] { element_shape },
148+
element_shape: element_shape,
149149
tensor_array_name: base_name + name);
150150
return ta;
151151
};
@@ -178,19 +178,29 @@ private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, T
178178

179179
// Make sure that we run at least 1 step, if necessary, to ensure
180180
// the TensorArrays pick up the dynamic shape.
181-
Tensor loop_bound;
181+
Tensor loop_bound = null;
182182
if (in_graph_mode)
183183
loop_bound = math_ops.minimum(
184184
time_steps, math_ops.maximum(1, max_sequence_length));
185185

186-
/*Func<Tensor, Tensor> cond = (ctime) =>
186+
Func<BodyItemInRnnWhileLoop, Tensor> cond = (item) =>
187187
{
188-
return null;
188+
return time < loop_bound;
189189
};
190190

191-
control_flow_ops.while_loop(
191+
// Take a time step of the dynamic RNN.
192+
Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) =>
193+
{
194+
return item;
195+
};
196+
197+
control_flow_ops.while_loop<BodyItemInRnnWhileLoop>(
192198
cond: cond,
193-
body = );*/
199+
body: _time_step,
200+
loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state),
201+
parallel_iterations: parallel_iterations,
202+
maximum_iterations: time_steps,
203+
swap_memory: swap_memory);
194204

195205
throw new NotImplementedException("");
196206
}

src/TensorFlowNET.Core/Operations/TensorArray.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public class TensorArray
3939

4040
public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null,
4141
string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
42-
bool infer_shape = true, TensorShape[] element_shape = null,
42+
bool infer_shape = true, TensorShape element_shape = null,
4343
bool colocate_with_first_write_call = true, string name = null)
4444
{
4545
_implementation = new _GraphTensorArray(dtype,

src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ internal class _GraphTensorArray
4444

4545
public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null,
4646
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
47-
bool infer_shape = true, TensorShape[] element_shape = null,
47+
bool infer_shape = true, TensorShape element_shape = null,
4848
bool colocate_with_first_write_call = true, string name = null)
4949
{
5050
clear_after_read = clear_after_read ?? true;
@@ -68,7 +68,7 @@ public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = nu
6868
else
6969
{
7070
_infer_shape = true;
71-
_element_shape = new List<TensorShape> { };
71+
_element_shape = new List<TensorShape> { element_shape };
7272
}
7373

7474
tf_with(ops.name_scope(name, "TensorArray", new { handle, size, flow }), scope =>
@@ -135,7 +135,7 @@ public TensorArray scatter(Tensor indices, Tensor value, string name = null)
135135

136136
var ta = new TensorArray(_dtype,
137137
infer_shape:_infer_shape,
138-
element_shape: _element_shape.ToArray(),
138+
element_shape: _element_shape[0],
139139
dynamic_size: _dynamic_size,
140140
handle: _handle,
141141
flow: flow_out,

src/TensorFlowNET.Core/Operations/control_flow_ops.cs

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ public static Tensor[] cond<T>(Tensor pred,
485485
});
486486
}
487487

488-
public static Tensor[] _convert_flows_to_tensorarrays<T>(T[] tensors_or_tensorarrays, Tensor[] tensors_or_flows)
488+
public static Tensor[] _convert_flows_to_tensorarrays<T>(T tensors_or_tensorarrays, Tensor[] tensors_or_flows)
489489
{
490490
// zip(tensors_or_tensorarrays, tensors_or_flows).Select((ta, t_or_flow) => ta).ToArray();
491491
return tensors_or_flows;
@@ -591,18 +591,18 @@ public static Tensor ZerosLikeOutsideLoop(Operation op, int index)
591591
/// <param name="body"></param>
592592
/// <param name="loop_vars"></param>
593593
/// <param name="i"></param>
594-
public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars,
594+
public static Tensor while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TItem> body, TItem loop_vars,
595595
TensorShape shape_invariants = null,
596596
int parallel_iterations = 10,
597597
bool back_prop = true,
598598
bool swap_memory = false,
599599
string name = null,
600-
int? maximum_iterations = null,
600+
Tensor maximum_iterations = null,
601601
bool return_same_structure = false)
602602
{
603603
tf_with(ops.name_scope(name, "while", loop_vars), scope =>
604604
{
605-
if (loop_vars == null || loop_vars.Length == 0)
605+
if (loop_vars == null)
606606
throw new ValueError("No loop variables provided");
607607
if (cond == null)
608608
throw new ValueError("cond must be callable.");
@@ -611,6 +611,28 @@ public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor>
611611
if (parallel_iterations < 1)
612612
throw new ValueError("parallel_iterations must be a positive integer.");
613613

614+
var try_to_pack = loop_vars is Tensor && !return_same_structure;
615+
var counter = constant_op.constant(0, dtype: maximum_iterations.dtype, name: "iteration_counter");
616+
var orig_cond = cond;
617+
var orig_body = body;
618+
619+
LoopVar<TItem> loop_vars_1 = null;
620+
Func<Tensor, TItem, LoopVar<TItem>> body_buildloop = null;
621+
Func<Tensor, TItem, Tensor> cond_buildloop = null;
622+
623+
if (try_to_pack)
624+
{
625+
626+
}
627+
else
628+
{
629+
loop_vars_1 = new LoopVar<TItem>(counter, loop_vars);
630+
cond_buildloop = (i, lv) =>
631+
math_ops.logical_and(i < maximum_iterations, orig_cond(lv));
632+
body_buildloop = (i, lv) => new LoopVar<TItem>(i + 1, orig_body(lv));
633+
}
634+
try_to_pack = false;
635+
614636
var loop_context = new WhileContext(
615637
maximum_iterations: maximum_iterations,
616638
parallel_iterations: parallel_iterations,
@@ -620,7 +642,7 @@ public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor>
620642
if (loop_context.outer_context == null)
621643
ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context);
622644

623-
var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants,
645+
var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars, shape_invariants,
624646
return_same_structure);
625647

626648
if (maximum_iterations != null)

src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,9 @@ public static Tensor dynamic_stitch(Tensor[] indices, Tensor[] data, string name
2828
}
2929

3030
public static (Tensor, Tensor) tensor_array_v3<T>(T size, TF_DataType dtype = TF_DataType.DtInvalid,
31-
TensorShape[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true,
32-
bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null)
31+
TensorShape element_shape = null, bool dynamic_size = false, bool clear_after_read = true,
32+
bool identical_element_shapes = false, string tensor_array_name = "", string name = null)
3333
{
34-
if (tensor_array_name == null)
35-
tensor_array_name = string.Empty;
36-
3734
var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new
3835
{
3936
size,

src/TensorFlowNET.Core/Util/nest.py.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ public static List<T> flatten<T>(T structure)
223223

224224
private static void _flatten_recursive<T>(T obj, List<T> list)
225225
{
226-
227226
switch(obj)
228227
{
229228
case IDictionary dict:

src/TensorFlowNET.Core/tensorflow.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,14 @@ public Session Session()
9393
return new Session().as_default();
9494
}
9595

96-
public Session Session(Graph graph, SessionOptions opts = null)
96+
public Session Session(Graph graph, ConfigProto config = null)
9797
{
98-
return new Session(graph, opts: opts).as_default();
98+
return new Session(graph, config: config).as_default();
9999
}
100100

101-
public Session Session(SessionOptions opts)
101+
public Session Session(ConfigProto config)
102102
{
103-
return new Session(null, opts).as_default();
103+
return new Session(null, config).as_default();
104104
}
105105

106106
public void __init__()

0 commit comments

Comments
 (0)