Skip to content

Commit 6931d5c

Browse files
committed
change Session(ConfigProto).
1 parent 5a0648a commit 6931d5c

File tree

6 files changed

+21
-50
lines changed

6 files changed

+21
-50
lines changed

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,20 @@ public class BaseSession : DisposableObject
3636
protected byte[] _target;
3737
public Graph graph => _graph;
3838

39-
public BaseSession(string target = "", Graph g = null, SessionOptions opts = null, Status status = null)
39+
public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null)
4040
{
4141
_graph = g ?? ops.get_default_graph();
4242
_graph.as_default();
4343
_target = Encoding.UTF8.GetBytes(target);
4444

45-
SessionOptions lopts = opts ?? new SessionOptions();
46-
47-
lock (Locks.ProcessWide)
45+
using (var opts = new SessionOptions(target, config))
4846
{
49-
status = status ?? new Status();
50-
_handle = c_api.TF_NewSession(_graph, opts ?? lopts, status);
51-
status.Check(true);
47+
lock (Locks.ProcessWide)
48+
{
49+
status = status ?? new Status();
50+
_handle = c_api.TF_NewSession(_graph, opts, status);
51+
status.Check(true);
52+
}
5253
}
5354
}
5455

src/TensorFlowNET.Core/Sessions/Session.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public Session(IntPtr handle, Graph g = null) : base("", g, null)
3232
_handle = handle;
3333
}
3434

35-
public Session(Graph g, SessionOptions opts = null, Status s = null) : base("", g, opts, s)
35+
public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s)
3636
{ }
3737

3838
public Session as_default()

src/TensorFlowNET.Core/Sessions/SessionOptions.cs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,14 @@ limitations under the License.
2020

2121
namespace Tensorflow
2222
{
23-
public class SessionOptions : DisposableObject
23+
internal class SessionOptions : DisposableObject
2424
{
25-
public SessionOptions()
25+
public SessionOptions(string target = "", ConfigProto config = null)
2626
{
2727
_handle = c_api.TF_NewSessionOptions();
28+
c_api.TF_SetTarget(_handle, target);
29+
if (config != null)
30+
SetConfig(config);
2831
}
2932

3033
public SessionOptions(IntPtr handle)
@@ -35,10 +38,10 @@ public SessionOptions(IntPtr handle)
3538
protected override void DisposeUnmanagedResources(IntPtr handle)
3639
=> c_api.TF_DeleteSessionOptions(handle);
3740

38-
public void SetConfig(ConfigProto config)
41+
private void SetConfig(ConfigProto config)
3942
{
40-
var bytes = config.ToByteArray(); //TODO! we can use WriteTo
41-
var proto = Marshal.AllocHGlobal(bytes.Length); //TODO! potential memory leak
43+
var bytes = config.ToByteArray();
44+
var proto = Marshal.AllocHGlobal(bytes.Length);
4245
Marshal.Copy(bytes, 0, proto, bytes.Length);
4346

4447
using (var status = new Status())

src/TensorFlowNET.Core/Sessions/TF_DeprecatedSession.cs

Lines changed: 0 additions & 26 deletions
This file was deleted.

src/TensorFlowNET.Core/Sessions/TF_SessionOptions.cs

Lines changed: 0 additions & 10 deletions
This file was deleted.

src/TensorFlowNET.Core/Sessions/c_api.session.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_op
116116
/// <param name="proto_len">size_t</param>
117117
/// <param name="status">TF_Status*</param>
118118
[DllImport(TensorFlowLibName)]
119-
public static extern unsafe void TF_SetConfig(IntPtr options, IntPtr proto, ulong proto_len, IntPtr status);
119+
public static extern void TF_SetConfig(IntPtr options, IntPtr proto, ulong proto_len, IntPtr status);
120+
121+
[DllImport(TensorFlowLibName)]
122+
public static extern void TF_SetTarget(IntPtr options, string target);
120123
}
121124
}

0 commit comments

Comments
 (0)