Skip to content

Commit f1fbcf2

Browse files
committed
feat: support model building with RNN.
1 parent 1d97b71 commit f1fbcf2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+3662
-507
lines changed

src/TensorFlowNET.Core/APIs/c_api.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616

1717
using System;
1818
using System.Runtime.InteropServices;
19+
using static Tensorflow.CppShapeInferenceResult.Types;
1920

2021
namespace Tensorflow
2122
{
@@ -50,6 +51,19 @@ public static string StringPiece(IntPtr handle)
5051
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle);
5152
}
5253

54+
public unsafe static byte[] ByteStringPiece(IntPtr handle)
55+
{
56+
byte* str_data = (byte*)handle.ToPointer();
57+
List<byte> bytes = new List<byte>();
58+
byte current = 255;
59+
while (current != ((byte)'\0'))
60+
{
61+
current = *(str_data++);
62+
bytes.Add(current);
63+
}
64+
return bytes.Take(bytes.Count - 1).ToArray();
65+
}
66+
5367
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
5468
public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args);
5569

src/TensorFlowNET.Core/APIs/tf.control_flow.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ public Tensor while_loop(Func<Tensor, Tensor> cond,
4646
Tensor loop_vars,
4747
int parallel_iterations = 10)
4848
{
49-
Func<Tensor[], Tensor> cond1 = x
49+
Func<Tensors, Tensor> cond1 = x
5050
=> cond(x[0]);
5151

52-
Func<Tensor[], Tensor[]> body1 = x
52+
Func<Tensors, Tensors> body1 = x
5353
=> new[] { body(x[0]) };
5454

5555
var results = control_flow_ops.while_loop(cond1,
@@ -58,9 +58,9 @@ public Tensor while_loop(Func<Tensor, Tensor> cond,
5858
return results[0];
5959
}
6060

61-
public Tensor[] while_loop(Func<Tensor[], Tensor> cond,
62-
Func<Tensor[], Tensor[]> body,
63-
Tensor[] loop_vars,
61+
public Tensor[] while_loop(Func<Tensors, Tensor> cond,
62+
Func<Tensors, Tensors> body,
63+
Tensors loop_vars,
6464
int parallel_iterations = 10,
6565
string name = null)
6666
=> control_flow_ops.while_loop(cond, body, loop_vars,

src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@ public static IEnumerable<T> SkipLast<T>(this IEnumerable<T> sequence, int count
1818
return sequence.Take(sequence.Count() - count);
1919
}
2020
#endif
21-
public static Tensors ToTensors(this IEnumerable<Tensor> tensors)
21+
public static Tensors ToTensors(this Tensor[] tensors)
22+
{
23+
return new Tensors(tensors);
24+
}
25+
26+
public static Tensors ToTensors(this IList<Tensor> tensors)
2227
{
2328
return new Tensors(tensors);
2429
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Common.Types
6+
{
7+
/// <summary>
8+
/// This is a temp solution, which should be removed after refactoring `Tensors`
9+
/// </summary>
10+
[Obsolete]
11+
public class FakeTensorByTensorArray: Tensor
12+
{
13+
public TensorArray TensorArray { get; set; }
14+
15+
public FakeTensorByTensorArray(TensorArray array)
16+
{
17+
TensorArray = array;
18+
}
19+
}
20+
}

src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs

Lines changed: 42 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -5,136 +5,80 @@
55

66
namespace Tensorflow.Common.Types
77
{
8-
public class GeneralizedTensorShape: IEnumerable<long?[]>, INestStructure<long?>, INestable<long?>
8+
public class GeneralizedTensorShape: Nest<Shape>
99
{
10-
public TensorShapeConfig[] Shapes { get; set; }
11-
/// <summary>
12-
/// create a single-dim generalized Tensor shape.
13-
/// </summary>
14-
/// <param name="dim"></param>
15-
public GeneralizedTensorShape(int dim, int size = 1)
16-
{
17-
var elem = new TensorShapeConfig() { Items = new long?[] { dim } };
18-
Shapes = Enumerable.Repeat(elem, size).ToArray();
19-
//Shapes = new TensorShapeConfig[size];
20-
//Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } });
21-
//Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } });
22-
////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
23-
}
10+
////public TensorShapeConfig[] Shapes { get; set; }
11+
///// <summary>
12+
///// create a single-dim generalized Tensor shape.
13+
///// </summary>
14+
///// <param name="dim"></param>
15+
//public GeneralizedTensorShape(int dim, int size = 1)
16+
//{
17+
// var elem = new TensorShapeConfig() { Items = new long?[] { dim } };
18+
// Shapes = Enumerable.Repeat(elem, size).ToArray();
19+
// //Shapes = new TensorShapeConfig[size];
20+
// //Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } });
21+
// //Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } });
22+
// ////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
23+
//}
2424

25-
public GeneralizedTensorShape(Shape shape)
25+
public GeneralizedTensorShape(Shape value, string? name = null)
2626
{
27-
Shapes = new TensorShapeConfig[] { shape };
27+
NodeValue = value;
28+
NestType = NestType.Node;
2829
}
2930

30-
public GeneralizedTensorShape(TensorShapeConfig shape)
31+
public GeneralizedTensorShape(IEnumerable<Shape> values, string? name = null)
3132
{
32-
Shapes = new TensorShapeConfig[] { shape };
33+
ListValue = values.Select(s => new Nest<Shape>(s) as INestStructure<Shape>).ToList();
34+
Name = name;
35+
NestType = NestType.List;
3336
}
3437

35-
public GeneralizedTensorShape(TensorShapeConfig[] shapes)
38+
public GeneralizedTensorShape(Dictionary<string, Shape> value, string? name = null)
3639
{
37-
Shapes = shapes;
40+
DictValue = value.ToDictionary(x => x.Key, x => new Nest<Shape>(x.Value) as INestStructure<Shape>);
41+
Name = name;
42+
NestType = NestType.Dictionary;
3843
}
3944

40-
public GeneralizedTensorShape(IEnumerable<Shape> shape)
45+
public GeneralizedTensorShape(Nest<Shape> other)
4146
{
42-
Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray();
47+
NestType = other.NestType;
48+
NodeValue = other.NodeValue;
49+
DictValue = other.DictValue;
50+
ListValue = other.ListValue;
51+
Name = other.Name;
4352
}
4453

4554
public Shape ToSingleShape()
4655
{
47-
if (Shapes.Length != 1)
56+
var shapes = Flatten().ToList();
57+
if (shapes.Count != 1)
4858
{
4959
throw new ValueError("The generalized shape contains more than 1 dim.");
5060
}
51-
var shape_config = Shapes[0];
52-
Debug.Assert(shape_config is not null);
53-
return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray());
61+
return shapes[0];
5462
}
5563

5664
public long ToNumber()
5765
{
58-
if(Shapes.Length != 1 || Shapes[0].Items.Length != 1)
66+
var shapes = Flatten().ToList();
67+
if (shapes.Count != 1 || shapes[0].ndim != 1)
5968
{
6069
throw new ValueError("The generalized shape contains more than 1 dim.");
6170
}
62-
var res = Shapes[0].Items[0];
63-
return res is null ? -1 : res.Value;
64-
}
65-
66-
public Shape[] ToShapeArray()
67-
{
68-
return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray();
69-
}
70-
71-
public IEnumerable<long?> Flatten()
72-
{
73-
List<long?> result = new List<long?>();
74-
foreach(var shapeConfig in Shapes)
75-
{
76-
result.AddRange(shapeConfig.Items);
77-
}
78-
return result;
79-
}
80-
public INestStructure<TOut> MapStructure<TOut>(Func<long?, TOut> func)
81-
{
82-
List<Nest<TOut>> lists = new();
83-
foreach(var shapeConfig in Shapes)
84-
{
85-
lists.Add(new Nest<TOut>(shapeConfig.Items.Select(x => new Nest<TOut>(func(x)))));
86-
}
87-
return new Nest<TOut>(lists);
88-
}
89-
90-
public Nest<long?> AsNest()
91-
{
92-
Nest<long?> DealWithSingleShape(TensorShapeConfig config)
93-
{
94-
if (config.Items.Length == 0)
95-
{
96-
return Nest<long?>.Empty;
97-
}
98-
else if (config.Items.Length == 1)
99-
{
100-
return new Nest<long?>(config.Items[0]);
101-
}
102-
else
103-
{
104-
return new Nest<long?>(config.Items.Select(x => new Nest<long?>(x)));
105-
}
106-
}
107-
108-
if(Shapes.Length == 0)
109-
{
110-
return Nest<long?>.Empty;
111-
}
112-
else if(Shapes.Length == 1)
113-
{
114-
return DealWithSingleShape(Shapes[0]);
115-
}
116-
else
117-
{
118-
return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s)));
119-
}
71+
return shapes[0].dims[0];
12072
}
121-
122-
12373

124-
public static implicit operator GeneralizedTensorShape(int dims)
125-
=> new GeneralizedTensorShape(dims);
126-
127-
public IEnumerator<long?[]> GetEnumerator()
74+
public INestStructure<TensorShapeConfig> ToTensorShapeConfigs()
12875
{
129-
foreach (var shape in Shapes)
130-
{
131-
yield return shape.Items;
132-
}
76+
return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() });
13377
}
13478

135-
IEnumerator IEnumerable.GetEnumerator()
79+
public static implicit operator GeneralizedTensorShape(Shape shape)
13680
{
137-
return GetEnumerator();
81+
return new GeneralizedTensorShape(shape);
13882
}
13983
}
14084
}

src/TensorFlowNET.Core/Common/Types/INest.cs renamed to src/TensorFlowNET.Core/Common/Types/INestStructure.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,19 @@ namespace Tensorflow.Common.Types
1010
/// </summary>
1111
public interface INestStructure<T>: INestable<T>
1212
{
13+
NestType NestType { get; }
14+
15+
/// <summary>
16+
/// The item count of depth 1 of the nested structure.
17+
/// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3.
18+
/// </summary>
19+
int ShallowNestedCount { get; }
20+
/// <summary>
21+
/// The total item count of depth 1 of the nested structure.
22+
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5.
23+
/// </summary>
24+
int TotalNestedCount { get; }
25+
1326
/// <summary>
1427
/// Flatten the Nestable object. Node that if the object contains only one value,
1528
/// it will be flattened to an enumerable with one element.

src/TensorFlowNET.Core/Common/Types/Nest.Static.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public static class Nest
1313
/// <param name="template"></param>
1414
/// <param name="flatItems"></param>
1515
/// <returns></returns>
16-
public static Nest<T> PackSequenceAs<T>(INestable<T> template, T[] flatItems)
16+
public static Nest<TOut> PackSequenceAs<T, TOut>(INestable<T> template, TOut[] flatItems)
1717
{
1818
return template.AsNest().PackSequence(flatItems);
1919
}

0 commit comments

Comments
 (0)