Skip to content

Commit 0454c7b

Browse files
authored
Merge pull request #1110 from Wanglongzhi2001/master
feat: Support training of RNN and LSTM.
2 parents 1b1a503 + 35d2e10 commit 0454c7b

File tree

160 files changed

+11061
-867
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

160 files changed

+11061
-867
lines changed

src/TensorFlowNET.Core/APIs/c_api.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616

1717
using System;
1818
using System.Runtime.InteropServices;
19+
using static Tensorflow.CppShapeInferenceResult.Types;
1920

2021
namespace Tensorflow
2122
{
@@ -50,6 +51,19 @@ public static string StringPiece(IntPtr handle)
5051
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle);
5152
}
5253

54+
public unsafe static byte[] ByteStringPiece(IntPtr handle)
55+
{
56+
byte* str_data = (byte*)handle.ToPointer();
57+
List<byte> bytes = new List<byte>();
58+
byte current = 255;
59+
while (current != ((byte)'\0'))
60+
{
61+
current = *(str_data++);
62+
bytes.Add(current);
63+
}
64+
return bytes.Take(bytes.Count - 1).ToArray();
65+
}
66+
5367
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
5468
public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args);
5569

src/TensorFlowNET.Core/APIs/tf.control_flow.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ public Tensor while_loop(Func<Tensor, Tensor> cond,
4646
Tensor loop_vars,
4747
int parallel_iterations = 10)
4848
{
49-
Func<Tensor[], Tensor> cond1 = x
49+
Func<Tensors, Tensor> cond1 = x
5050
=> cond(x[0]);
5151

52-
Func<Tensor[], Tensor[]> body1 = x
52+
Func<Tensors, Tensors> body1 = x
5353
=> new[] { body(x[0]) };
5454

5555
var results = control_flow_ops.while_loop(cond1,
@@ -58,9 +58,9 @@ public Tensor while_loop(Func<Tensor, Tensor> cond,
5858
return results[0];
5959
}
6060

61-
public Tensor[] while_loop(Func<Tensor[], Tensor> cond,
62-
Func<Tensor[], Tensor[]> body,
63-
Tensor[] loop_vars,
61+
public Tensor[] while_loop(Func<Tensors, Tensor> cond,
62+
Func<Tensors, Tensors> body,
63+
Tensors loop_vars,
6464
int parallel_iterations = 10,
6565
string name = null)
6666
=> control_flow_ops.while_loop(cond, body, loop_vars,

src/TensorFlowNET.Core/APIs/tf.tensor.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,15 @@ public Tensor strided_slice<T>(Tensor input, T[] begin, T[] end, T[] strides = n
7171
public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null)
7272
=> array_ops.split(
7373
value: value,
74-
num_split: num_split,
74+
num_or_size_splits: num_split,
7575
axis: axis,
7676
name: name);
7777

7878
public Tensor[] split(Tensor value, int num_split, int axis, string name = null)
7979
=> array_ops.split(
8080
value: value,
81-
num_split: num_split,
82-
axis: axis,
81+
num_or_size_splits: num_split,
82+
axis: ops.convert_to_tensor(axis),
8383
name: name);
8484

8585
public Tensor ensure_shape(Tensor x, Shape shape, string name = null)

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ public static TF_DataType GetDataType(this object data)
524524
case Tensors tensors:
525525
return tensors.dtype;
526526
case IEnumerable<Tensor> tensors:
527-
return tensors.First().dtype;
527+
return tensors.Where(x => x is not null).First().dtype;
528528
case RefVariable variable:
529529
return variable.dtype;
530530
case ResourceVariable variable:

src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs renamed to src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
using System.Collections.Generic;
44
using System.Text;
55

6-
namespace Tensorflow.Extensions
6+
namespace Tensorflow.Common.Extensions
77
{
88
public static class JObjectExtensions
99
{
1010
public static T? TryGetOrReturnNull<T>(this JObject obj, string key)
1111
{
1212
var res = obj[key];
13-
if(res is null)
13+
if (res is null)
1414
{
15-
return default(T);
15+
return default;
1616
}
1717
else
1818
{
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
6+
namespace Tensorflow.Common.Extensions
7+
{
8+
public static class LinqExtensions
9+
{
10+
#if NETSTANDARD2_0
11+
public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> sequence, int count)
12+
{
13+
return sequence.Skip(sequence.Count() - count);
14+
}
15+
16+
public static IEnumerable<T> SkipLast<T>(this IEnumerable<T> sequence, int count)
17+
{
18+
return sequence.Take(sequence.Count() - count);
19+
}
20+
#endif
21+
public static Tensors ToTensors(this Tensor[] tensors)
22+
{
23+
return new Tensors(tensors);
24+
}
25+
26+
public static Tensors ToTensors(this IList<Tensor> tensors)
27+
{
28+
return new Tensors(tensors);
29+
}
30+
31+
public static void Deconstruct<T1, T2, T3>(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third)
32+
{
33+
first = values.Item1;
34+
second = values.Item2;
35+
third = values.Item3;
36+
}
37+
}
38+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Common.Types;
5+
6+
namespace Tensorflow.Common.Extensions
7+
{
8+
public static class NestExtensions
9+
{
10+
public static Tensors ToTensors(this INestable<Tensor> tensors)
11+
{
12+
return new Tensors(tensors.AsNest());
13+
}
14+
15+
public static Tensors? ToTensors(this Nest<Tensor> tensors)
16+
{
17+
return Tensors.FromNest(tensors);
18+
}
19+
20+
/// <summary>
21+
/// If the nested object is already a nested type, this function could reduce it.
22+
/// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`.
23+
/// </summary>
24+
/// <typeparam name="TIn"></typeparam>
25+
/// <typeparam name="TOut"></typeparam>
26+
/// <param name="input"></param>
27+
/// <returns></returns>
28+
public static Nest<TOut> ReduceTo<TIn, TOut>(this INestStructure<TIn> input) where TIn: INestStructure<TOut>
29+
{
30+
return Nest<TOut>.ReduceFrom(input);
31+
}
32+
}
33+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Common.Types
6+
{
7+
/// <summary>
8+
/// This is a temp solution, which should be removed after refactoring `Tensors`
9+
/// </summary>
10+
[Obsolete]
11+
public class FakeTensorByTensorArray: Tensor
12+
{
13+
public TensorArray TensorArray { get; set; }
14+
15+
public FakeTensorByTensorArray(TensorArray array)
16+
{
17+
TensorArray = array;
18+
}
19+
}
20+
}

0 commit comments

Comments
 (0)