Skip to content

Commit fb903d8

Browse files
lsylusiyaoEsther2013
authored andcommitted
Add random seed test to help reproduce training
Currently the tests are set ignored because here's bug
1 parent 6a8665f commit fb903d8

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using NumSharp;
3+
using System;
4+
using System.Linq;
5+
using Tensorflow;
6+
using static Tensorflow.Binding;
7+
8+
namespace TensorFlowNET.UnitTest.Basics
9+
{
10+
[TestClass]
11+
public class RandomTest
12+
{
13+
/// <summary>
14+
/// Test the function of setting random seed
15+
/// This will help regenerate the same result
16+
/// </summary>
17+
[TestMethod, Ignore]
18+
public void TFRandomSeedTest()
19+
{
20+
var initValue = np.arange(6).reshape(3, 2);
21+
tf.set_random_seed(1234);
22+
var a1 = tf.random_uniform(1);
23+
var b1 = tf.random_shuffle(tf.constant(initValue));
24+
25+
// This part we consider to be a refresh
26+
tf.set_random_seed(10);
27+
tf.random_uniform(1);
28+
tf.random_shuffle(tf.constant(initValue));
29+
30+
tf.set_random_seed(1234);
31+
var a2 = tf.random_uniform(1);
32+
var b2 = tf.random_shuffle(tf.constant(initValue));
33+
Assert.IsTrue(a1.numpy().array_equal(a2.numpy()));
34+
Assert.IsTrue(b1.numpy().array_equal(b2.numpy()));
35+
}
36+
37+
/// <summary>
38+
/// compare to Test above, seed is also added in params
39+
/// </summary>
40+
[TestMethod, Ignore]
41+
public void TFRandomSeedTest2()
42+
{
43+
var initValue = np.arange(6).reshape(3, 2);
44+
tf.set_random_seed(1234);
45+
var a1 = tf.random_uniform(1, seed:1234);
46+
var b1 = tf.random_shuffle(tf.constant(initValue), seed: 1234);
47+
48+
// This part we consider to be a refresh
49+
tf.set_random_seed(10);
50+
tf.random_uniform(1);
51+
tf.random_shuffle(tf.constant(initValue));
52+
53+
tf.set_random_seed(1234);
54+
var a2 = tf.random_uniform(1);
55+
var b2 = tf.random_shuffle(tf.constant(initValue));
56+
Assert.IsTrue(a1.numpy().array_equal(a2.numpy()));
57+
Assert.IsTrue(b1.numpy().array_equal(b2.numpy()));
58+
}
59+
60+
/// <summary>
61+
/// This part we use funcs in tf.random rather than only tf
62+
/// </summary>
63+
[TestMethod, Ignore]
64+
public void TFRandomRaodomSeedTest()
65+
{
66+
tf.set_random_seed(1234);
67+
var a1 = tf.random.normal(1);
68+
var b1 = tf.random.truncated_normal(1);
69+
70+
// This part we consider to be a refresh
71+
tf.set_random_seed(10);
72+
tf.random.normal(1);
73+
tf.random.truncated_normal(1);
74+
75+
tf.set_random_seed(1234);
76+
var a2 = tf.random.normal(1);
77+
var b2 = tf.random.truncated_normal(1);
78+
79+
Assert.IsTrue(a1.numpy().array_equal(a2.numpy()));
80+
Assert.IsTrue(b1.numpy().array_equal(b2.numpy()));
81+
}
82+
83+
/// <summary>
84+
/// compare to Test above, seed is also added in params
85+
/// </summary>
86+
[TestMethod, Ignore]
87+
public void TFRandomRaodomSeedTest2()
88+
{
89+
tf.set_random_seed(1234);
90+
var a1 = tf.random.normal(1, seed:1234);
91+
var b1 = tf.random.truncated_normal(1);
92+
93+
// This part we consider to be a refresh
94+
tf.set_random_seed(10);
95+
tf.random.normal(1);
96+
tf.random.truncated_normal(1);
97+
98+
tf.set_random_seed(1234);
99+
var a2 = tf.random.normal(1, seed:1234);
100+
var b2 = tf.random.truncated_normal(1, seed:1234);
101+
102+
Assert.IsTrue(a1.numpy().array_equal(a2.numpy()));
103+
Assert.IsTrue(b1.numpy().array_equal(b2.numpy()));
104+
}
105+
}
106+
}

0 commit comments

Comments
 (0)