Skip to content

Commit 9f0ffa4

Browse files
committed
Implemented unittests for Concatenate layers and calls
The loading and saving of a simple model with a Concatenate layer is tested to check if the model is the same after reloading. Implemented missing axis parameter for np.stack (added some handy tuple calls too like the np.concatenate example).
1 parent 93a242c commit 9f0ffa4

File tree

3 files changed

+62
-5
lines changed

3 files changed

+62
-5
lines changed

src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ public static NDArray concatenate((NDArray, NDArray) tuple, int axis = 0)
3030
[AutoNumPy]
3131
public static NDArray stack(params NDArray[] arrays) => new NDArray(array_ops.stack(arrays));
3232

33+
[AutoNumPy]
34+
public static NDArray stack(NDArray[] arrays, int axis = 0) => new NDArray(array_ops.stack(arrays, axis));
35+
36+
[AutoNumPy]
37+
public static NDArray stack((NDArray, NDArray) tuple, int axis = 0) => new NDArray(array_ops.stack(new[] { tuple.Item1, tuple.Item2 }, axis));
38+
39+
[AutoNumPy]
40+
public static NDArray stack((NDArray, NDArray, NDArray) tuple, int axis = 0) => new NDArray(array_ops.stack(new[] { tuple.Item1, tuple.Item2, tuple.Item3 }, axis));
41+
3342
[AutoNumPy]
3443
public static NDArray moveaxis(NDArray array, Axis source, Axis destination) => new NDArray(array_ops.moveaxis(array, source, destination));
3544
}
Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System.Collections.Generic;
23
using Tensorflow.NumPy;
34
using static Tensorflow.KerasApi;
45

@@ -8,12 +9,16 @@ namespace Tensorflow.Keras.UnitTest.Layers
89
public class LayersMergingTest : EagerModeTestBase
910
{
1011
[TestMethod]
11-
public void Concatenate()
12+
[DataRow(1, 4, 1, 5)]
13+
[DataRow(2, 2, 2, 5)]
14+
[DataRow(3, 2, 1, 10)]
15+
public void Concatenate(int axis, int shapeA, int shapeB, int shapeC)
1216
{
13-
var x = np.arange(20).reshape((2, 2, 5));
14-
var y = np.arange(20, 30).reshape((2, 1, 5));
15-
var z = keras.layers.Concatenate(axis: 1).Apply(new Tensors(x, y));
16-
Assert.AreEqual((2, 3, 5), z.shape);
17+
var x = np.arange(10).reshape((1, 2, 1, 5));
18+
var y = np.arange(10, 20).reshape((1, 2, 1, 5));
19+
var z = keras.layers.Concatenate(axis: axis).Apply(new Tensors(x, y));
20+
Assert.AreEqual((1, shapeA, shapeB, shapeC), z.shape);
1721
}
22+
1823
}
1924
}

test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
using Microsoft.VisualStudio.TestPlatform.Utilities;
22
using Microsoft.VisualStudio.TestTools.UnitTesting;
3+
using Newtonsoft.Json.Linq;
34
using System.Linq;
5+
using System.Xml.Linq;
46
using Tensorflow.Keras.Engine;
57
using Tensorflow.Keras.Optimizers;
68
using Tensorflow.Keras.UnitTest.Helpers;
79
using Tensorflow.NumPy;
10+
using static HDF.PInvoke.H5Z;
811
using static Tensorflow.Binding;
912
using static Tensorflow.KerasApi;
1013

@@ -124,4 +127,44 @@ public void TestModelBeforeTF2_5()
124127
var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Tensorflow.Keras.Engine.Model;
125128
model.summary();
126129
}
130+
131+
132+
133+
[TestMethod]
134+
public void CreateConcatenateModelSaveAndLoad()
135+
{
136+
// a small demo model that is just here to see if the axis value for the concatenate method is saved and loaded.
137+
var input_layer = tf.keras.layers.Input((8, 8, 5));
138+
139+
var conv1 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_1"*/).Apply(input_layer);
140+
conv1.Name = "conv1";
141+
142+
var conv2 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_2"*/).Apply(input_layer);
143+
conv2.Name = "conv2";
144+
145+
var concat1 = tf.keras.layers.Concatenate(axis: 3).Apply((conv1, conv2));
146+
concat1.Name = "concat1";
147+
148+
var model = tf.keras.Model(input_layer, concat1);
149+
model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy());
150+
151+
model.save(@"Assets/concat_axis3_model");
152+
153+
154+
var tensorInput = np.arange(320).reshape((1, 8, 8, 5)).astype(TF_DataType.TF_FLOAT);
155+
156+
var tensors1 = model.predict(tensorInput);
157+
158+
Assert.AreEqual((1, 8, 8, 4), tensors1.shape);
159+
160+
model = null;
161+
keras.backend.clear_session();
162+
163+
var model2 = tf.keras.models.load_model(@"Assets/concat_axis3_model");
164+
165+
var tensors2 = model2.predict(tensorInput);
166+
167+
Assert.AreEqual(tensors1.shape, tensors2.shape);
168+
}
169+
127170
}

0 commit comments

Comments
 (0)