Skip to content

Commit 58d2dae

Browse files
committed
Fix graph instance in InputLayer.
1 parent 05dd652 commit 58d2dae

File tree

4 files changed

+11
-30
lines changed

4 files changed

+11
-30
lines changed

src/TensorFlowNET.Core/Graphs/FuncGraph.cs

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ public class FuncGraph : Graph
1717
IntPtr func_handle;
1818
public string FuncName => _graph_key;
1919

20-
public Tensors Inputs { get; set; }
21-
public Tensors Outputs { get; set; }
20+
public Tensors Inputs { get; set; } = new Tensors();
21+
public Tensors Outputs { get; set; } = new Tensors();
2222
public Dictionary<string, string> Attrs { get; set; }
2323

2424
public Dictionary<long, (Tensor, Tensor)> _captures
@@ -175,14 +175,7 @@ Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null)
175175
void add_capture(Tensor tensor, Tensor placeholder)
176176
{
177177
_captures.Add(tensor.Id, (tensor, placeholder));
178-
if (Inputs == null)
179-
Inputs = new Tensors(placeholder);
180-
else
181-
{
182-
var inputs = Inputs.ToList();
183-
inputs.Add(placeholder);
184-
Inputs = new Tensors(inputs.ToArray());
185-
}
178+
Inputs.Add(placeholder);
186179
}
187180

188181
Tensor _create_substitute_placeholder(Tensor value,

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ public class BaseSession : DisposableObject
3939
public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null)
4040
{
4141
_graph = g ?? ops.get_default_graph();
42-
_graph.as_default();
42+
if (!_graph.building_function)
43+
_graph.as_default();
4344
_target = Encoding.UTF8.GetBytes(target);
4445

4546
using (var opts = new SessionOptions(target, config))

src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,6 @@ public InputLayer(InputLayerArgs args) :
5858
args.DType = args.InputTensor == null ? tf.float32 : args.InputTensor.dtype;
5959
}
6060

61-
// In graph mode, create a graph placeholder to call the layer on.
62-
tf.Context.graph_mode();
63-
6461
if (args.InputTensor == null)
6562
{
6663
if (args.InputShape != null)
@@ -74,15 +71,18 @@ public InputLayer(InputLayerArgs args) :
7471
args.BatchInputShape = null;
7572
}
7673

74+
var graph = keras.backend.get_graph();
75+
graph.as_default();
76+
7777
args.InputTensor = keras.backend.placeholder(
7878
shape: BatchInputShape,
7979
dtype: DType,
8080
name: Name,
8181
sparse: args.Sparse,
8282
ragged: args.Ragged);
8383

84-
8584
isPlaceholder = true;
85+
tf.Context.restore_mode();
8686
}
8787

8888
// Create an input node to add to self.outbound_node
@@ -97,8 +97,6 @@ public InputLayer(InputLayerArgs args) :
9797
typeSpec = new TensorSpec(args.InputTensor.TensorShape,
9898
dtype: args.InputTensor.dtype,
9999
name: Name);
100-
101-
tf.Context.restore_mode();
102100
}
103101

104102
public static InputLayer from_config(LayerArgs args)

src/TensorFlowNET.Keras/Utils/base_layer_utils.cs

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -151,23 +151,12 @@ public static void CreateKerasHistoryHelper(Tensors tensors, List<Operation> pro
151151

152152
// recursively
153153
CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers);
154-
Layer op_layer = null;
155-
/*var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs
154+
Layer op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs
156155
{
157156
NodeDef = op.node_def,
158157
Constants = constants,
159158
Name = op.name
160-
});*/
161-
op_layer = op.type switch
162-
{
163-
// "AddV2" => keras.layers.Add(),
164-
_ => new TensorFlowOpLayer(new TensorFlowOpLayerArgs
165-
{
166-
NodeDef = op.node_def,
167-
Constants = constants,
168-
Name = op.name
169-
})
170-
};
159+
});
171160
created_layers.Add(op_layer);
172161
op_layer.SetConnectivityMetadata(layer_inputs, op.outputs);
173162
processed_ops.Add(op);

0 commit comments

Comments
 (0)