Skip to content

Commit f7b8dba

Browse files
small fixes
1 parent 2a377e2 commit f7b8dba

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using Microsoft.VisualStudio.TestTools.UnitTesting;
1+
using Microsoft.VisualStudio.TestPlatform.Utilities;
2+
using Microsoft.VisualStudio.TestTools.UnitTesting;
23
using System;
34
using System.Linq;
45
using Tensorflow.NumPy;
@@ -20,7 +21,7 @@ private static TF_DataType GetTypeForNumericType<T>() where T : struct
2021
};
2122
}
2223

23-
private void TestBasicGeneric<T>() where T : struct
24+
private void TestBasic<T>() where T : struct
2425
{
2526
var dtype = GetTypeForNumericType<T>();
2627

@@ -42,11 +43,9 @@ private void TestBasicGeneric<T>() where T : struct
4243
var global_variables = tf.global_variables_initializer();
4344
sess.run(global_variables);
4445

45-
// Fetch params to validate initial values
4646
var initialVar0 = sess.run(var0);
47-
var valu = var0.eval(sess);
4847
var initialVar1 = sess.run(var1);
49-
// TODO: use self.evaluate<T[]> instead of self.evaluate<double[]>
48+
// Fetch params to validate initial values
5049
self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0));
5150
self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1));
5251
// Run 1 step of sgd
@@ -66,10 +65,9 @@ private void TestBasicGeneric<T>() where T : struct
6665
public void TestBasic()
6766
{
6867
//TODO: add np.half
69-
TestBasicGeneric<float>();
70-
TestBasicGeneric<double>();
68+
TestBasic<float>();
69+
TestBasic<double>();
7170
}
7271

73-
7472
}
7573
}

0 commit comments

Comments
 (0)