|
1 | 1 | using Microsoft.VisualStudio.TestPlatform.Utilities; |
2 | 2 | using Microsoft.VisualStudio.TestTools.UnitTesting; |
| 3 | +using Newtonsoft.Json.Linq; |
3 | 4 | using System.Linq; |
| 5 | +using System.Xml.Linq; |
4 | 6 | using Tensorflow.Keras.Engine; |
5 | 7 | using Tensorflow.Keras.Optimizers; |
6 | 8 | using Tensorflow.Keras.UnitTest.Helpers; |
7 | 9 | using Tensorflow.NumPy; |
| 10 | +using static HDF.PInvoke.H5Z; |
8 | 11 | using static Tensorflow.Binding; |
9 | 12 | using static Tensorflow.KerasApi; |
10 | 13 |
|
@@ -124,4 +127,44 @@ public void TestModelBeforeTF2_5() |
124 | 127 | var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Tensorflow.Keras.Engine.Model; |
125 | 128 | model.summary(); |
126 | 129 | } |
| 130 | + |
| 131 | + |
| 132 | + |
| 133 | + [TestMethod] |
| 134 | + public void CreateConcatenateModelSaveAndLoad() |
| 135 | + { |
| 136 | + // a small demo model that is just here to see if the axis value for the concatenate method is saved and loaded. |
| 137 | + var input_layer = tf.keras.layers.Input((8, 8, 5)); |
| 138 | + |
| 139 | + var conv1 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_1"*/).Apply(input_layer); |
| 140 | + conv1.Name = "conv1"; |
| 141 | + |
| 142 | + var conv2 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_2"*/).Apply(input_layer); |
| 143 | + conv2.Name = "conv2"; |
| 144 | + |
| 145 | + var concat1 = tf.keras.layers.Concatenate(axis: 3).Apply((conv1, conv2)); |
| 146 | + concat1.Name = "concat1"; |
| 147 | + |
| 148 | + var model = tf.keras.Model(input_layer, concat1); |
| 149 | + model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy()); |
| 150 | + |
| 151 | + model.save(@"Assets/concat_axis3_model"); |
| 152 | + |
| 153 | + |
| 154 | + var tensorInput = np.arange(320).reshape((1, 8, 8, 5)).astype(TF_DataType.TF_FLOAT); |
| 155 | + |
| 156 | + var tensors1 = model.predict(tensorInput); |
| 157 | + |
| 158 | + Assert.AreEqual((1, 8, 8, 4), tensors1.shape); |
| 159 | + |
| 160 | + model = null; |
| 161 | + keras.backend.clear_session(); |
| 162 | + |
| 163 | + var model2 = tf.keras.models.load_model(@"Assets/concat_axis3_model"); |
| 164 | + |
| 165 | + var tensors2 = model2.predict(tensorInput); |
| 166 | + |
| 167 | + Assert.AreEqual(tensors1.shape, tensors2.shape); |
| 168 | + } |
| 169 | + |
127 | 170 | } |
0 commit comments