Skip to content

Commit c906f46

Browse files
learning rate test
1 parent f7b8dba commit c906f46

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Microsoft.VisualStudio.TestPlatform.Utilities;
22
using Microsoft.VisualStudio.TestTools.UnitTesting;
33
using System;
4+
using System.Diagnostics;
45
using System.Linq;
56
using Tensorflow.NumPy;
67
using TensorFlowNET.UnitTest;
@@ -69,5 +70,53 @@ public void TestBasic()
6970
TestBasic<double>();
7071
}
7172

73+
private void TestTensorLearningRate<T>() where T : struct
74+
{
75+
var dtype = GetTypeForNumericType<T>();
76+
77+
// train.GradientDescentOptimizer is V1 only API.
78+
tf.Graph().as_default();
79+
using (var sess = self.cached_session())
80+
{
81+
var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype);
82+
var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype);
83+
var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype);
84+
var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype);
85+
var lrate = constant_op.constant(3.0);
86+
var grads_and_vars = new[] {
87+
Tuple.Create(grads0, var0 as IVariableV1),
88+
Tuple.Create(grads1, var1 as IVariableV1)
89+
};
90+
var sgd_op = tf.train.GradientDescentOptimizer(lrate)
91+
.apply_gradients(grads_and_vars);
92+
93+
var global_variables = tf.global_variables_initializer();
94+
sess.run(global_variables);
95+
96+
var initialVar0 = sess.run(var0);
97+
var initialVar1 = sess.run(var1);
98+
// Fetch params to validate initial values
99+
self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0));
100+
self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1));
101+
// Run 1 step of sgd
102+
sgd_op.run();
103+
// Validate updated params
104+
self.assertAllCloseAccordingToType(
105+
new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 },
106+
self.evaluate<T[]>(var0));
107+
self.assertAllCloseAccordingToType(
108+
new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 },
109+
self.evaluate<T[]>(var1));
110+
// TODO: self.assertEqual(0, len(optimizer.variables()));
111+
}
112+
}
113+
114+
[TestMethod]
115+
public void TestTensorLearningRate()
116+
{
117+
//TODO: add np.half
118+
TestTensorLearningRate<float>();
119+
TestTensorLearningRate<double>();
120+
}
72121
}
73122
}

0 commit comments

Comments
 (0)