Skip to content

Commit 2cb5fd6

Browse files
new graph
1 parent 149caae commit 2cb5fd6

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

test/TensorFlowNET.UnitTest/Training/BasicLinearModel.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ public class BasicLinearModel
1515
[TestMethod]
1616
public void LinearRegression()
1717
{
18+
tf.Graph().as_default();
19+
1820
// Initialize the weights to `5.0` and the bias to `0.0`
1921
// In practice, these should be initialized to random values (for example, with `tf.random.normal`)
2022
var W = tf.Variable(5.0f);

test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
using Microsoft.VisualStudio.TestPlatform.Utilities;
2-
using Microsoft.VisualStudio.TestTools.UnitTesting;
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
32
using System;
4-
using System.Diagnostics;
5-
using System.Linq;
63
using Tensorflow.NumPy;
74
using TensorFlowNET.UnitTest;
85
using static Tensorflow.Binding;
@@ -27,8 +24,8 @@ private void TestBasic<T>() where T : struct
2724
var dtype = GetTypeForNumericType<T>();
2825

2926
// train.GradientDescentOptimizer is V1 only API.
30-
//tf.Graph().as_default();
31-
/*using (var sess = self.cached_session())
27+
tf.Graph().as_default();
28+
using (var sess = self.cached_session())
3229
{
3330
var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype);
3431
var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype);
@@ -59,15 +56,15 @@ private void TestBasic<T>() where T : struct
5956
new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 },
6057
self.evaluate<T[]>(var1));
6158
// TODO: self.assertEqual(0, len(optimizer.variables()));
62-
}*/
59+
}
6360
}
6461

6562
[TestMethod]
6663
public void TestBasic()
6764
{
6865
//TODO: add np.half
6966
TestBasic<float>();
70-
// TestBasic<double>();
67+
TestBasic<double>();
7168
}
7269

7370
private void TestTensorLearningRate<T>() where T : struct
@@ -115,8 +112,8 @@ private void TestTensorLearningRate<T>() where T : struct
115112
public void TestTensorLearningRate()
116113
{
117114
//TODO: add np.half
118-
// TestTensorLearningRate<float>();
119-
// TestTensorLearningRate<double>();
115+
TestTensorLearningRate<float>();
116+
TestTensorLearningRate<double>();
120117
}
121118
}
122119
}

0 commit comments

Comments
 (0)