Skip to content

Commit 029779f

Browse files
committed
Multithreading with Keras. #890
1 parent 0a5e181 commit 029779f

File tree

14 files changed

+205
-240
lines changed

14 files changed

+205
-240
lines changed

src/TensorFlowNET.Core/APIs/tf.graph.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ namespace Tensorflow
2020
{
2121
public partial class tensorflow
2222
{
23-
public graph_util_impl graph_util => new graph_util_impl();
24-
public GraphTransformer graph_transforms => new GraphTransformer();
23+
public graph_util_impl graph_util { get; } = new graph_util_impl();
24+
public GraphTransformer graph_transforms { get; } = new GraphTransformer();
2525
public GraphKeys GraphKeys { get; } = new GraphKeys();
2626

2727
public void reset_default_graph()

src/TensorFlowNET.Core/Contexts/Context.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ public void restore_mode()
171171

172172
public void reset_context()
173173
{
174-
ops.reset_uid();
174+
// ops.reset_uid();
175175
// tf.defaultSession = null;
176176
ops.reset_default_graph();
177177
context_switches.Clear();

src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using System;
1718
using System.Collections.Generic;
18-
using System.Linq;
19-
using static Tensorflow.Binding;
2019

2120
namespace Tensorflow
2221
{
@@ -25,19 +24,14 @@ namespace Tensorflow
2524
/// </summary>
2625
public class DefaultGraphStack
2726
{
28-
private readonly Stack<Graph> _stack = new Stack<Graph>();
29-
Graph _global_default_graph;
27+
Stack<Graph> _stack = new Stack<Graph>();
3028

3129
public Graph get_default()
3230
{
33-
if (_stack.Count > 0)
34-
return _stack.Peek();
35-
else if (_global_default_graph != null)
36-
return _global_default_graph;
37-
else
38-
_global_default_graph = new Graph();
31+
if (_stack.Count == 0)
32+
_stack.Push(new Graph());
3933

40-
return _global_default_graph;
34+
return _stack.Peek();
4135
}
4236

4337
public Graph get_controller(Graph g)
@@ -61,7 +55,6 @@ public void pop()
6155
public void reset()
6256
{
6357
_stack.Clear();
64-
_global_default_graph = null;
6558
}
6659
}
6760
}

src/TensorFlowNET.Core/NumPy/NDArray.Index.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ unsafe NDArray GetData(Slice[] slices)
107107
if (tensor.Handle == null)
108108
{
109109
if (tf.executing_eagerly())
110-
tensor = tf.defaultSession.eval(tensor);
110+
tensor = tf.get_default_session().eval(tensor);
111111
}
112112

113113
return new NDArray(tensor, tf.executing_eagerly());

src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public NDArray(Tensor tensor, bool clone = false) : base(tensor.Handle, clone: c
3838
{
3939
if (_handle is null)
4040
{
41-
tensor = tf.defaultSession.eval(tensor);
41+
tensor = tf.get_default_session().eval(tensor);
4242
_handle = tensor.Handle;
4343
}
4444

src/TensorFlowNET.Core/Variables/EagerResourceDeleter.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ protected override void DisposeUnmanagedResources(IntPtr handle)
2323
{
2424
// gen_resource_variable_ops.destroy_resource_op(_tensor, ignore_lookup_error: true);
2525

26-
tf.device(_handle_device);
26+
// tf.device(_handle_device);
2727
tf.Runner.TFE_Execute(tf.Context, _handle_device, "DestroyResourceOp",
2828
new[] { _tensor },
2929
new object[] { "ignore_lookup_error", true }, 0);

src/TensorFlowNET.Core/ops.threading.cs

Lines changed: 28 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,26 @@
1-
using System.Threading;
1+
using System;
2+
using System.Threading;
23
using static Tensorflow.Binding;
34

45
namespace Tensorflow
56
{
67
public partial class ops
78
{
8-
private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack());
9-
private static volatile Session _singleSesson;
10-
private static volatile DefaultGraphStack _singleGraphStack;
11-
private static readonly object _threadingLock = new object();
12-
13-
public static DefaultGraphStack default_graph_stack
14-
{
15-
get
16-
{
17-
if (!isSingleThreaded)
18-
return _defaultGraphFactory.Value;
19-
20-
if (_singleGraphStack == null)
21-
{
22-
lock (_threadingLock)
23-
{
24-
if (_singleGraphStack == null)
25-
_singleGraphStack = new DefaultGraphStack();
26-
}
27-
}
28-
29-
return _singleGraphStack;
30-
}
31-
}
32-
33-
private static bool isSingleThreaded = false;
34-
35-
/// <summary>
36-
/// Does this library ignore different thread accessing.
37-
/// </summary>
38-
/// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading </remarks>
39-
public static bool IsSingleThreaded
40-
{
41-
get => isSingleThreaded;
42-
set
43-
{
44-
if (value)
45-
enforce_singlethreading();
46-
else
47-
enforce_multithreading();
48-
}
49-
}
50-
51-
/// <summary>
52-
/// Forces the library to ignore different thread accessing.
53-
/// </summary>
54-
/// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a multithreaded manner</remarks>
55-
public static void enforce_singlethreading()
56-
{
57-
isSingleThreaded = true;
58-
}
59-
60-
/// <summary>
61-
/// Forces the library to provide a separate <see cref="Session"/> and <see cref="Graph"/> to every different thread accessing.
62-
/// </summary>
63-
/// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a singlethreaded manner</remarks>
64-
public static void enforce_multithreading()
65-
{
66-
isSingleThreaded = false;
67-
}
9+
[ThreadStatic]
10+
static DefaultGraphStack default_graph_stack = new DefaultGraphStack();
11+
[ThreadStatic]
12+
static Session defaultSession;
6813

6914
/// <summary>
7015
/// Returns the default session for the current thread.
7116
/// </summary>
7217
/// <returns>The default `Session` being used in the current thread.</returns>
7318
public static Session get_default_session()
7419
{
75-
if (!isSingleThreaded)
76-
return tf.defaultSession;
20+
if (defaultSession == null)
21+
defaultSession = new Session(tf.get_default_graph());
7722

78-
if (_singleSesson == null)
79-
{
80-
lock (_threadingLock)
81-
{
82-
if (_singleSesson == null)
83-
_singleSesson = new Session();
84-
}
85-
}
86-
87-
return _singleSesson;
23+
return defaultSession;
8824
}
8925

9026
/// <summary>
@@ -93,15 +29,8 @@ public static Session get_default_session()
9329
/// <returns>The default `Session` being used in the current thread.</returns>
9430
public static Session set_default_session(Session sess)
9531
{
96-
if (!isSingleThreaded)
97-
return tf.defaultSession = sess;
98-
99-
lock (_threadingLock)
100-
{
101-
_singleSesson = sess;
102-
}
103-
104-
return _singleSesson;
32+
defaultSession = sess;
33+
return sess;
10534
}
10635

10736
/// <summary>
@@ -118,10 +47,18 @@ public static Session set_default_session(Session sess)
11847
/// </summary>
11948
/// <returns></returns>
12049
public static Graph get_default_graph()
121-
=> default_graph_stack.get_default();
50+
{
51+
if (default_graph_stack == null)
52+
default_graph_stack = new DefaultGraphStack();
53+
return default_graph_stack.get_default();
54+
}
12255

12356
public static Graph set_default_graph(Graph g)
124-
=> default_graph_stack.get_controller(g);
57+
{
58+
if (default_graph_stack == null)
59+
default_graph_stack = new DefaultGraphStack();
60+
return default_graph_stack.get_controller(g);
61+
}
12562

12663
/// <summary>
12764
/// Clears the default graph stack and resets the global default graph.
@@ -135,6 +72,8 @@ public static Graph set_default_graph(Graph g)
13572
/// <returns></returns>
13673
public static void reset_default_graph()
13774
{
75+
if (default_graph_stack == null)
76+
return;
13877
//if (!_default_graph_stack.is_cleared())
13978
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " +
14079
// "nested graphs. If you need a cleared graph, " +
@@ -143,7 +82,11 @@ public static void reset_default_graph()
14382
}
14483

14584
public static Graph peak_default_graph()
146-
=> default_graph_stack.peak_controller();
85+
{
86+
if (default_graph_stack == null)
87+
default_graph_stack = new DefaultGraphStack();
88+
return default_graph_stack.peak_controller();
89+
}
14790

14891
public static void pop_graph()
14992
=> default_graph_stack.pop();

src/TensorFlowNET.Core/tensorflow.cs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616

1717
using Serilog;
1818
using Serilog.Core;
19+
using System.Threading;
1920
using Tensorflow.Contexts;
2021
using Tensorflow.Eager;
2122
using Tensorflow.Gradients;
@@ -38,25 +39,27 @@ public partial class tensorflow
3839
public TF_DataType chars = TF_DataType.TF_STRING;
3940
public TF_DataType @string = TF_DataType.TF_STRING;
4041

41-
public Status Status;
4242
public OpDefLibrary OpDefLib;
43-
public Context Context;
44-
public IEagerRunner Runner;
4543
public Logger Logger;
4644

45+
ThreadLocal<Status> _status = new ThreadLocal<Status>(() => new Status());
46+
public Status Status => _status.Value;
47+
48+
ThreadLocal<Context> _context = new ThreadLocal<Context>(() => new Context());
49+
public Context Context => _context.Value;
50+
51+
ThreadLocal<IEagerRunner> _runner = new ThreadLocal<IEagerRunner>(() => new EagerRunner());
52+
public IEagerRunner Runner => _runner.Value;
53+
4754
public tensorflow()
4855
{
4956
Logger = new LoggerConfiguration()
5057
.MinimumLevel.Error()
5158
.WriteTo.Console()
5259
.CreateLogger();
5360

54-
Status = new Status();
55-
Context = new Context();
5661
OpDefLib = new OpDefLibrary();
57-
ConstructThreadingObjects();
5862
InitGradientEnvironment();
59-
Runner = new EagerRunner();
6063
}
6164

6265
public string VERSION => c_api.StringPiece(c_api.TF_Version());

src/TensorFlowNET.Core/tensorflow.threading.cs

Lines changed: 0 additions & 53 deletions
This file was deleted.

src/TensorFlowNET.Keras/KerasInterface.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using Tensorflow.Keras.Optimizers;
1313
using Tensorflow.Keras.Saving;
1414
using Tensorflow.Keras.Utils;
15+
using System.Threading;
1516

1617
namespace Tensorflow.Keras
1718
{

0 commit comments

Comments
 (0)