Skip to content

Commit 32a0f58

Browse files
committed
ContextSwitchStack
1 parent 6e4bad4 commit 32a0f58

File tree

11 files changed

+57
-31
lines changed

11 files changed

+57
-31
lines changed

src/TensorFlowNET.Core/Contexts/Context.cs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,15 @@ public sealed class Context : IDisposable
3131
public string DeviceName { get; set; } = "";
3232
public string ScopeName { get; set; } = "";
3333
bool initialized = false;
34-
bool isEager;
35-
ContextSwitchStack contextSwitches;
34+
ContextSwitchStack context_switches;
3635

3736
public SafeContextHandle Handle { get; }
3837

3938
public Context(ContextOptions opts, Status status)
4039
{
4140
Handle = c_api.TFE_NewContext(opts.Handle, status.Handle);
4241
status.Check(true);
43-
isEager = defaultExecutionMode == EAGER_MODE;
44-
contextSwitches = new ContextSwitchStack(isEager);
42+
context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE);
4543
initialized = true;
4644
}
4745

@@ -66,7 +64,7 @@ public void end_step()
6664
/// </summary>
6765
/// <returns></returns>
6866
public bool executing_eagerly()
69-
=> isEager;
67+
=> context_switches.Current().EagerMode;
7068

7169
public string shared_name(string name = null)
7270
=> !string.IsNullOrEmpty(name) || !executing_eagerly() ?
@@ -79,9 +77,14 @@ public void graph_mode()
7977
public void eager_mode()
8078
=> mode(true);
8179

82-
void mode(bool mode)
80+
void mode(bool isEager)
8381
{
84-
isEager = mode;
82+
context_switches.Push(isEager);
83+
}
84+
85+
public void restore_mode()
86+
{
87+
context_switches.Pop();
8588
}
8689

8790
public void Dispose()

src/TensorFlowNET.Core/Contexts/ContextSwitch.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ namespace Tensorflow.Contexts
2222
{
2323
public class ContextSwitch
2424
{
25+
public bool EagerMode { get; set; }
26+
2527
/// <summary>
2628
/// Whether the context is building a function.
2729
/// </summary>

src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,25 @@ public class ContextSwitchStack
3030
public ContextSwitchStack(bool isEager)
3131
{
3232
stack = new Stack<ContextSwitch>();
33-
if (isEager)
34-
stack.Push(new ContextSwitch
35-
{
36-
IsBuildingFunction = false
37-
});
33+
Push(isEager);
34+
}
35+
36+
public void Push(bool isEager)
37+
{
38+
stack.Push(new ContextSwitch
39+
{
40+
EagerMode = isEager
41+
});
42+
}
43+
44+
public void Pop()
45+
{
46+
stack.Pop();
47+
}
48+
49+
public ContextSwitch Current()
50+
{
51+
return stack.Peek();
3852
}
3953
}
4054
}

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true,
148148
/// <returns></returns>
149149
public Graph as_default()
150150
{
151-
tf.Context.graph_mode();
152151
return ops.set_default_graph(this);
153152
}
154153

@@ -492,7 +491,6 @@ public void prevent_fetching(Operation op)
492491

493492
protected override void DisposeManagedResources()
494493
{
495-
tf.Context.eager_mode();
496494
ops.default_graph_stack.remove(this);
497495
}
498496

src/TensorFlowNET.Core/Keras/Engine/Layer.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ public Tensor Apply(Tensor inputs, bool is_training = false)
159159
_set_mask_metadata(inputs, outputs, null);
160160
});
161161

162-
tf.Context.eager_mode();
162+
if (!inputs.IsEagerTensor)
163+
tf.Context.restore_mode();
163164

164165
return outputs;
165166
}

src/TensorFlowNET.Core/Keras/Layers/Embedding.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,14 @@ public class Embedding : Layer
3434
IInitializer embeddings_initializer;
3535

3636
public Embedding(EmbeddingArgs args)
37-
: base(args)
37+
: base(new LayerArgs // copy args
38+
{
39+
DType = args.DType,
40+
Name = args.Name
41+
})
3842
{
3943
this.args = args;
40-
if(args.InputShape == null)
44+
if (args.InputShape == null)
4145
args.InputShape = args.InputLength;
4246

4347
embeddings_initializer = embeddings_initializer ?? tf.random_uniform_initializer;

src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public InputLayer(InputLayerArgs args) :
3838
{
3939
this.args = args;
4040
built = true;
41-
this.SupportsMasking = true;
41+
SupportsMasking = true;
4242

4343
if(BatchInputShape != null)
4444
{
@@ -58,6 +58,9 @@ public InputLayer(InputLayerArgs args) :
5858
args.DType = args.InputTensor == null ? tf.float32 : args.InputTensor.dtype;
5959
}
6060

61+
// In graph mode, create a graph placeholder to call the layer on.
62+
tf.Context.graph_mode();
63+
6164
if (args.InputTensor == null)
6265
{
6366
if(args.InputShape != null)
@@ -71,15 +74,13 @@ public InputLayer(InputLayerArgs args) :
7174
args.BatchInputShape = null;
7275
}
7376

74-
// In graph mode, create a graph placeholder to call the layer on.
75-
tf.Context.graph_mode();
7677
args.InputTensor = tf.keras.backend.placeholder(
77-
shape: BatchInputShape,
78-
dtype: DType,
79-
name: Name,
80-
sparse: args.Sparse,
81-
ragged: args.Ragged);
82-
tf.Context.eager_mode();
78+
shape: BatchInputShape,
79+
dtype: DType,
80+
name: Name,
81+
sparse: args.Sparse,
82+
ragged: args.Ragged);
83+
8384

8485
isPlaceholder = true;
8586
}
@@ -97,6 +98,8 @@ public InputLayer(InputLayerArgs args) :
9798
typeSpec = new TensorSpec(args.InputTensor.TensorShape,
9899
dtype: args.InputTensor.dtype,
99100
name: Name);
101+
102+
tf.Context.restore_mode();
100103
}
101104
}
102105
}

src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
<ItemGroup>
3131
<PackageReference Include="BenchmarkDotNet" Version="0.12.1" />
32-
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.2.0" />
32+
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" />
3333
</ItemGroup>
3434

3535
<ItemGroup>

test/TensorFlowNET.UnitTest/EagerModeTestBase.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ public class EagerModeTestBase : PythonTest
1212
[TestInitialize]
1313
public void TestInit()
1414
{
15-
tf.enable_eager_execution();
15+
if (!tf.executing_eagerly())
16+
tf.enable_eager_execution();
1617
}
1718

1819
[TestCleanup]

test/TensorFlowNET.UnitTest/Keras/LayersTest.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Keras
1414
/// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers
1515
/// </summary>
1616
[TestClass]
17-
public class LayersTest : GraphModeTestBase
17+
public class LayersTest : EagerModeTestBase
1818
{
1919
[TestMethod]
2020
public void Sequential()
@@ -26,7 +26,7 @@ public void Sequential()
2626
/// <summary>
2727
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
2828
/// </summary>
29-
[TestMethod, Ignore]
29+
[TestMethod]
3030
public void Embedding()
3131
{
3232
var model = new Sequential();

0 commit comments

Comments
 (0)