Skip to content

Commit d3e85fe

Browse files
Pass named args to RandomShuffle; seed handling more simular to python
1 parent e5b40b0 commit d3e85fe

File tree

4 files changed

+37
-7
lines changed

4 files changed

+37
-7
lines changed

src/TensorFlowNET.Core/Contexts/Context.cs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ public sealed partial class Context : IDisposable
4343
public SafeContextHandle Handle => _handle;
4444

4545
int? _seed;
46+
Random _rng;
4647

4748
public Context()
4849
{
@@ -74,11 +75,23 @@ public void ensure_initialized()
7475
}
7576

7677
public void set_global_seed(int? seed)
77-
=> _seed = seed;
78+
{
79+
_seed = seed;
80+
if (seed.HasValue)
81+
_rng = new Random(seed.Value);
82+
else
83+
_rng = null;
84+
// Also clear the kernel cache, to reset any existing seeds
85+
if (_handle != null)
86+
c_api.TFE_ContextClearCaches(_handle);
87+
}
7888

7989
public int? global_seed()
8090
=> _seed;
8191

92+
public int? internal_operation_seed()
93+
=> _rng?.Next(0, int.MaxValue);
94+
8295
public void start_step()
8396
=> c_api.TFE_ContextStartStep(_handle);
8497

@@ -94,7 +107,7 @@ public bool executing_eagerly()
94107
{
95108
if(context_switches.Count() == 0)
96109
tf.enable_eager_execution();
97-
110+
98111
return context_switches.Current().EagerMode;
99112
}
100113

src/TensorFlowNET.Core/Framework/random_seed.py.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using System.Collections.Generic;
1718
using static Tensorflow.Binding;
1819

1920
namespace Tensorflow
2021
{
2122
public class random_seed
2223
{
2324
private static int DEFAULT_GRAPH_SEED = 87654321;
25+
private static Dictionary<string, int> _graph_to_seed_dict = new Dictionary<string, int>();
2426

2527
public static (int?, int?) get_seed(int? op_seed = null)
2628
{
@@ -32,7 +34,20 @@ public static (int?, int?) get_seed(int? op_seed = null)
3234
global_seed = ops.get_default_graph().seed;
3335

3436
if (global_seed.HasValue)
37+
{
38+
if (!op_seed.HasValue)
39+
if (tf.executing_eagerly())
40+
op_seed = tf.Context.internal_operation_seed();
41+
else
42+
{
43+
if (!_graph_to_seed_dict.TryGetValue(ops.get_default_graph().graph_key, out int seed))
44+
seed = 0;
45+
_graph_to_seed_dict[ops.get_default_graph().graph_key] = seed + 1;
46+
op_seed = seed;
47+
}
48+
3549
return (global_seed, op_seed);
50+
}
3651

3752
if (op_seed.HasValue)
3853
return (DEFAULT_GRAPH_SEED, op_seed);

src/TensorFlowNET.Core/Operations/gen_random_ops.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0,
131131
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
132132
"RandomShuffle", name,
133133
null,
134-
value, seed, seed2);
134+
value,
135+
"seed", seed,
136+
"seed2", seed2);
135137

136138
return results[0];
137139
}

test/TensorFlowNET.UnitTest/Basics/RandomTest.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public class RandomTest
1414
/// Test the function of setting random seed
1515
/// This will help regenerate the same result
1616
/// </summary>
17-
[TestMethod, Ignore]
17+
[TestMethod]
1818
public void TFRandomSeedTest()
1919
{
2020
var initValue = np.arange(6).reshape(3, 2);
@@ -37,7 +37,7 @@ public void TFRandomSeedTest()
3737
/// <summary>
3838
/// compare to Test above, seed is also added in params
3939
/// </summary>
40-
[TestMethod, Ignore]
40+
[TestMethod]
4141
public void TFRandomSeedTest2()
4242
{
4343
var initValue = np.arange(6).reshape(3, 2);
@@ -60,7 +60,7 @@ public void TFRandomSeedTest2()
6060
/// <summary>
6161
/// This part we use funcs in tf.random rather than only tf
6262
/// </summary>
63-
[TestMethod, Ignore]
63+
[TestMethod]
6464
public void TFRandomRaodomSeedTest()
6565
{
6666
tf.set_random_seed(1234);
@@ -83,7 +83,7 @@ public void TFRandomRaodomSeedTest()
8383
/// <summary>
8484
/// compare to Test above, seed is also added in params
8585
/// </summary>
86-
[TestMethod, Ignore]
86+
[TestMethod]
8787
public void TFRandomRaodomSeedTest2()
8888
{
8989
tf.set_random_seed(1234);

0 commit comments

Comments
 (0)