Skip to content

Commit 05dd652

Browse files
committed
Refactor execution mode switch in keras layer.
1 parent cfc9eb3 commit 05dd652

File tree

18 files changed

+107
-74
lines changed

18 files changed

+107
-74
lines changed

src/TensorFlowNET.Core/Contexts/Context.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ public void restore_mode()
105105
context_switches.Pop();
106106
}
107107

108+
public void reset_context()
109+
{
110+
c_api.TFE_ContextClearCaches(_handle);
111+
}
112+
108113
public void Dispose()
109114
=> _handle.Dispose();
110115
}

src/TensorFlowNET.Core/Functions/ConcreteFunction.cs

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public string Name
3434
public ConcreteFunction(string name)
3535
{
3636
func_graph = new FuncGraph(name);
37+
func_graph.as_default();
3738
}
3839

3940
public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs)
@@ -48,37 +49,36 @@ public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype)
4849
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
4950

5051
// IntPtr func_handle;
51-
using (var graph = new FuncGraph(func_name))
52-
{
53-
var input = tf.placeholder(dtype);
54-
var output = func(input);
55-
56-
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
57-
_handle = graph.ToGraph(opers,
58-
new[] { input },
59-
new[] { output },
60-
null);
61-
}
52+
using var graph = new FuncGraph(func_name);
53+
graph.as_default();
54+
var input = tf.placeholder(dtype);
55+
var output = func(input);
56+
57+
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
58+
_handle = graph.ToGraph(opers,
59+
new[] { input },
60+
new[] { output },
61+
null);
6262
}
6363

6464
public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype)
6565
{
6666
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
6767

6868
// IntPtr func_handle;
69-
using (var graph = new FuncGraph(func_name))
70-
{
71-
var input = tf.placeholder(dtype);
72-
var output = func(input);
69+
using var graph = new FuncGraph(func_name);
70+
graph.as_default();
7371

74-
OutputStructure = output.structure;
72+
var input = tf.placeholder(dtype);
73+
var output = func(input);
7574

76-
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
77-
_handle = graph.ToGraph(opers,
78-
new[] { input },
79-
new[] { output.variant_tensor },
80-
null);
81-
}
75+
OutputStructure = output.structure;
76+
77+
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
78+
_handle = graph.ToGraph(opers,
79+
new[] { input },
80+
new[] { output.variant_tensor },
81+
null);
8282
}
8383

8484
public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func,
@@ -87,22 +87,22 @@ public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func,
8787
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
8888

8989
// IntPtr func_handle;
90-
using (var graph = new FuncGraph(func_name))
91-
{
92-
var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args");
93-
var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args");
94-
var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args");
95-
var outputs = func(input1, (input2, input3));
96-
97-
Outputs = new[] { outputs.Item1, outputs.Item2 };
98-
OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() };
99-
100-
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
101-
_handle = graph.ToGraph(opers,
102-
new[] { input1, input2, input3 },
103-
new[] { outputs.Item1, outputs.Item2 },
104-
null);
105-
}
90+
using var graph = new FuncGraph(func_name);
91+
graph.as_default();
92+
93+
var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args");
94+
var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args");
95+
var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args");
96+
var outputs = func(input1, (input2, input3));
97+
98+
Outputs = new[] { outputs.Item1, outputs.Item2 };
99+
OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() };
100+
101+
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
102+
_handle = graph.ToGraph(opers,
103+
new[] { input1, input2, input3 },
104+
new[] { outputs.Item1, outputs.Item2 },
105+
null);
106106
}
107107

108108
public void ToGraph(Tensors inputs, Tensors outputs)

src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ public EagerDefinedFunction(string name, FuncGraph graph,
2626
var output_names = new string[0];
2727

2828
_func_graph = new FuncGraph(graph, name, attrs);
29+
_func_graph.as_default();
2930
_func_graph.ToGraph(operations, inputs, outputs, output_names);
3031
}
3132

src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ public void Record(Tensors flat_outputs, Tensors inference_args)
8585

8686
var gradients_wrt_outputs = new List<Tensor>();
8787
var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{ops.uid()}");
88+
backwards_graph.as_default();
8889
foreach (var output in trainable_outputs)
8990
gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape));
9091
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(),

src/TensorFlowNET.Core/Graphs/AutoGraph.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func)
1313
// IntPtr func_handle;
1414
using (var graph = new FuncGraph(func_name))
1515
{
16+
graph.as_default();
1617
var input = tf.placeholder(tf.int32);
1718
var output = func(input);
1819

@@ -43,6 +44,7 @@ public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func)
4344
// IntPtr func_handle;
4445
using (var graph = new FuncGraph(func_name))
4546
{
47+
graph.as_default();
4648
var input1 = tf.placeholder(tf.int32);
4749
var input2 = tf.placeholder(tf.int32);
4850
var output = func(input1, input2);

src/TensorFlowNET.Core/Graphs/FuncGraph.cs

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@ public Tensor[] external_captures()
3030
public Tensor[] internal_captures()
3131
=> _captures.Select(x => x.Value.Item2).ToArray();
3232

33-
// new Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>();
34-
// public new Tensor[] external_captures => _captures.Values.Select(x => x.Item1).ToArray();
35-
3633
/// <summary>
3734
/// Construct a new FuncGraph.
3835
/// </summary>
@@ -43,8 +40,6 @@ public FuncGraph(string name) : base()
4340
outer_graph = outer_graph.OuterGraph;
4441
_graph_key = name;
4542
building_function = true;
46-
tf.Context.graph_mode();
47-
as_default();
4843
}
4944

5045
public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base()
@@ -58,9 +53,6 @@ public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) :
5853
// Will to test if FuncGraph has memory leak
5954
// c_api.TF_DeleteGraph(_handle);
6055
_handle = handle;
61-
62-
tf.Context.graph_mode();
63-
as_default();
6456
}
6557

6658
public IntPtr ToGraph(Operation[] opers,
@@ -110,11 +102,21 @@ public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType
110102
return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device);
111103
}
112104

113-
public Tensor capture(Tensor tensor, string name = null, TF_DataType shape = TF_DataType.DtInvalid)
105+
const int _EAGER_CONST_THRESHOLD = 128;
106+
public Tensor capture(Tensor tensor, string name = null, TensorShape shape = null)
114107
{
115108
if(tensor is EagerTensor)
116109
{
117-
throw new NotImplementedException("");
110+
if (name == null)
111+
name = ops.uid().ToString();
112+
113+
// Small EagerTensors are captured with Const ops
114+
if (dtypes.is_value_dtype(tensor.dtype)
115+
&& (tensor.rank == 0 || tensor.size < _EAGER_CONST_THRESHOLD))
116+
return capture_eager_tensor(tensor, name);
117+
118+
// Large EagerTensors and resources are captured with Placeholder ops
119+
return _capture_helper(tensor, name, shape: shape);
118120
}
119121

120122
if(tensor.graph != this)
@@ -137,6 +139,9 @@ public Tensor capture(Tensor tensor, string name = null, TF_DataType shape = TF_
137139
return tensor;
138140
}
139141

142+
Tensor capture_eager_tensor(Tensor tensor, string name)
143+
=> throw new NotImplementedException("");
144+
140145
Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null)
141146
{
142147
Tensor placeholder = null;
@@ -190,7 +195,8 @@ Tensor _create_substitute_placeholder(Tensor value,
190195
if (dtype == TF_DataType.DtInvalid)
191196
dtype = value.dtype;
192197

193-
var placeholder = tf_with(ops.control_dependencies(null), ctl => array_ops.placeholder(dtype, shape: shape, name: name));
198+
var placeholder = tf_with(ops.control_dependencies(null), ctl
199+
=> array_ops.placeholder(dtype, shape: shape, name: name));
194200
// custom_gradient.copy_handle_data(value, placeholder)
195201
return placeholder;
196202
}
@@ -211,6 +217,13 @@ void SetAttrs()
211217
}
212218
}
213219

220+
public override Graph as_default()
221+
{
222+
tf.Context.graph_mode(isFunc: true);
223+
ops.set_default_graph(this);
224+
return this;
225+
}
226+
214227
protected override void DisposeManagedResources()
215228
{
216229
base.DisposeManagedResources();

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true,
148148
/// Returns a context manager that makes this `Graph` the default graph.
149149
/// </summary>
150150
/// <returns></returns>
151-
public Graph as_default()
151+
public virtual Graph as_default()
152152
{
153153
return ops.set_default_graph(this);
154154
}

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
<AssemblyName>TensorFlow.NET</AssemblyName>
66
<RootNamespace>Tensorflow</RootNamespace>
77
<TargetTensorFlow>2.2.0</TargetTensorFlow>
8-
<Version>0.31.2</Version>
8+
<Version>0.32.0</Version>
99
<LangVersion>8.0</LangVersion>
1010
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
1111
<Company>SciSharp STACK</Company>
@@ -15,7 +15,7 @@
1515
<RepositoryType>git</RepositoryType>
1616
<PackageProjectUrl>http://scisharpstack.org</PackageProjectUrl>
1717
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIconUrl>
18-
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#, TF.NET</PackageTags>
18+
<PackageTags>TensorFlow, SciSharp, Machine Learning, TensorFlow.NET, TF.NET, AI</PackageTags>
1919
<Description>Google's TensorFlow full binding in .NET Standard.
2020
Building, training and infering deep learning models.
2121
https://tensorflownet.readthedocs.io</Description>

src/TensorFlowNET.Core/Tensors/dtypes.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,5 +293,12 @@ public static TF_DataType real_dtype(this TF_DataType self)
293293
else
294294
return self;
295295
}
296+
297+
public static bool is_value_dtype(this TF_DataType type)
298+
{
299+
return ((int)type >= 1 && (int)type <= 19)
300+
|| type == TF_DataType.TF_UINT32
301+
|| type == TF_DataType.TF_UINT64;
302+
}
296303
}
297304
}

src/TensorFlowNET.Core/ops.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,13 @@ public static Tensor convert_to_tensor(object value,
113113
{
114114
if (tf.executing_eagerly())
115115
return eager_tensor;
116-
/*else
116+
else
117117
{
118118
var graph = get_default_graph();
119119
if (!graph.building_function)
120120
throw new RuntimeError("Attempting to capture an EagerTensor without building a function.");
121121
return (graph as FuncGraph).capture(eager_tensor, name: name);
122-
}*/
122+
}
123123
}
124124

125125
Tensor ret = value switch

0 commit comments

Comments
 (0)