1+ using Microsoft . VisualStudio . TestTools . UnitTesting ;
2+ using NumSharp ;
3+ using System ;
4+ using System . Linq ;
5+ using Tensorflow ;
6+ using static Tensorflow . Binding ;
7+
8+ namespace TensorFlowNET . UnitTest . Basics
9+ {
10+ [ TestClass ]
11+ public class RandomTest
12+ {
13+ /// <summary>
14+ /// Test the function of setting random seed
15+ /// This will help regenerate the same result
16+ /// </summary>
17+ [ TestMethod , Ignore ]
18+ public void TFRandomSeedTest ( )
19+ {
20+ var initValue = np . arange ( 6 ) . reshape ( 3 , 2 ) ;
21+ tf . set_random_seed ( 1234 ) ;
22+ var a1 = tf . random_uniform ( 1 ) ;
23+ var b1 = tf . random_shuffle ( tf . constant ( initValue ) ) ;
24+
25+ // This part we consider to be a refresh
26+ tf . set_random_seed ( 10 ) ;
27+ tf . random_uniform ( 1 ) ;
28+ tf . random_shuffle ( tf . constant ( initValue ) ) ;
29+
30+ tf . set_random_seed ( 1234 ) ;
31+ var a2 = tf . random_uniform ( 1 ) ;
32+ var b2 = tf . random_shuffle ( tf . constant ( initValue ) ) ;
33+ Assert . IsTrue ( a1 . numpy ( ) . array_equal ( a2 . numpy ( ) ) ) ;
34+ Assert . IsTrue ( b1 . numpy ( ) . array_equal ( b2 . numpy ( ) ) ) ;
35+ }
36+
37+ /// <summary>
38+ /// compare to Test above, seed is also added in params
39+ /// </summary>
40+ [ TestMethod , Ignore ]
41+ public void TFRandomSeedTest2 ( )
42+ {
43+ var initValue = np . arange ( 6 ) . reshape ( 3 , 2 ) ;
44+ tf . set_random_seed ( 1234 ) ;
45+ var a1 = tf . random_uniform ( 1 , seed : 1234 ) ;
46+ var b1 = tf . random_shuffle ( tf . constant ( initValue ) , seed : 1234 ) ;
47+
48+ // This part we consider to be a refresh
49+ tf . set_random_seed ( 10 ) ;
50+ tf . random_uniform ( 1 ) ;
51+ tf . random_shuffle ( tf . constant ( initValue ) ) ;
52+
53+ tf . set_random_seed ( 1234 ) ;
54+ var a2 = tf . random_uniform ( 1 ) ;
55+ var b2 = tf . random_shuffle ( tf . constant ( initValue ) ) ;
56+ Assert . IsTrue ( a1 . numpy ( ) . array_equal ( a2 . numpy ( ) ) ) ;
57+ Assert . IsTrue ( b1 . numpy ( ) . array_equal ( b2 . numpy ( ) ) ) ;
58+ }
59+
60+ /// <summary>
61+ /// This part we use funcs in tf.random rather than only tf
62+ /// </summary>
63+ [ TestMethod , Ignore ]
64+ public void TFRandomRaodomSeedTest ( )
65+ {
66+ tf . set_random_seed ( 1234 ) ;
67+ var a1 = tf . random . normal ( 1 ) ;
68+ var b1 = tf . random . truncated_normal ( 1 ) ;
69+
70+ // This part we consider to be a refresh
71+ tf . set_random_seed ( 10 ) ;
72+ tf . random . normal ( 1 ) ;
73+ tf . random . truncated_normal ( 1 ) ;
74+
75+ tf . set_random_seed ( 1234 ) ;
76+ var a2 = tf . random . normal ( 1 ) ;
77+ var b2 = tf . random . truncated_normal ( 1 ) ;
78+
79+ Assert . IsTrue ( a1 . numpy ( ) . array_equal ( a2 . numpy ( ) ) ) ;
80+ Assert . IsTrue ( b1 . numpy ( ) . array_equal ( b2 . numpy ( ) ) ) ;
81+ }
82+
83+ /// <summary>
84+ /// compare to Test above, seed is also added in params
85+ /// </summary>
86+ [ TestMethod , Ignore ]
87+ public void TFRandomRaodomSeedTest2 ( )
88+ {
89+ tf . set_random_seed ( 1234 ) ;
90+ var a1 = tf . random . normal ( 1 , seed : 1234 ) ;
91+ var b1 = tf . random . truncated_normal ( 1 ) ;
92+
93+ // This part we consider to be a refresh
94+ tf . set_random_seed ( 10 ) ;
95+ tf . random . normal ( 1 ) ;
96+ tf . random . truncated_normal ( 1 ) ;
97+
98+ tf . set_random_seed ( 1234 ) ;
99+ var a2 = tf . random . normal ( 1 , seed : 1234 ) ;
100+ var b2 = tf . random . truncated_normal ( 1 , seed : 1234 ) ;
101+
102+ Assert . IsTrue ( a1 . numpy ( ) . array_equal ( a2 . numpy ( ) ) ) ;
103+ Assert . IsTrue ( b1 . numpy ( ) . array_equal ( b2 . numpy ( ) ) ) ;
104+ }
105+ }
106+ }
0 commit comments