Skip to content

Commit 2121079

Browse files
gradient descent tests
1 parent 3811e4e commit 2121079

File tree

4 files changed

+250
-5
lines changed

4 files changed

+250
-5
lines changed

src/TensorFlowNET.Core/Variables/variables.py.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ public static List<IVariableV1> global_variables(string scope = null)
7272
public static Operation variables_initializer(IVariableV1[] var_list, string name = "init")
7373
{
7474
if (var_list.Length > 0)
75+
{
7576
return control_flow_ops.group(var_list.Select(x => x.Initializer).ToArray(), name);
77+
}
7678
else
7779
return gen_control_flow_ops.no_op(name: name);
7880
}
@@ -155,7 +157,10 @@ public static Operation _safe_initial_value_from_op(string name, Operation op, D
155157

156158
public static Tensor global_variables_initializer()
157159
{
158-
throw new NotImplementedException();
160+
// if context.executing_eagerly():
161+
// return control_flow_ops.no_op(name = "global_variables_initializer")
162+
var group = variables_initializer(global_variables().ToArray());
163+
return group;
159164
}
160165
}
161166
}

test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -776,8 +776,6 @@ public void testUnconnectedGradientsNoneUnconnectedGradients()
776776
[TestMethod]
777777
public void testUnconnectedGradientsZerosUnconnectedGradients()
778778
{
779-
780-
781779
//def testUnconnectedGradientsZerosUnconnectedGradients(self):
782780
// with ops.Graph().as_default():
783781
// x = constant(1.0, shape=[2, 2])

test/TensorFlowNET.UnitTest/PythonTest.cs

Lines changed: 176 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,37 @@ public void assertAllClose(double value, NDArray array2, double eps = 1e-5)
144144
Assert.IsTrue(np.allclose(array1, array2, rtol: eps));
145145
}
146146

147+
private class CollectionComparer : System.Collections.IComparer
148+
{
149+
private readonly double _epsilon;
150+
151+
public CollectionComparer(double eps = 1e-06) {
152+
_epsilon = eps;
153+
}
154+
public int Compare(object x, object y)
155+
{
156+
var a = (double)x;
157+
var b = (double)y;
158+
159+
double delta = Math.Abs(a - b);
160+
if (delta < _epsilon)
161+
{
162+
return 0;
163+
}
164+
return a.CompareTo(b);
165+
}
166+
}
167+
168+
public void assertAllCloseAccordingToType<T>(
169+
T[] expected,
170+
T[] given,
171+
double eps = 1e-6,
172+
float float_eps = 1e-6f)
173+
{
174+
// TODO: check if any of arguments is not double and change toletance
175+
CollectionAssert.AreEqual(expected, given, new CollectionComparer(eps));
176+
}
177+
147178
public void assertProtoEquals(object toProto, object o)
148179
{
149180
throw new NotImplementedException();
@@ -153,6 +184,20 @@ public void assertProtoEquals(object toProto, object o)
153184

154185
#region tensor evaluation and test session
155186

187+
private Session _cached_session = null;
188+
private Graph _cached_graph = null;
189+
private object _cached_config = null;
190+
private bool _cached_force_gpu = false;
191+
192+
private void _ClearCachedSession()
193+
{
194+
if (self._cached_session != null)
195+
{
196+
self._cached_session.Dispose();
197+
self._cached_session = null;
198+
}
199+
}
200+
156201
//protected object _eval_helper(Tensor[] tensors)
157202
//{
158203
// if (tensors == null)
@@ -218,9 +263,56 @@ public T evaluate<T>(Tensor tensor)
218263
}
219264

220265

221-
public Session cached_session()
266+
///Returns a TensorFlow Session for use in executing tests.
267+
public Session cached_session(
268+
Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
222269
{
223-
throw new NotImplementedException();
270+
// This method behaves differently than self.session(): for performance reasons
271+
// `cached_session` will by default reuse the same session within the same
272+
// test.The session returned by this function will only be closed at the end
273+
// of the test(in the TearDown function).
274+
275+
// Use the `use_gpu` and `force_gpu` options to control where ops are run.If
276+
// `force_gpu` is True, all ops are pinned to `/ device:GPU:0`. Otherwise, if
277+
// `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
278+
// possible.If both `force_gpu and `use_gpu` are False, all ops are pinned to
279+
// the CPU.
280+
281+
// Example:
282+
// python
283+
// class MyOperatorTest(test_util.TensorFlowTestCase) :
284+
// def testMyOperator(self):
285+
// with self.cached_session() as sess:
286+
// valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
287+
// result = MyOperator(valid_input).eval()
288+
// self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
289+
// invalid_input = [-1.0, 2.0, 7.0]
290+
// with self.assertRaisesOpError("negative input not supported"):
291+
// MyOperator(invalid_input).eval()
292+
293+
294+
// Args:
295+
// graph: Optional graph to use during the returned session.
296+
// config: An optional config_pb2.ConfigProto to use to configure the
297+
// session.
298+
// use_gpu: If True, attempt to run as many ops as possible on GPU.
299+
// force_gpu: If True, pin all ops to `/device:GPU:0`.
300+
301+
// Yields:
302+
// A Session object that should be used as a context manager to surround
303+
// the graph building and execution code in a test case.
304+
305+
306+
// TODO:
307+
// if context.executing_eagerly():
308+
// return self._eval_helper(tensors)
309+
// else:
310+
{
311+
var sess = self._get_cached_session(
312+
graph, config, force_gpu, crash_if_inconsistent_args: true);
313+
using var cached = self._constrain_devices_and_set_default(sess, use_gpu, force_gpu);
314+
return cached;
315+
}
224316
}
225317

226318
//Returns a TensorFlow Session for use in executing tests.
@@ -268,6 +360,40 @@ public Session session(Graph graph = null, object config = null, bool use_gpu =
268360
return s.as_default();
269361
}
270362

363+
private Session _constrain_devices_and_set_default(Session sess, bool use_gpu, bool force_gpu)
364+
{
365+
// Set the session and its graph to global default and constrain devices."""
366+
if (tf.executing_eagerly())
367+
return null;
368+
else
369+
{
370+
sess.graph.as_default();
371+
sess.as_default();
372+
{
373+
if (force_gpu)
374+
{
375+
// TODO:
376+
377+
// Use the name of an actual device if one is detected, or
378+
// '/device:GPU:0' otherwise
379+
/* var gpu_name = gpu_device_name();
380+
if (!gpu_name)
381+
gpu_name = "/device:GPU:0"
382+
using (sess.graph.device(gpu_name)) {
383+
yield return sess;
384+
}*/
385+
return sess;
386+
}
387+
else if (use_gpu)
388+
return sess;
389+
else
390+
using (sess.graph.device("/device:CPU:0"))
391+
return sess;
392+
}
393+
394+
}
395+
}
396+
271397
// See session() for details.
272398
private Session _create_session(Graph graph, object cfg, bool forceGpu)
273399
{
@@ -312,6 +438,54 @@ private Session _create_session(Graph graph, object cfg, bool forceGpu)
312438
return new Session(graph);//, config = prepare_config(config))
313439
}
314440

441+
private Session _get_cached_session(
442+
Graph graph = null,
443+
object config = null,
444+
bool force_gpu = false,
445+
bool crash_if_inconsistent_args = true)
446+
{
447+
// See cached_session() for documentation.
448+
if (self._cached_session == null)
449+
{
450+
var sess = self._create_session(graph, config, force_gpu);
451+
self._cached_session = sess;
452+
self._cached_graph = graph;
453+
self._cached_config = config;
454+
self._cached_force_gpu = force_gpu;
455+
return sess;
456+
}
457+
else
458+
{
459+
460+
if (crash_if_inconsistent_args && !self._cached_graph.Equals(graph))
461+
throw new ValueError(@"The graph used to get the cached session is
462+
different than the one that was used to create the
463+
session. Maybe create a new session with
464+
self.session()");
465+
if (crash_if_inconsistent_args && !self._cached_config.Equals(config))
466+
{
467+
throw new ValueError(@"The config used to get the cached session is
468+
different than the one that was used to create the
469+
session. Maybe create a new session with
470+
self.session()");
471+
}
472+
if (crash_if_inconsistent_args && !self._cached_force_gpu.Equals(force_gpu))
473+
{
474+
throw new ValueError(@"The force_gpu value used to get the cached session is
475+
different than the one that was used to create the
476+
session. Maybe create a new session with
477+
self.session()");
478+
}
479+
return _cached_session;
480+
}
481+
}
482+
483+
[TestCleanup]
484+
public void Cleanup()
485+
{
486+
_ClearCachedSession();
487+
}
488+
315489
#endregion
316490

317491
public void AssetSequenceEqual<T>(T[] a, T[] b)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Linq;
4+
using System.Runtime.Intrinsics.X86;
5+
using System.Security.AccessControl;
6+
using Tensorflow.NumPy;
7+
using TensorFlowNET.UnitTest;
8+
using static Tensorflow.Binding;
9+
10+
namespace Tensorflow.Keras.UnitTest.Optimizers
11+
{
12+
[TestClass]
13+
public class GradientDescentOptimizerTest : PythonTest
14+
{
15+
private void TestBasicGeneric<T>() where T : struct
16+
{
17+
var dtype = Type.GetTypeCode(typeof(T)) switch
18+
{
19+
TypeCode.Single => np.float32,
20+
TypeCode.Double => np.float64,
21+
_ => throw new NotImplementedException(),
22+
};
23+
24+
// train.GradientDescentOptimizer is V1 only API.
25+
tf.Graph().as_default();
26+
using (self.cached_session())
27+
{
28+
var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype);
29+
var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype);
30+
var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype);
31+
var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype);
32+
var optimizer = tf.train.GradientDescentOptimizer(3.0f);
33+
var grads_and_vars = new[] {
34+
Tuple.Create(grads0, var0 as IVariableV1),
35+
Tuple.Create(grads1, var1 as IVariableV1)
36+
};
37+
var sgd_op = optimizer.apply_gradients(grads_and_vars);
38+
39+
var global_variables = variables.global_variables_initializer();
40+
self.evaluate<T>(global_variables);
41+
// Fetch params to validate initial values
42+
// TODO: use self.evaluate<T[]> instead of self.evaluate<double[]>
43+
self.assertAllCloseAccordingToType(new double[] { 1.0, 2.0 }, self.evaluate<double[]>(var0));
44+
self.assertAllCloseAccordingToType(new double[] { 3.0, 4.0 }, self.evaluate<double[]>(var1));
45+
// Run 1 step of sgd
46+
sgd_op.run();
47+
// Validate updated params
48+
self.assertAllCloseAccordingToType(
49+
new double[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 },
50+
self.evaluate<double[]>(var0));
51+
self.assertAllCloseAccordingToType(
52+
new double[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 },
53+
self.evaluate<double[]>(var1));
54+
// TODO: self.assertEqual(0, len(optimizer.variables()));
55+
}
56+
}
57+
58+
[TestMethod]
59+
public void TestBasic()
60+
{
61+
//TODO: add np.half
62+
TestBasicGeneric<float>();
63+
TestBasicGeneric<double>();
64+
}
65+
66+
67+
}
68+
}

0 commit comments

Comments
 (0)