1- using System . Threading ;
1+ using System ;
2+ using System . Threading ;
23using static Tensorflow . Binding ;
34
45namespace Tensorflow
56{
67 public partial class ops
78 {
8- private static readonly ThreadLocal < DefaultGraphStack > _defaultGraphFactory = new ThreadLocal < DefaultGraphStack > ( ( ) => new DefaultGraphStack ( ) ) ;
9- private static volatile Session _singleSesson ;
10- private static volatile DefaultGraphStack _singleGraphStack ;
11- private static readonly object _threadingLock = new object ( ) ;
12-
13- public static DefaultGraphStack default_graph_stack
14- {
15- get
16- {
17- if ( ! isSingleThreaded )
18- return _defaultGraphFactory . Value ;
19-
20- if ( _singleGraphStack == null )
21- {
22- lock ( _threadingLock )
23- {
24- if ( _singleGraphStack == null )
25- _singleGraphStack = new DefaultGraphStack ( ) ;
26- }
27- }
28-
29- return _singleGraphStack ;
30- }
31- }
32-
33- private static bool isSingleThreaded = false ;
34-
35- /// <summary>
36- /// Does this library ignore different thread accessing.
37- /// </summary>
38- /// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading </remarks>
39- public static bool IsSingleThreaded
40- {
41- get => isSingleThreaded ;
42- set
43- {
44- if ( value )
45- enforce_singlethreading ( ) ;
46- else
47- enforce_multithreading ( ) ;
48- }
49- }
50-
51- /// <summary>
52- /// Forces the library to ignore different thread accessing.
53- /// </summary>
54- /// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a multithreaded manner</remarks>
55- public static void enforce_singlethreading ( )
56- {
57- isSingleThreaded = true ;
58- }
59-
60- /// <summary>
61- /// Forces the library to provide a separate <see cref="Session"/> and <see cref="Graph"/> to every different thread accessing.
62- /// </summary>
63- /// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a singlethreaded manner</remarks>
64- public static void enforce_multithreading ( )
65- {
66- isSingleThreaded = false ;
67- }
9+ [ ThreadStatic ]
10+ static DefaultGraphStack default_graph_stack = new DefaultGraphStack ( ) ;
11+ [ ThreadStatic ]
12+ static Session defaultSession ;
6813
6914 /// <summary>
7015 /// Returns the default session for the current thread.
7116 /// </summary>
7217 /// <returns>The default `Session` being used in the current thread.</returns>
7318 public static Session get_default_session ( )
7419 {
75- if ( ! isSingleThreaded )
76- return tf . defaultSession ;
20+ if ( defaultSession == null )
21+ defaultSession = new Session ( tf . get_default_graph ( ) ) ;
7722
78- if ( _singleSesson == null )
79- {
80- lock ( _threadingLock )
81- {
82- if ( _singleSesson == null )
83- _singleSesson = new Session ( ) ;
84- }
85- }
86-
87- return _singleSesson ;
23+ return defaultSession ;
8824 }
8925
9026 /// <summary>
@@ -93,15 +29,8 @@ public static Session get_default_session()
9329 /// <returns>The default `Session` being used in the current thread.</returns>
9430 public static Session set_default_session ( Session sess )
9531 {
96- if ( ! isSingleThreaded )
97- return tf . defaultSession = sess ;
98-
99- lock ( _threadingLock )
100- {
101- _singleSesson = sess ;
102- }
103-
104- return _singleSesson ;
32+ defaultSession = sess ;
33+ return sess ;
10534 }
10635
10736 /// <summary>
@@ -118,10 +47,18 @@ public static Session set_default_session(Session sess)
11847 /// </summary>
11948 /// <returns></returns>
12049 public static Graph get_default_graph ( )
121- => default_graph_stack . get_default ( ) ;
50+ {
51+ if ( default_graph_stack == null )
52+ default_graph_stack = new DefaultGraphStack ( ) ;
53+ return default_graph_stack . get_default ( ) ;
54+ }
12255
12356 public static Graph set_default_graph ( Graph g )
124- => default_graph_stack . get_controller ( g ) ;
57+ {
58+ if ( default_graph_stack == null )
59+ default_graph_stack = new DefaultGraphStack ( ) ;
60+ return default_graph_stack . get_controller ( g ) ;
61+ }
12562
12663 /// <summary>
12764 /// Clears the default graph stack and resets the global default graph.
@@ -135,6 +72,8 @@ public static Graph set_default_graph(Graph g)
13572 /// <returns></returns>
13673 public static void reset_default_graph ( )
13774 {
75+ if ( default_graph_stack == null )
76+ return ;
13877 //if (!_default_graph_stack.is_cleared())
13978 // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " +
14079 // "nested graphs. If you need a cleared graph, " +
@@ -143,7 +82,11 @@ public static void reset_default_graph()
14382 }
14483
14584 public static Graph peak_default_graph ( )
146- => default_graph_stack . peak_controller ( ) ;
85+ {
86+ if ( default_graph_stack == null )
87+ default_graph_stack = new DefaultGraphStack ( ) ;
88+ return default_graph_stack . peak_controller ( ) ;
89+ }
14790
14891 public static void pop_graph ( )
14992 => default_graph_stack . pop ( ) ;
0 commit comments