Skip to content

Commit f4e7fd4

Browse files
Merge branch 'master' of github.com:Wanglongzhi2001/TensorFlow.NET into dev1
2 parents 78bd4c7 + 682f52f commit f4e7fd4

File tree

190 files changed

+73839
-1975
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

190 files changed

+73839
-1975
lines changed

TensorFlow.NET.sln

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
Microsoft Visual Studio Solution File, Format Version 12.00
3-
# Visual Studio Version 16
4-
VisualStudioVersion = 16.0.31624.102
3+
# Visual Studio Version 17
4+
VisualStudioVersion = 17.4.33213.308
55
MinimumVisualStudioVersion = 10.0.40219.1
66
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}"
77
EndProject
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Runtime.InteropServices;
4+
using System.Text;
5+
6+
namespace Tensorflow
7+
{
8+
public partial class c_api
9+
{
10+
[DllImport(TensorFlowLibName)]
11+
public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status);
12+
[DllImport(TensorFlowLibName)]
13+
public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
14+
[DllImport(TensorFlowLibName)]
15+
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status);
16+
}
17+
}

src/TensorFlowNET.Core/APIs/tf.compat.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using Google.Protobuf;
1718
using System.Text;
1819

1920
namespace Tensorflow
@@ -45,6 +46,23 @@ internal string as_str(byte[] bytes_or_text, Encoding? encoding = null)
4546
{
4647
return as_text(bytes_or_text, encoding);
4748
}
49+
50+
public ByteString as_bytes(ByteString bytes, Encoding encoding = null)
51+
{
52+
return bytes;
53+
}
54+
public ByteString as_bytes(byte[] bytes, Encoding encoding = null)
55+
{
56+
return ByteString.CopyFrom(bytes);
57+
}
58+
public ByteString as_bytes(string text, Encoding encoding = null)
59+
{
60+
if(encoding is null)
61+
{
62+
encoding = Encoding.UTF8;
63+
}
64+
return ByteString.CopyFrom(encoding.GetBytes(text));
65+
}
4866
}
4967

5068
public bool executing_eagerly()

src/TensorFlowNET.Core/APIs/tf.io.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,6 @@ public ITensorOrOperation[] import_graph_def(GraphDef graph_def,
5454
Dictionary<string, Tensor> input_map = null,
5555
string[] return_elements = null,
5656
string name = null,
57-
OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name, producer_op_list);
57+
OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name: name, producer_op_list: producer_op_list);
5858
}
5959
}

src/TensorFlowNET.Core/APIs/tf.tensor.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using Tensorflow.Operations;
18+
1719
namespace Tensorflow
1820
{
1921
public partial class tensorflow
@@ -79,5 +81,10 @@ public Tensor[] split(Tensor value, int num_split, int axis, string name = null)
7981
num_split: num_split,
8082
axis: axis,
8183
name: name);
84+
85+
public Tensor ensure_shape(Tensor x, Shape shape, string name = null)
86+
{
87+
return gen_ops.ensure_shape(x, shape, name);
88+
}
8289
}
8390
}

src/TensorFlowNET.Core/Attributes/c_api.ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public partial class c_api
6161
public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value);
6262

6363
[DllImport(TensorFlowLibName)]
64-
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status);
64+
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, ulong proto_len, SafeStatusHandle status);
6565

6666
/// <summary>
6767
/// Set `num_dims` to -1 to represent "unknown rank".

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License.
2222
using System.Diagnostics;
2323
using System.IO;
2424
using System.Linq;
25+
using Tensorflow.Operations;
2526

2627
namespace Tensorflow
2728
{

src/TensorFlowNET.Core/Buffers/Buffer.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ public unsafe byte[] ToArray()
107107
}
108108
}
109109

110+
public void Release()
111+
{
112+
_handle.Dispose();
113+
_handle = null;
114+
}
115+
110116
public override string ToString()
111117
=> $"0x{_handle.DangerousGetHandle():x16}";
112118

src/TensorFlowNET.Core/Buffers/TF_Buffer.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,32 @@ public struct TF_Buffer
2525
public IntPtr data;
2626
public ulong length;
2727
public IntPtr data_deallocator;
28+
29+
public unsafe Span<T> AsSpan<T>() where T: unmanaged
30+
{
31+
if(length > int.MaxValue)
32+
{
33+
throw new ValueError($"The length {length} is too large to use in the span.");
34+
}
35+
return new Span<T>(data.ToPointer(), (int)length);
36+
}
37+
38+
public unsafe byte[] ToByteArray()
39+
{
40+
byte[] res = new byte[length];
41+
if(length > int.MaxValue)
42+
{
43+
byte* root = (byte*)data;
44+
for(ulong i = 0; i < length; i++)
45+
{
46+
res[i] = *(root++);
47+
}
48+
}
49+
else
50+
{
51+
new Span<byte>(data.ToPointer(), (int)length).CopyTo(res.AsSpan());
52+
}
53+
return res;
54+
}
2855
}
2956
}

src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ public static IList<Trackable> list_objects(ObjectGraphView graph_view)
161161

162162
internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list)
163163
{
164-
return full_list.TakeWhile(x =>
164+
return full_list.Where(x =>
165165
{
166166
var saveables = x.gather_saveables_for_checkpoint();
167167
return saveables is not null && saveables.Count > 0;

0 commit comments

Comments
 (0)