Skip to content

Commit 17a4fe0

Browse files
lsylusiyaoEsther2013
authored andcommitted
Add Custom Keras Layer Test
This test is written based on Dueliing DQN' s network structure.
1 parent 0d5e5e0 commit 17a4fe0

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,72 @@ public void Functional()
3535
var model = keras.Model(inputs, outputs, name: "mnist_model");
3636
model.summary();
3737
}
38+
39+
/// <summary>
40+
/// Custom layer test, used in Dueling DQN
41+
/// </summary>
42+
[TestMethod, Ignore]
43+
public void FunctionalTest()
44+
{
45+
var layers = keras.layers;
46+
var inputs = layers.Input(shape: 24);
47+
var x = layers.Dense(128, activation:"relu").Apply(inputs);
48+
var value = layers.Dense(24).Apply(x);
49+
var adv = layers.Dense(1).Apply(x);
50+
51+
var adv_out = adv - Binding.tf.reduce_mean(adv, axis: 1, keepdims: true); // Here's problem.
52+
var outputs = layers.Add().Apply(new Tensors(adv_out, value));
53+
var model = keras.Model(inputs, outputs);
54+
model.summary();
55+
model.compile(optimizer: keras.optimizers.RMSprop(0.001f),
56+
loss: keras.losses.MeanSquaredError(),
57+
metrics: new[] { "acc" });
58+
// Here we consider the adv_out is one layer, which is a little different from py's version
59+
Assert.AreEqual(model.Layers.Count, 6);
60+
61+
// py code:
62+
//from tensorflow.keras.layers import Input, Dense, Add, Subtract, Lambda
63+
//from tensorflow.keras.models import Model
64+
//from tensorflow.keras.optimizers import RMSprop
65+
//import tensorflow.keras.backend as K
66+
67+
//inputs = Input(24)
68+
//x = Dense(128, activation = "relu")(inputs)
69+
//value = Dense(24)(x)
70+
//adv = Dense(1)(x)
71+
//meam = Lambda(lambda x: K.mean(x, axis = 1, keepdims = True))(adv)
72+
//adv = Subtract()([adv, meam])
73+
//outputs = Add()([value, adv])
74+
//model = Model(inputs, outputs)
75+
//model.compile(loss = "mse", optimizer = RMSprop(1e-3))
76+
//model.summary()
77+
78+
//py output:
79+
//Model: "functional_3"
80+
//__________________________________________________________________________________________________
81+
//Layer(type) Output Shape Param # Connected to
82+
//==================================================================================================
83+
//input_2 (InputLayer) [(None, 24)] 0
84+
//__________________________________________________________________________________________________
85+
//dense_3 (Dense) (None, 128) 3200 input_2[0][0]
86+
//__________________________________________________________________________________________________
87+
//dense_5 (Dense) (None, 1) 129 dense_3[0][0]
88+
//__________________________________________________________________________________________________
89+
//lambda_1 (Lambda) (None, 1) 0 dense_5[0][0]
90+
//__________________________________________________________________________________________________
91+
//dense_4 (Dense) (None, 24) 3096 dense_3[0][0]
92+
//__________________________________________________________________________________________________
93+
//subtract_1 (Subtract) (None, 1) 0 dense_5[0][0]
94+
// lambda_1[0][0]
95+
//__________________________________________________________________________________________________
96+
//add_1 (Add) (None, 24) 0 dense_4[0][0]
97+
// subtract_1[0][0]
98+
//==================================================================================================
99+
//Total params: 6,425
100+
//Trainable params: 6,425
101+
//Non-trainable params: 0
102+
//__________________________________________________________________________________________________
103+
}
38104

39105
/// <summary>
40106
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding

0 commit comments

Comments
 (0)