Skip to content

Commit 67a70bf

Browse files
committed
Remove _isCreatedInGraphMode in Tensor
1 parent c93c219 commit 67a70bf

File tree

6 files changed

+8
-19
lines changed

6 files changed

+8
-19
lines changed

src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
using System.Collections.Generic;
33
using System.Linq;
44
using System.Text;
5+
using Tensorflow.Eager;
56
using Tensorflow.Graphs;
7+
using Tensorflow.NumPy;
68
using static Tensorflow.Binding;
79
using static Tensorflow.tensorflow;
810

@@ -148,7 +150,7 @@ public void Record(Tensors flat_outputs, Tensors inference_args)
148150
src_graph: _func_graph);
149151

150152
var captures_from_forward = backwards_graph.external_captures
151-
.Where(x => x.IsCreatedInGraphMode && x.graph == _func_graph)
153+
.Where(x => x is not EagerTensor && x is not NDArray && x.graph == _func_graph)
152154
.ToArray();
153155
foreach(var capture in captures_from_forward)
154156
{

src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ public partial class Tensor
3232

3333
public Tensor()
3434
{
35-
_isCreatedInGraphMode = !tf.executing_eagerly();
3635
}
3736

3837
/// <summary>
@@ -44,8 +43,6 @@ public unsafe Tensor(SafeTensorHandle handle, bool clone = false)
4443
_handle = handle;
4544
if (clone && handle != null)
4645
_handle = TF_NewTensor(shape, dtype, data: TensorDataPointer.ToPointer());
47-
48-
_isCreatedInGraphMode = !tf.executing_eagerly();
4946
}
5047

5148
/// <summary>
@@ -59,13 +56,11 @@ public unsafe Tensor(SafeTensorHandle handle, bool clone = false)
5956
public unsafe Tensor(IntPtr data_ptr, Shape shape, TF_DataType dtype)
6057
{
6158
_handle = TF_NewTensor(shape, dtype, data: data_ptr.ToPointer());
62-
_isCreatedInGraphMode = !tf.executing_eagerly();
6359
}
6460

6561
public unsafe Tensor(NDArray nd)
6662
{
6763
_handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer());
68-
_isCreatedInGraphMode = !tf.executing_eagerly();
6964
}
7065

7166
#region scala
@@ -107,13 +102,11 @@ public Tensor(Operation op, int value_index, TF_DataType dtype)
107102
_value_index = value_index;
108103
_override_dtype = dtype;
109104
_id = ops.uid();
110-
_isCreatedInGraphMode = !tf.executing_eagerly();
111105
}
112106

113107
protected unsafe void InitTensor(Shape shape, TF_DataType dtype)
114108
{
115109
_handle = TF_NewTensor(shape, dtype, null);
116-
_isCreatedInGraphMode = !tf.executing_eagerly();
117110
}
118111

119112
protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype)
@@ -122,13 +115,10 @@ protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype)
122115
_handle = StringTensor(new byte[][] { bytes }, Shape.Scalar);
123116
else
124117
_handle = TF_NewTensor(bytes, shape, dtype);
125-
_isCreatedInGraphMode = !tf.executing_eagerly();
126118
}
127119

128120
protected unsafe void InitTensor(Array array, Shape? shape = null)
129121
{
130-
_isCreatedInGraphMode = !tf.executing_eagerly();
131-
132122
shape = shape ?? array.GetShape();
133123
var dtype = array.GetDataType();
134124

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,6 @@ public partial class Tensor : DisposableObject,
9494
/// </summary>
9595
public SafeEagerTensorHandle EagerTensorHandle => _eagerTensorHandle;
9696

97-
protected bool _isCreatedInGraphMode;
98-
99-
public bool IsCreatedInGraphMode => _isCreatedInGraphMode;
100-
10197
/// <summary>
10298
/// Returns the shape of a tensor.
10399
/// </summary>

src/TensorFlowNET.Core/Tensors/Tensors.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ public class Tensors : IEnumerable<Tensor>, IDisposable
2121
public Shape shape => items.First().shape;
2222
public int rank => items.First().rank;
2323
public Graph graph => items.First().graph;
24-
public bool IsCreatedInGraphMode => items.First().IsCreatedInGraphMode;
2524
public bool IsList { get; set; }
2625
public int Length => items.Count();
2726

src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ public void __init__(bool trainable = true,
6868
// when this object is garbage collected the deleter will be too. This
6969
// means ResourceVariables can be part of reference cycles without those
7070
// cycles being uncollectable.
71-
if (!handle.IsCreatedInGraphMode)
71+
if (handle is EagerTensor)
7272
{
7373
_handle = handle.EagerTensorHandle.DangerousGetHandle();
7474
eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device);

src/TensorFlowNET.Keras/Engine/Layer.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ limitations under the License.
1818
using System.Collections.Generic;
1919
using System.Linq;
2020
using System.Threading;
21+
using Tensorflow.Eager;
2122
using Tensorflow.Keras.ArgsDefinition;
2223
using Tensorflow.Keras.Saving;
2324
using Tensorflow.Keras.Utils;
25+
using Tensorflow.NumPy;
2426
using Tensorflow.Train;
2527
using static Tensorflow.Binding;
2628

@@ -118,7 +120,7 @@ public Layer(LayerArgs args)
118120
bool _in_functional_construction_mode(Tensors inputs)
119121
{
120122
return tf.Context.executing_eagerly()
121-
&& inputs.Count(x => x.IsCreatedInGraphMode) == inputs.Count();
123+
&& inputs.Count(x => x is not EagerTensor && x is not NDArray) == inputs.Count();
122124
}
123125

124126
public void SetConnectivityMetadata(Tensors inputs, Tensors outputs)
@@ -180,7 +182,7 @@ protected void MaybeBuild(Tensors inputs)
180182
tf.init_scope();
181183

182184
bool need_restore_mode = false;
183-
if (!inputs.IsCreatedInGraphMode || tf.Context.is_build_function())
185+
if (inputs.Any(x => x is EagerTensor) || tf.Context.is_build_function())
184186
{
185187
need_restore_mode = true;
186188
tf.Context.eager_mode(isFunc: tf.Context.is_build_function());

0 commit comments

Comments
 (0)