Skip to content

Commit cfc9eb3

Browse files
committed
Remove internal_convert_to_tensor.
1 parent 734fe29 commit cfc9eb3

File tree

10 files changed

+63
-61
lines changed

10 files changed

+63
-61
lines changed

src/TensorFlowNET.Core/APIs/tf.tensor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace Tensorflow
1919
public partial class tensorflow
2020
{
2121
public Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid)
22-
=> ops.convert_to_tensor(value, dtype, name, preferred_dtype);
22+
=> ops.convert_to_tensor(value, dtype, name, preferred_dtype: preferred_dtype);
2323

2424
public Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides = null,
2525
int begin_mask = 0,

src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ public Tensor AsPlaceholder(string name = null)
6969
return placeholder;
7070
}
7171

72-
public Tensor AsContatnt(string name = null)
72+
public Tensor AsConstant(string name = null)
7373
{
7474
Tensor constant = null;
7575
tf_with(ops.control_dependencies(null), delegate

src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public SparseTensor(long[,] indices_, T[] values_, long[] dense_shape_)
2929
{
3030
indices = ops.convert_to_tensor(
3131
indices_, name: "indices", dtype: dtypes.int64);
32-
values = ops.internal_convert_to_tensor(values_, name: "values");
32+
values = ops.convert_to_tensor(values_, name: "values");
3333
dense_shape = ops.convert_to_tensor(
3434
dense_shape_, name: "dense_shape", dtype: dtypes.int64);
3535
});

src/TensorFlowNET.Core/Graphs/FuncGraph.cs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@ namespace Tensorflow.Graphs
1313
/// </summary>
1414
public class FuncGraph : Graph
1515
{
16-
Graph outer_graph;
17-
public Graph OuterGraph => outer_graph;
18-
1916
// _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle));
2017
IntPtr func_handle;
2118
public string FuncName => _graph_key;
@@ -42,16 +39,21 @@ public Tensor[] internal_captures()
4239
public FuncGraph(string name) : base()
4340
{
4441
outer_graph = ops.get_default_graph();
42+
while (outer_graph.building_function)
43+
outer_graph = outer_graph.OuterGraph;
4544
_graph_key = name;
46-
45+
building_function = true;
4746
tf.Context.graph_mode();
4847
as_default();
4948
}
5049

5150
public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base()
5251
{
5352
outer_graph = ops.get_default_graph();
53+
while (outer_graph.building_function)
54+
outer_graph = outer_graph.OuterGraph;
5455
_graph_key = name;
56+
building_function = true;
5557
Attrs = attrs;
5658
// Will to test if FuncGraph has memory leak
5759
// c_api.TF_DeleteGraph(_handle);
@@ -108,7 +110,7 @@ public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType
108110
return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device);
109111
}
110112

111-
Tensor capture(Tensor tensor, string name = null, TF_DataType shape = TF_DataType.DtInvalid)
113+
public Tensor capture(Tensor tensor, string name = null, TF_DataType shape = TF_DataType.DtInvalid)
112114
{
113115
if(tensor is EagerTensor)
114116
{

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ public int seed
118118
}
119119
}
120120

121+
protected Graph outer_graph;
122+
public Graph OuterGraph => outer_graph;
123+
121124
public Graph()
122125
{
123126
_handle = c_api.TF_NewGraph();

src/TensorFlowNET.Core/Operations/OpDefLibrary.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ public Operation _apply_op_helper(string op_type_name, string name = null, Dicti
148148
else if (default_type_attr_map.ContainsKey(input_arg.TypeAttr))
149149
default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr];
150150

151-
var value = ops.internal_convert_to_tensor(values,
151+
var value = ops.convert_to_tensor(values,
152152
name: input_name,
153153
dtype: dtype.as_tf_dtype(),
154154
as_ref: input_arg.IsRef,

src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public static IEnumerable<MySaveableObject> saveable_objects_for_op(Tensor op, s
6666
else
6767
{
6868
ops.init_scope();
69-
var variable = ops.internal_convert_to_tensor(op, as_ref: true);
69+
var variable = ops.convert_to_tensor(op, as_ref: true);
7070
if (variable.dtype.is_ref_dtype())
7171
yield return new ReferenceVariableSaveable(variable, "", name);
7272
else
@@ -103,7 +103,7 @@ public static Dictionary<string, Tensor> op_list_to_dict(IVariableV1[] op_list,
103103
if (!var.dtype.is_ref_dtype())
104104
tensor = var.GraphElement;
105105
else
106-
tensor = ops.internal_convert_to_tensor(var, as_ref: true);
106+
tensor = ops.convert_to_tensor(var, as_ref: true);
107107
}
108108

109109
if (tensor.op.type == "ReadVariableOp")

src/TensorFlowNET.Core/ops.cs

Lines changed: 38 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
using System.Threading;
2525
using Tensorflow.Contexts;
2626
using Tensorflow.Eager;
27+
using Tensorflow.Graphs;
2728
using Tensorflow.Util;
2829
using static Tensorflow.Binding;
2930

@@ -101,14 +102,44 @@ public static Graph _get_graph_from_inputs(Tensors op_input_list, Graph graph =
101102
public static Tensor convert_to_tensor(object value,
102103
TF_DataType dtype = TF_DataType.DtInvalid,
103104
string name = null,
105+
bool as_ref = false,
104106
TF_DataType preferred_dtype = TF_DataType.DtInvalid,
105107
Context ctx = null)
106108
{
107-
return internal_convert_to_tensor(value,
108-
dtype: dtype,
109-
name: name,
110-
preferred_dtype: preferred_dtype,
111-
as_ref: false);
109+
if (dtype == TF_DataType.DtInvalid)
110+
dtype = preferred_dtype;
111+
112+
if (value is EagerTensor eager_tensor)
113+
{
114+
if (tf.executing_eagerly())
115+
return eager_tensor;
116+
/*else
117+
{
118+
var graph = get_default_graph();
119+
if (!graph.building_function)
120+
throw new RuntimeError("Attempting to capture an EagerTensor without building a function.");
121+
return (graph as FuncGraph).capture(eager_tensor, name: name);
122+
}*/
123+
}
124+
125+
Tensor ret = value switch
126+
{
127+
NDArray nd => constant_op.constant(nd, dtype: dtype, name: name),
128+
EagerTensor tensor => tensor.dtype == TF_DataType.TF_RESOURCE
129+
? tensor.AsPlaceholder(name: name)
130+
: tensor.AsConstant(name: name),
131+
Tensor tensor => tensor,
132+
Tensor[] tensors => array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name),
133+
RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
134+
ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
135+
TensorShape ts => constant_op.constant(ts.dims, dtype: dtype, name: name),
136+
int[] dims => constant_op.constant(dims, dtype: dtype, name: name),
137+
string str => constant_op.constant(str, dtype: tf.@string, name: name),
138+
object[] objects => array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name),
139+
_ => constant_op.constant(value, dtype: dtype, name: name)
140+
};
141+
142+
return ret;
112143
}
113144

114145

@@ -118,9 +149,7 @@ public static Tensor convert_to_tensor_or_composite(Tensor value, TF_DataType dt
118149
}
119150

120151
public static Tensor internal_convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
121-
{
122-
return internal_convert_to_tensor(value, dtype: dtype, name: name, as_ref: as_ref);
123-
}
152+
=> convert_to_tensor(value, dtype: dtype, name: name, as_ref: as_ref);
124153

125154
/// <summary>
126155
/// Wrapper for `Graph.control_dependencies()` using the default graph.
@@ -460,52 +489,12 @@ public static Tensor[] internal_convert_n_to_tensor(object values, TF_DataType d
460489
foreach ((int i, object value) in enumerate(values as object[]))
461490
{
462491
string n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}";
463-
ret.Add(internal_convert_to_tensor(value, dtype: dtype, name: n, as_ref: as_ref, preferred_dtype: preferred_dtype));
492+
ret.Add(convert_to_tensor(value, dtype: dtype, name: n, as_ref: as_ref, preferred_dtype: preferred_dtype));
464493
}
465494

466495
return ret.ToArray();
467496
}
468497

469-
public static Tensor internal_convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid,
470-
string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid,
471-
bool as_ref = false,
472-
string scope = null)
473-
{
474-
if (dtype == TF_DataType.DtInvalid)
475-
dtype = preferred_dtype;
476-
477-
switch (value)
478-
{
479-
case NDArray nd:
480-
return constant_op.constant(nd, dtype: dtype, name: name);
481-
case EagerTensor tensor:
482-
if (tf.executing_eagerly())
483-
return tensor;
484-
else
485-
return tensor.dtype == TF_DataType.TF_RESOURCE
486-
? tensor.AsPlaceholder(name: name)
487-
: tensor.AsContatnt(name: name);
488-
case Tensor tensor:
489-
return tensor;
490-
case Tensor[] tensors:
491-
return array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name);
492-
case RefVariable varVal:
493-
return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref);
494-
case ResourceVariable varVal:
495-
return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref);
496-
case TensorShape ts:
497-
return constant_op.constant(ts.dims, dtype: dtype, name: name);
498-
case string str:
499-
return constant_op.constant(value, dtype: tf.@string, name: name);
500-
case int[] dims:
501-
return constant_op.constant(dims, dtype: dtype, name: name);
502-
case object[] objects:
503-
return array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name);
504-
default:
505-
return constant_op.constant(value, dtype: dtype, name: name);
506-
}
507-
}
508-
509498
public static string strip_name_scope(string name, string export_scope = "")
510499
{
511500
if (!string.IsNullOrEmpty(export_scope))

src/TensorFlowNET.Keras/BackendImpl.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
using NumSharp;
1818
using System;
1919
using System.Collections.Generic;
20+
using Tensorflow.Graphs;
2021
using static Tensorflow.Binding;
2122

2223
namespace Tensorflow.Keras
@@ -78,6 +79,12 @@ public Tensor placeholder(TensorShape shape = null,
7879

7980
public Graph get_graph()
8081
{
82+
if (tf.Context.executing_eagerly())
83+
{
84+
if (_GRAPH == null)
85+
_GRAPH = new FuncGraph("keras_graph");
86+
return _GRAPH;
87+
}
8188
return ops.get_default_graph();
8289
}
8390

src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using Tensorflow.Keras.Utils;
33
using static Tensorflow.Binding;
4+
using static Tensorflow.KerasApi;
45

56
namespace Tensorflow.Keras.Engine
67
{
@@ -22,7 +23,7 @@ Tensors FunctionalConstructionCall(Tensors inputs)
2223
Tensors outputs = null;
2324
using var ctxManager = CallContext.enter();
2425

25-
// using var graph = tf.keras.backend.get_graph().as_default();
26+
// using var graph = keras.backend.get_graph();
2627

2728
if (!inputs.IsEagerTensor)
2829
tf.Context.graph_mode(isFunc: true);

0 commit comments

Comments
 (0)