Skip to content

Commit 547c4e6

Browse files
committed
tf.while_loop, add ICanBeFlattened #348
1 parent f7cee4b commit 547c4e6

File tree

7 files changed

+48
-32
lines changed

7 files changed

+48
-32
lines changed

src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,25 @@
44

55
namespace Tensorflow.Operations
66
{
7-
internal class LoopVar<TItem>
7+
internal class LoopVar<TItem> : ICanBeFlattened
88
{
99
public Tensor Counter { get; }
10-
public TItem[] Items { get; }
1110
public TItem Item { get; }
1211

13-
public LoopVar(Tensor counter, TItem[] items)
12+
public LoopVar(Tensor counter, TItem item)
1413
{
1514
Counter = counter;
16-
Items = items;
15+
Item = item;
1716
}
1817

19-
public LoopVar(Tensor counter, TItem item)
18+
public object[] Flatten()
2019
{
21-
Counter = counter;
22-
Item = item;
20+
var elements = new List<object> { Counter };
21+
if (typeof(TItem).GetInterface(typeof(ICanBeFlattened).Name) != null)
22+
elements.AddRange((Item as ICanBeFlattened).Flatten());
23+
else
24+
elements.Add(Item);
25+
return elements.ToArray();
2326
}
2427
}
2528
}

src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ private void _init_from_proto(WhileContextDef context_def, string import_scope =
109109
/// </summary>
110110
internal Tensor[] BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred,
111111
Func<Tensor, TItem, LoopVar<TItem>> body,
112-
TItem loop_vars,
112+
LoopVar<TItem> loop_vars,
113113
TensorShape shape_invariants,
114114
bool return_same_structure)
115115
{
@@ -143,16 +143,16 @@ private Tensor _convert_tensorarray_to_flow<TItem>(TItem tensor_or_tensor_array)
143143

144144
private (Tensor[], Tensor[]) _BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred,
145145
Func<Tensor, TItem, LoopVar<TItem>> body,
146-
TItem original_loop_vars,
147-
TItem loop_vars,
146+
LoopVar<TItem> original_loop_vars,
147+
LoopVar<TItem> loop_vars,
148148
TensorShape shape_invariants)
149149
{
150150
var flat_loop_vars = original_loop_vars;
151151

152152
// Convert TensorArrays to their flow variables
153153
var loop_vars_tensor = nest.map_structure(
154154
_convert_tensorarray_to_flow,
155-
nest.flatten(loop_vars));
155+
nest.flatten2(loop_vars));
156156

157157
// Let the context know the loop variables so the loop variables
158158
// would be added in the outer contexts properly.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Operations
6+
{
7+
public interface ICanBeFlattened
8+
{
9+
object[] Flatten();
10+
}
11+
}

src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
namespace Tensorflow.Operations
66
{
7-
internal class BodyItemInRnnWhileLoop
7+
internal class BodyItemInRnnWhileLoop : ICanBeFlattened
88
{
99
/// <summary>
1010
/// int32 scalar Tensor.
@@ -28,5 +28,13 @@ public BodyItemInRnnWhileLoop(Tensor time, TensorArray[] output_ta_t, Tensor sta
2828

2929
public static implicit operator (Tensor, TensorArray[], Tensor)(BodyItemInRnnWhileLoop item)
3030
=> (item.time, item.output_ta_t, item.state);
31+
32+
public object[] Flatten()
33+
{
34+
var elements = new List<object> { time };
35+
elements.AddRange(output_ta_t);
36+
elements.Add(state);
37+
return elements.ToArray();
38+
}
3139
}
3240
}

src/TensorFlowNET.Core/Operations/control_flow_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ public static Tensor while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TIt
642642
if (loop_context.outer_context == null)
643643
ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context);
644644

645-
var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars, shape_invariants,
645+
var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars_1, shape_invariants,
646646
return_same_structure);
647647

648648
if (maximum_iterations != null)

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
<AssemblyName>TensorFlow.NET</AssemblyName>
66
<RootNamespace>Tensorflow</RootNamespace>
77
<TargetTensorFlow>1.14.0</TargetTensorFlow>
8-
<Version>0.11.8.1</Version>
8+
<Version>0.12.0</Version>
99
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
1010
<Company>SciSharp STACK</Company>
1111
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -16,25 +16,13 @@
1616
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIconUrl>
1717
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags>
1818
<Description>Google's TensorFlow full binding in .NET Standard.
19-
Docs: https://tensorflownet.readthedocs.io</Description>
20-
<AssemblyVersion>0.11.8.1</AssemblyVersion>
21-
<PackageReleaseNotes>Changes since v0.10.0:
22-
1. Upgrade NumSharp to v0.20.3.
23-
2. Add DisposableObject class to manage object lifetime.
24-
3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables.
25-
4. Change tensorflow to non-static class in order to execute some initialization process.
26-
5. Overload session.run(), make syntax simpler.
27-
6. Add Local Response Normalization.
28-
7. Add tf.image related APIs.
29-
8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor.
30-
9. MultiThread is safe.
31-
10. Support n-dim indexing for tensor.
32-
11. Add RegisterNoGradients
33-
12. Add CumsumGrad, BroadcastToGrad.
34-
13. Return VariableV1 instead of RefVariable.
35-
14. Add Tensor overload to GradientDescentOptimizer.</PackageReleaseNotes>
19+
Building, training and infering deep learning models.
20+
https://tensorflownet.readthedocs.io</Description>
21+
<AssemblyVersion>0.12.0.0</AssemblyVersion>
22+
<PackageReleaseNotes>Changes since v0.11.0:
23+
</PackageReleaseNotes>
3624
<LangVersion>7.3</LangVersion>
37-
<FileVersion>0.11.8.1</FileVersion>
25+
<FileVersion>0.12.0.0</FileVersion>
3826
<PackageLicenseFile>LICENSE</PackageLicenseFile>
3927
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
4028
<SignAssembly>true</SignAssembly>

src/TensorFlowNET.Core/Util/nest.py.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
using System.Collections.Generic;
2020
using System.Linq;
2121
using NumSharp;
22+
using Tensorflow.Operations;
2223

2324
namespace Tensorflow.Util
2425
{
@@ -221,6 +222,11 @@ public static List<T> flatten<T>(T structure)
221222
return list;
222223
}
223224

225+
public static object[] flatten2(ICanBeFlattened structure)
226+
{
227+
return structure.Flatten();
228+
}
229+
224230
private static void _flatten_recursive<T>(T obj, List<T> list)
225231
{
226232
switch(obj)

0 commit comments

Comments
 (0)