Skip to content

Commit 2001619

Browse files
committed
AutoNumPy annotation.
1 parent 99fc016 commit 2001619

File tree

14 files changed

+173
-70
lines changed

14 files changed

+173
-70
lines changed

src/TensorFlowNET.Core/Contexts/Context.cs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
using static Tensorflow.Binding;
2222
using Google.Protobuf;
2323
using Tensorflow.Util;
24+
using Tensorflow.NumPy;
2425

2526
namespace Tensorflow.Contexts
2627
{
@@ -40,8 +41,15 @@ public sealed partial class Context : IDisposable
4041
public FunctionCallOptions FunctionCallOptions { get; }
4142

4243
SafeContextHandle _handle;
43-
public SafeContextHandle Handle => _handle;
44-
44+
public SafeContextHandle Handle
45+
{
46+
get
47+
{
48+
if (_handle == null)
49+
ensure_initialized();
50+
return _handle;
51+
}
52+
}
4553
int? _seed;
4654
Random _rng;
4755

@@ -142,7 +150,11 @@ public bool has_graph_arg(params object[] args)
142150
bool has_graph_arg = !tf.Context.executing_eagerly();
143151
foreach (var el in flatten_args)
144152
{
145-
if (el is Tensor tensor && tensor.IsCreatedInGraphMode)
153+
if (el is NDArray)
154+
continue;
155+
else if (el is EagerTensor)
156+
continue;
157+
else if (el is Tensor)
146158
{
147159
has_graph_arg = true;
148160
break;
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
using MethodBoundaryAspect.Fody.Attributes;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using Tensorflow.Eager;
6+
using Tensorflow.Functions;
7+
using static Tensorflow.Binding;
8+
9+
namespace Tensorflow.NumPy
10+
{
11+
public sealed class AutoNumPyAttribute : OnMethodBoundaryAspect
12+
{
13+
bool _changedMode = false;
14+
15+
public override void OnEntry(MethodExecutionArgs args)
16+
{
17+
if (!tf.executing_eagerly())
18+
{
19+
tf.Context.eager_mode();
20+
_changedMode = true;
21+
}
22+
}
23+
24+
public override void OnExit(MethodExecutionArgs args)
25+
{
26+
if (_changedMode)
27+
tf.Context.restore_mode();
28+
}
29+
}
30+
}

src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using System;
1+
 using System;
22
using System.Collections;
33
using System.Collections.Generic;
44
using System.IO;
@@ -27,6 +27,33 @@ public NDArray load(string file)
2727
return result.reshape(shape);
2828
}
2929

30+
public Array LoadMatrix(Stream stream)
31+
{
32+
using (var reader = new BinaryReader(stream, System.Text.Encoding.ASCII, leaveOpen: true))
33+
{
34+
int bytes;
35+
Type type;
36+
int[] shape;
37+
if (!ParseReader(reader, out bytes, out type, out shape))
38+
throw new FormatException();
39+
40+
Array matrix = Array.CreateInstance(type, shape);
41+
42+
//if (type == typeof(String))
43+
//return ReadStringMatrix(reader, matrix, bytes, type, shape);
44+
return ReadValueMatrix(reader, matrix, bytes, type, shape);
45+
}
46+
}
47+
48+
public T Load<T>(Stream stream)
49+
where T : class,
50+
ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable
51+
{
52+
// if (typeof(T).IsArray && (typeof(T).GetElementType().IsArray || typeof(T).GetElementType() == typeof(string)))
53+
// return LoadJagged(stream) as T;
54+
return LoadMatrix(stream) as T;
55+
}
56+
3057
bool ParseReader(BinaryReader reader, out int bytes, out Type t, out int[] shape)
3158
{
3259
bytes = 0;

src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,19 @@ namespace Tensorflow.NumPy
88
{
99
public partial class NDArray
1010
{
11+
[AutoNumPy]
1112
public static NDArray operator +(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("add", lhs, rhs));
13+
[AutoNumPy]
1214
public static NDArray operator -(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("sub", lhs, rhs));
15+
[AutoNumPy]
1316
public static NDArray operator *(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("mul", lhs, rhs));
17+
[AutoNumPy]
1418
public static NDArray operator /(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("div", lhs, rhs));
19+
[AutoNumPy]
1520
public static NDArray operator >(NDArray lhs, NDArray rhs) => new NDArray(gen_math_ops.greater(lhs, rhs));
21+
[AutoNumPy]
1622
public static NDArray operator <(NDArray lhs, NDArray rhs) => new NDArray(gen_math_ops.less(lhs, rhs));
23+
[AutoNumPy]
1724
public static NDArray operator -(NDArray lhs) => new NDArray(gen_math_ops.neg(lhs));
1825
}
1926
}

src/TensorFlowNET.Core/NumPy/NumPy.Logical.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ namespace Tensorflow.NumPy
99
{
1010
public partial class np
1111
{
12-
public static NDArray logical_or(NDArray x1, NDArray x2)
13-
=> new NDArray(tf.logical_or(x1, x2));
12+
[AutoNumPy]
13+
public static NDArray logical_or(NDArray x1, NDArray x2) => new NDArray(tf.logical_or(x1, x2));
1414

15-
public static NDArray logical_and(NDArray x1, NDArray x2)
16-
=> new NDArray(tf.logical_and(x1, x2));
15+
[AutoNumPy]
16+
public static NDArray logical_and(NDArray x1, NDArray x2) => new NDArray(tf.logical_and(x1, x2));
1717
}
1818
}

src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@ namespace Tensorflow.NumPy
88
{
99
public partial class np
1010
{
11+
[AutoNumPy]
1112
public static NDArray argmax(NDArray a, Axis axis = null)
1213
=> new NDArray(math_ops.argmax(a, axis));
1314

15+
[AutoNumPy]
1416
public static NDArray argsort(NDArray a, Axis axis = null)
1517
=> new NDArray(math_ops.argmax(a, axis ?? -1));
1618

19+
[AutoNumPy]
1720
public static NDArray unique(NDArray a)
1821
=> throw new NotImplementedException("");
1922
}

src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ namespace Tensorflow.NumPy
99
{
1010
public partial class np
1111
{
12-
public static NDArray amin(NDArray x, int axis = 0)
13-
=> new NDArray(tf.arg_min(x, axis));
12+
[AutoNumPy]
13+
public static NDArray amin(NDArray x, int axis = 0) => new NDArray(tf.arg_min(x, axis));
1414

15-
public static NDArray amax(NDArray x, int axis = 0)
16-
=> new NDArray(tf.arg_max(x, axis));
15+
[AutoNumPy]
16+
public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.arg_max(x, axis));
1717
}
1818
}

src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ namespace Tensorflow.NumPy
88
{
99
public partial class np
1010
{
11-
public static NDArray reshape(NDArray x1, Shape newshape)
12-
=> x1.reshape(newshape);
11+
[AutoNumPy]
12+
public static NDArray reshape(NDArray x1, Shape newshape) => x1.reshape(newshape);
1313

14-
public static NDArray squeeze(NDArray x1, Axis? axis = null)
15-
=> new NDArray(array_ops.squeeze(x1, axis));
14+
[AutoNumPy]
15+
public static NDArray squeeze(NDArray x1, Axis? axis = null) => new NDArray(array_ops.squeeze(x1, axis));
1616
}
1717
}

src/TensorFlowNET.Core/NumPy/Numpy.Math.cs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,33 @@ namespace Tensorflow.NumPy
99
{
1010
public partial class np
1111
{
12-
public static NDArray exp(NDArray x)
13-
=> new NDArray(tf.exp(x));
12+
[AutoNumPy]
13+
public static NDArray exp(NDArray x) => new NDArray(tf.exp(x));
1414

15-
public static NDArray log(NDArray x)
16-
=> new NDArray(tf.log(x));
15+
[AutoNumPy]
16+
public static NDArray log(NDArray x) => new NDArray(tf.log(x));
1717

18-
public static NDArray multiply(NDArray x1, NDArray x2)
19-
=> new NDArray(tf.multiply(x1, x2));
18+
[AutoNumPy]
19+
public static NDArray multiply(NDArray x1, NDArray x2) => new NDArray(tf.multiply(x1, x2));
2020

21-
public static NDArray maximum(NDArray x1, NDArray x2)
22-
=> new NDArray(tf.maximum(x1, x2));
21+
[AutoNumPy]
22+
public static NDArray maximum(NDArray x1, NDArray x2) => new NDArray(tf.maximum(x1, x2));
2323

24-
public static NDArray minimum(NDArray x1, NDArray x2)
25-
=> new NDArray(tf.minimum(x1, x2));
24+
[AutoNumPy]
25+
public static NDArray minimum(NDArray x1, NDArray x2) => new NDArray(tf.minimum(x1, x2));
2626

27+
[AutoNumPy]
2728
public static NDArray prod(NDArray array, Axis? axis = null, Type? dtype = null, bool keepdims = false)
2829
=> new NDArray(tf.reduce_prod(array, axis: axis));
2930

31+
[AutoNumPy]
3032
public static NDArray prod<T>(params T[] array) where T : unmanaged
3133
=> new NDArray(tf.reduce_prod(new NDArray(array)));
3234

33-
public static NDArray sqrt(NDArray x)
34-
=> new NDArray(tf.sqrt(x));
35+
[AutoNumPy]
36+
public static NDArray sqrt(NDArray x) => new NDArray(tf.sqrt(x));
3537

36-
public static NDArray sum(NDArray x1, Axis? axis = null)
37-
=> new NDArray(tf.math.sum(x1, axis));
38+
[AutoNumPy]
39+
public static NDArray sum(NDArray x1, Axis? axis = null) => new NDArray(tf.math.sum(x1, axis));
3840
}
3941
}

src/TensorFlowNET.Core/Numpy/NDArray.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ public NDIterator<T> AsIterator<T>(bool autoreset = false) where T : unmanaged
3838

3939
public bool HasNext() => throw new NotImplementedException("");
4040
public T MoveNext<T>() => throw new NotImplementedException("");
41+
[AutoNumPy]
4142
public NDArray reshape(Shape newshape) => new NDArray(tf.reshape(this, newshape));
4243
public NDArray astype(TF_DataType dtype) => new NDArray(math_ops.cast(this, dtype));
4344
public NDArray ravel() => throw new NotImplementedException("");

0 commit comments

Comments
 (0)