|
1 | 1 | using Microsoft.VisualStudio.TestPlatform.Utilities; |
2 | 2 | using Microsoft.VisualStudio.TestTools.UnitTesting; |
3 | 3 | using System; |
| 4 | +using System.Diagnostics; |
4 | 5 | using System.Linq; |
5 | 6 | using Tensorflow.NumPy; |
6 | 7 | using TensorFlowNET.UnitTest; |
@@ -69,5 +70,53 @@ public void TestBasic() |
69 | 70 | TestBasic<double>(); |
70 | 71 | } |
71 | 72 |
|
| 73 | + private void TestTensorLearningRate<T>() where T : struct |
| 74 | + { |
| 75 | + var dtype = GetTypeForNumericType<T>(); |
| 76 | + |
| 77 | + // train.GradientDescentOptimizer is V1 only API. |
| 78 | + tf.Graph().as_default(); |
| 79 | + using (var sess = self.cached_session()) |
| 80 | + { |
| 81 | + var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype); |
| 82 | + var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype); |
| 83 | + var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype); |
| 84 | + var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype); |
| 85 | + var lrate = constant_op.constant(3.0); |
| 86 | + var grads_and_vars = new[] { |
| 87 | + Tuple.Create(grads0, var0 as IVariableV1), |
| 88 | + Tuple.Create(grads1, var1 as IVariableV1) |
| 89 | + }; |
| 90 | + var sgd_op = tf.train.GradientDescentOptimizer(lrate) |
| 91 | + .apply_gradients(grads_and_vars); |
| 92 | + |
| 93 | + var global_variables = tf.global_variables_initializer(); |
| 94 | + sess.run(global_variables); |
| 95 | + |
| 96 | + var initialVar0 = sess.run(var0); |
| 97 | + var initialVar1 = sess.run(var1); |
| 98 | + // Fetch params to validate initial values |
| 99 | + self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0)); |
| 100 | + self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1)); |
| 101 | + // Run 1 step of sgd |
| 102 | + sgd_op.run(); |
| 103 | + // Validate updated params |
| 104 | + self.assertAllCloseAccordingToType( |
| 105 | + new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 }, |
| 106 | + self.evaluate<T[]>(var0)); |
| 107 | + self.assertAllCloseAccordingToType( |
| 108 | + new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 }, |
| 109 | + self.evaluate<T[]>(var1)); |
| 110 | + // TODO: self.assertEqual(0, len(optimizer.variables())); |
| 111 | + } |
| 112 | + } |
| 113 | + |
| 114 | + [TestMethod] |
| 115 | + public void TestTensorLearningRate() |
| 116 | + { |
| 117 | + //TODO: add np.half |
| 118 | + TestTensorLearningRate<float>(); |
| 119 | + TestTensorLearningRate<double>(); |
| 120 | + } |
72 | 121 | } |
73 | 122 | } |
0 commit comments