@@ -19,13 +19,14 @@ limitations under the License.
1919using System . Linq ;
2020using Tensorflow . Eager ;
2121using static Tensorflow . Binding ;
22+ using Google . Protobuf ;
2223
2324namespace Tensorflow . Contexts
2425{
2526 /// <summary>
2627 /// Environment in which eager operations execute.
2728 /// </summary>
28- public sealed class Context : IDisposable
29+ public sealed partial class Context : IDisposable
2930 {
3031 public const int GRAPH_MODE = 0 ;
3132 public const int EAGER_MODE = 1 ;
@@ -37,14 +38,14 @@ public sealed class Context : IDisposable
3738 ContextSwitchStack context_switches ;
3839 public FunctionCallOptions FunctionCallOptions { get ; }
3940
40- public SafeContextHandle Handle { get ; }
41+ SafeContextHandle _handle ;
42+ public SafeContextHandle Handle => _handle ;
4143
42- public Context ( ContextOptions opts , Status status )
44+ public Context ( )
4345 {
44- Handle = c_api . TFE_NewContext ( opts . Handle , status . Handle ) ;
45- status . Check ( true ) ;
46+ _device_policy = ContextDevicePlacementPolicy . DEVICE_PLACEMENT_SILENT ;
4647 context_switches = new ContextSwitchStack ( defaultExecutionMode == EAGER_MODE , false ) ;
47- initialized = true ;
48+ initialized = false ;
4849 FunctionCallOptions = new FunctionCallOptions ( ) ;
4950 }
5051
@@ -55,14 +56,25 @@ public void ensure_initialized()
5556 {
5657 if ( initialized )
5758 return ;
59+
60+ _config = config ( ) ;
61+ var config_str = _config . ToByteArray ( ) ;
62+
63+ using var opts = new ContextOptions ( ) ;
64+ using var status = new Status ( ) ;
65+ c_api . TFE_ContextOptionsSetConfig ( opts . Handle , config_str , ( ulong ) config_str . Length , status . Handle ) ;
66+ status . Check ( true ) ;
67+ c_api . TFE_ContextOptionsSetDevicePlacementPolicy ( opts . Handle , _device_policy ) ;
68+ _handle = c_api . TFE_NewContext ( opts . Handle , status . Handle ) ;
69+ status . Check ( true ) ;
5870 initialized = true ;
5971 }
6072
6173 public void start_step ( )
62- => c_api . TFE_ContextStartStep ( Handle ) ;
74+ => c_api . TFE_ContextStartStep ( _handle ) ;
6375
6476 public void end_step ( )
65- => c_api . TFE_ContextEndStep ( Handle ) ;
77+ => c_api . TFE_ContextEndStep ( _handle ) ;
6678
6779 /// <summary>
6880 /// Checks whether the current thread has eager execution enabled.
@@ -91,61 +103,7 @@ public void restore_mode()
91103 context_switches . Pop ( ) ;
92104 }
93105
94- // [DebuggerStepThrough]
95- public T RunInAutoMode < T > ( Func < T > graphAction , Func < T > eagerAction , params Tensor [ ] tensors )
96- {
97- var shouldRunInEager = executing_eagerly ( )
98- && tensors . Count ( x => x . IsEagerTensor ) == tensors . Length ;
99-
100- if ( shouldRunInEager )
101- return eagerAction ( ) ;
102- else
103- {
104- if ( executing_eagerly ( ) )
105- {
106- graph_mode ( ) ;
107- var result = graphAction ( ) ;
108- restore_mode ( ) ;
109- return result ;
110- }
111- else
112- {
113- return graphAction ( ) ;
114- }
115- }
116- }
117-
118- // [DebuggerStepThrough]
119- public Tensors RunInAutoMode2 ( Func < Tensors > graphAction ,
120- Func < Tensors > eagerAction ,
121- Action < Operation > recordGradient ,
122- Tensors tensors )
123- {
124- var shouldRunInEager = executing_eagerly ( )
125- && tensors . Count ( x => x . IsEagerTensor ) == tensors . Length ;
126-
127- if ( shouldRunInEager )
128- return eagerAction ( ) ;
129- else
130- {
131- if ( executing_eagerly ( ) )
132- {
133- graph_mode ( ) ;
134- var result = graphAction ( ) ;
135- restore_mode ( ) ;
136- return result ;
137- }
138- else
139- {
140- var result = graphAction ( ) ;
141- if ( tf . Runner . MustRecordGradient ( ) )
142- recordGradient ( result [ 0 ] . op ) ;
143- return result ;
144- }
145- }
146- }
147-
148106 public void Dispose ( )
149- => Handle . Dispose ( ) ;
107+ => _handle . Dispose ( ) ;
150108 }
151109}
0 commit comments