Skip to content

Commit 93a242c

Browse files
committed
Implemented support for loading Concatenate layers
model.load_model now supports loading of concatenate layers. python tensorflow exports concatenate layers in an extra nested array in the manifest so added a check for that in generic_utils.cs. Concatenate was missing the build=true, this fix prevents the layer being build multiple times. Concatenate has 2 or more input nodes so List<NodeConfig> was required instead of just NodeConfig in Functional.FromConfig.cs. Added missing axis JsonProperty attribute for MergeArgs (used by Concatenate)
1 parent 090dc1e commit 93a242c

File tree

4 files changed

+35
-15
lines changed

4 files changed

+35
-15
lines changed
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
using System;
1+
using Newtonsoft.Json;
2+
using System;
23
using System.Collections.Generic;
34
using System.Text;
45

56
namespace Tensorflow.Keras.ArgsDefinition
67
{
78
// TODO: complete the implementation
8-
public class MergeArgs : LayerArgs
9+
public class MergeArgs : AutoSerializeLayerArgs
910
{
1011
public Tensors Inputs { get; set; }
12+
[JsonProperty("axis")]
1113
public int Axis { get; set; }
1214
}
1315
}

src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_co
3030
created_layers = created_layers ?? new Dictionary<string, ILayer>();
3131
var node_index_map = new Dictionary<(string, int), int>();
3232
var node_count_by_layer = new Dictionary<ILayer, int>();
33-
var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>();
33+
var unprocessed_nodes = new Dictionary<ILayer, List<NodeConfig>>();
3434
// First, we create all layers and enqueue nodes to be processed
3535
foreach (var layer_data in config.Layers)
3636
process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer);
@@ -79,7 +79,7 @@ public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_co
7979

8080
static void process_layer(Dictionary<string, ILayer> created_layers,
8181
LayerConfig layer_data,
82-
Dictionary<ILayer, NodeConfig> unprocessed_nodes,
82+
Dictionary<ILayer, List<NodeConfig>> unprocessed_nodes,
8383
Dictionary<ILayer, int> node_count_by_layer)
8484
{
8585
ILayer layer = null;
@@ -92,32 +92,38 @@ static void process_layer(Dictionary<string, ILayer> created_layers,
9292

9393
created_layers[layer_name] = layer;
9494
}
95-
node_count_by_layer[layer] = _should_skip_first_node(layer) ? 1 : 0;
95+
node_count_by_layer[layer] = layer_data.InboundNodes.Count - (_should_skip_first_node(layer) ? 1 : 0);
9696

9797
var inbound_nodes_data = layer_data.InboundNodes;
9898
foreach (var node_data in inbound_nodes_data)
9999
{
100100
if (!unprocessed_nodes.ContainsKey(layer))
101-
unprocessed_nodes[layer] = node_data;
101+
unprocessed_nodes[layer] = new List<NodeConfig>() { node_data };
102102
else
103-
unprocessed_nodes.Add(layer, node_data);
103+
unprocessed_nodes[layer].Add(node_data);
104104
}
105105
}
106106

107107
static void process_node(ILayer layer,
108-
NodeConfig node_data,
108+
List<NodeConfig> nodes_data,
109109
Dictionary<string, ILayer> created_layers,
110110
Dictionary<ILayer, int> node_count_by_layer,
111111
Dictionary<(string, int), int> node_index_map)
112112
{
113+
113114
var input_tensors = new List<Tensor>();
114-
var inbound_layer_name = node_data.Name;
115-
var inbound_node_index = node_data.NodeIndex;
116-
var inbound_tensor_index = node_data.TensorIndex;
117115

118-
var inbound_layer = created_layers[inbound_layer_name];
119-
var inbound_node = inbound_layer.InboundNodes[inbound_node_index];
120-
input_tensors.Add(inbound_node.Outputs[inbound_node_index]);
116+
for (int i = 0; i < nodes_data.Count; i++)
117+
{
118+
var node_data = nodes_data[i];
119+
var inbound_layer_name = node_data.Name;
120+
var inbound_node_index = node_data.NodeIndex;
121+
var inbound_tensor_index = node_data.TensorIndex;
122+
123+
var inbound_layer = created_layers[inbound_layer_name];
124+
var inbound_node = inbound_layer.InboundNodes[inbound_node_index];
125+
input_tensors.Add(inbound_node.Outputs[inbound_node_index]);
126+
}
121127

122128
var output_tensors = layer.Apply(input_tensors);
123129

src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ public override void build(KerasShapesWrapper input_shape)
3939
shape_set.Add(shape);
4040
}*/
4141
_buildInputShape = input_shape;
42+
built = true;
4243
}
4344

4445
protected override Tensors _merge_function(Tensors inputs)

src/TensorFlowNET.Keras/Utils/generic_utils.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,23 @@ public static FunctionalConfig deserialize_model_config(JToken json)
112112
foreach (var token in layersToken)
113113
{
114114
var args = deserialize_layer_args(token["class_name"].ToObject<string>(), token["config"]);
115+
116+
List<NodeConfig> nodeConfig = null; //python tensorflow sometimes exports inbound nodes in an extra nested array
117+
if (token["inbound_nodes"].Count() > 0 && token["inbound_nodes"][0].Count() > 0 && token["inbound_nodes"][0][0].Count() > 0)
118+
{
119+
nodeConfig = token["inbound_nodes"].ToObject<List<List<NodeConfig>>>().FirstOrDefault() ?? new List<NodeConfig>();
120+
}
121+
else
122+
{
123+
nodeConfig = token["inbound_nodes"].ToObject<List<NodeConfig>>();
124+
}
125+
115126
config.Layers.Add(new LayerConfig()
116127
{
117128
Config = args,
118129
Name = token["name"].ToObject<string>(),
119130
ClassName = token["class_name"].ToObject<string>(),
120-
InboundNodes = token["inbound_nodes"].ToObject<List<NodeConfig>>()
131+
InboundNodes = nodeConfig,
121132
});
122133
}
123134
config.InputLayers = json["input_layers"].ToObject<List<NodeConfig>>();

0 commit comments

Comments
 (0)