Skip to content

Commit a65d881

Browse files
committed
CheckInputFromValidContext
1 parent 8e054af commit a65d881

File tree

11 files changed

+117
-33
lines changed

11 files changed

+117
-33
lines changed

src/TensorFlowNET.Core/Operations/NnOps/rnn.cs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, T
172172

173173
for (int i = 0; i < input_ta.Count; i++)
174174
{
175-
var (ta, input_) = (input_ta[0], flat_input[0]);
175+
var (ta, input_) = (input_ta[i], flat_input[i]);
176+
ta.unstack(input_);
176177
}
177178
}
178179

@@ -185,16 +186,16 @@ private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, T
185186

186187
Func<BodyItemInRnnWhileLoop, Tensor> cond = (item) =>
187188
{
188-
return time < loop_bound;
189+
return item.time < loop_bound;
189190
};
190191

191192
// Take a time step of the dynamic RNN.
192193
Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) =>
193194
{
194-
return item;
195+
throw new NotImplementedException("");
195196
};
196197

197-
control_flow_ops.while_loop<BodyItemInRnnWhileLoop>(
198+
control_flow_ops.while_loop(
198199
cond: cond,
199200
body: _time_step,
200201
loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state),

src/TensorFlowNET.Core/Operations/Operation.Control.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,9 @@ public partial class Operation
3030
/// </summary>
3131
public void _control_flow_post_processing()
3232
{
33-
foreach(var input_tensor in inputs)
33+
foreach(Tensor input_tensor in inputs)
3434
{
35-
//TODO: implement below code dependency
36-
//control_flow_util.CheckInputFromValidContext(this, input_tensor.op);
35+
control_flow_util.CheckInputFromValidContext(this, input_tensor.op);
3736
}
3837

3938
if (_control_flow_context != null)

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License.
2323
using System.IO;
2424
using System.Linq;
2525
using Tensorflow.Util;
26+
using static Tensorflow.Binding;
2627

2728
namespace Tensorflow
2829
{

src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,25 @@ namespace Tensorflow.Operations
2525
internal class _GraphTensorArray
2626
{
2727
internal TF_DataType _dtype;
28+
public TF_DataType dtype => _dtype;
2829

2930
/// <summary>
3031
/// Used to keep track of what tensors the TensorArray should be
3132
/// colocated with. We choose to colocate the TensorArray with the
3233
/// first tensor written to it.
3334
/// </summary>
3435
bool _colocate_with_first_write_call;
36+
public bool colocate_with_first_write_call => _colocate_with_first_write_call;
3537

3638
bool _infer_shape;
37-
bool _dynamic_size;
38-
List<TensorShape> _element_shape;
39+
public bool infer_shape => _infer_shape;
40+
public bool _dynamic_size;
41+
public List<TensorShape> _element_shape;
3942

40-
List<Tensor> _colocate_with;
43+
public List<Tensor> _colocate_with;
4144

4245
internal Tensor _handle;
46+
public Tensor handle => _handle;
4347
internal Tensor _flow;
4448

4549
public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null,

src/TensorFlowNET.Core/Operations/control_flow_ops.cs

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
using Tensorflow.Operations.ControlFlows;
2222
using util = Tensorflow.control_flow_util;
2323
using static Tensorflow.Binding;
24+
using Tensorflow.Util;
2425

2526
namespace Tensorflow
2627
{
@@ -251,12 +252,16 @@ public static Tensor _Identity(Tensor data, string name = null)
251252
return gen_array_ops.identity(data, name: name);
252253
}
253254

254-
public static void _SetShapeInvariants(Tensor[] input_vars, Tensor[] enter_vars, TensorShape shapes = null)
255+
public static void _SetShapeInvariants(Tensor[] input_vars, Tensor[] enter_vars, TensorShape[] shapes = null)
255256
{
256257
if (shapes == null)
257258
return;
258259

259-
throw new NotImplementedException("_SetShapeInvariants");
260+
var flat_shapes = nest.flatten2(shapes);
261+
foreach (var (inp, var, shape) in zip(input_vars, enter_vars, flat_shapes))
262+
{
263+
var.set_shape(shape);
264+
}
260265
}
261266

262267
/// <summary>
@@ -428,12 +433,12 @@ raise ValueError(
428433
.Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 }))
429434
.ToArray();
430435

431-
merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges);
436+
var merges2 = _convert_flows_to_tensorarrays(new ITensorOrTensorArray[] { (Tensor)orig_res_t }, merges);
432437

433438
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
434439
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);
435440

436-
return merges[0];
441+
return new Tensor(IntPtr.Zero);
437442
});
438443
}
439444

@@ -473,22 +478,28 @@ public static Tensor[] cond<T>(Tensor pred,
473478
var res_f_flat = res_f;
474479

475480
var merges = zip(res_f_flat, res_t_flat)
476-
.Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 }))
481+
.Select(pair => merge(new [] { pair.Item1, pair.Item2 }))
477482
.ToArray();
478483

479-
merges = _convert_flows_to_tensorarrays(orig_res_t, merges);
484+
var merges2 = _convert_flows_to_tensorarrays(orig_res_t.Select(x => (ITensorOrTensorArray)x).ToArray(), merges);
480485

481486
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
482487
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);
483488

484-
return merges;
489+
return new[] { new Tensor(IntPtr.Zero) };
485490
});
486491
}
487492

488-
public static Tensor[] _convert_flows_to_tensorarrays<T>(T tensors_or_tensorarrays, Tensor[] tensors_or_flows)
493+
public static ITensorOrTensorArray[] _convert_flows_to_tensorarrays(ITensorOrTensorArray[] tensors_or_tensorarrays, Tensor[] tensors_or_flows)
489494
{
490-
// zip(tensors_or_tensorarrays, tensors_or_flows).Select((ta, t_or_flow) => ta).ToArray();
491-
return tensors_or_flows;
495+
return zip(tensors_or_tensorarrays, tensors_or_flows).Select(x =>
496+
{
497+
var (ta, t_or_flow) = (x.Item1, x.Item2);
498+
if (ta is TensorArray ta_1)
499+
return tensor_array_ops.build_ta_with_new_flow(ta_1, t_or_flow) as ITensorOrTensorArray;
500+
else
501+
return t_or_flow as ITensorOrTensorArray;
502+
}).ToArray();
492503
}
493504

494505
/// <summary>
@@ -592,7 +603,7 @@ public static Tensor ZerosLikeOutsideLoop(Operation op, int index)
592603
/// <param name="loop_vars"></param>
593604
/// <param name="i"></param>
594605
public static Tensor while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TItem> body, TItem loop_vars,
595-
TensorShape shape_invariants = null,
606+
TensorShape[] shape_invariants = null,
596607
int parallel_iterations = 10,
597608
bool back_prop = true,
598609
bool swap_memory = false,
@@ -617,8 +628,8 @@ public static Tensor while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TIt
617628
var orig_body = body;
618629

619630
LoopVar<TItem> loop_vars_1 = null;
620-
Func<Tensor, TItem, LoopVar<TItem>> body_buildloop = null;
621-
Func<Tensor, TItem, Tensor> cond_buildloop = null;
631+
Func<LoopVar<TItem>, LoopVar<TItem>> body_buildloop = null;
632+
Func<LoopVar<TItem>, Tensor> cond_buildloop = null;
622633

623634
if (try_to_pack)
624635
{
@@ -627,9 +638,18 @@ public static Tensor while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TIt
627638
else
628639
{
629640
loop_vars_1 = new LoopVar<TItem>(counter, loop_vars);
630-
cond_buildloop = (i, lv) =>
631-
math_ops.logical_and(i < maximum_iterations, orig_cond(lv));
632-
body_buildloop = (i, lv) => new LoopVar<TItem>(i + 1, orig_body(lv));
641+
cond_buildloop = (item) =>
642+
{
643+
var (i, lv) = (item.Counter, item.Item);
644+
var oc = orig_cond(lv);
645+
return math_ops.logical_and(i < maximum_iterations, oc);
646+
};
647+
648+
body_buildloop = (item) =>
649+
{
650+
var (i, lv) = (item.Counter, item.Item);
651+
return new LoopVar<TItem>(i + 1, orig_body(lv));
652+
};
633653
}
634654
try_to_pack = false;
635655

src/TensorFlowNET.Core/Operations/control_flow_util.py.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using System;
1718
using Tensorflow.Operations;
19+
using static Tensorflow.Binding;
1820

1921
namespace Tensorflow
2022
{
@@ -53,5 +55,25 @@ public static ControlFlowContext GetOutputContext(Operation op)
5355
ctxt = ctxt.outer_context;
5456
return ctxt;
5557
}
58+
59+
public static void CheckInputFromValidContext(Operation op, Operation input_op)
60+
{
61+
var op_ctxt = op._get_control_flow_context();
62+
var input_ctxt = GetOutputContext(input_op);
63+
var valid = false;
64+
if (input_ctxt == null)
65+
valid = true;
66+
else if (op_ctxt == input_ctxt)
67+
valid = true;
68+
else
69+
{
70+
throw new NotImplementedException("");
71+
}
72+
73+
if (!valid)
74+
{
75+
throw new NotImplementedException("");
76+
}
77+
}
5678
}
5779
}

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using static Tensorflow.Binding;
18+
1719
namespace Tensorflow
1820
{
1921
public static class gen_math_ops
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class tensor_array_ops
8+
{
9+
/// <summary>
10+
/// Builds a TensorArray with a new `flow` tensor.
11+
/// </summary>
12+
/// <param name="old_ta"></param>
13+
/// <param name="flow"></param>
14+
/// <returns></returns>
15+
public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow)
16+
{
17+
var impl = old_ta._implementation;
18+
19+
var new_ta = new TensorArray(
20+
dtype: impl.dtype,
21+
handle: impl.handle,
22+
flow: flow,
23+
infer_shape: impl.infer_shape,
24+
colocate_with_first_write_call: impl.colocate_with_first_write_call);
25+
26+
var new_impl = new_ta._implementation;
27+
new_impl._dynamic_size = impl._dynamic_size;
28+
new_impl._colocate_with = impl._colocate_with;
29+
new_impl._element_shape = impl._element_shape;
30+
return new_ta;
31+
}
32+
}
33+
}

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ Building, training and infering deep learning models.
2020
https://tensorflownet.readthedocs.io</Description>
2121
<AssemblyVersion>0.12.0.0</AssemblyVersion>
2222
<PackageReleaseNotes>Changes since v0.11.0:
23-
</PackageReleaseNotes>
23+
1: Add ICanBeFlattened for nest.flatten2.
24+
2:</PackageReleaseNotes>
2425
<LangVersion>7.3</LangVersion>
2526
<FileVersion>0.12.0.0</FileVersion>
2627
<PackageLicenseFile>LICENSE</PackageLicenseFile>

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace Tensorflow
3939
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes.
4040
/// </summary>
4141
[SuppressMessage("ReSharper", "ConvertToAutoProperty")]
42-
public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike
42+
public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike, ITensorOrTensorArray
4343
{
4444
private readonly int _id;
4545
private readonly Operation _op;
@@ -178,7 +178,7 @@ public int[] _shape_tuple()
178178
/// </summary>
179179
public void set_shape(TensorShape shape)
180180
{
181-
this.shape = shape.rank > 0 ? shape.dims : null;
181+
this.shape = shape.rank >= 0 ? shape.dims : null;
182182
}
183183

184184
/// <summary>

0 commit comments

Comments
 (0)