Skip to content

Commit e5b40b0

Browse files
Add global_seed to context
1 parent 49cbae0 commit e5b40b0

File tree

3 files changed

+27
-2
lines changed

3 files changed

+27
-2
lines changed

src/TensorFlowNET.Core/APIs/tf.random.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,12 @@ public Tensor random_shuffle(Tensor value, int? seed = null, string name = null)
9393
=> random_ops.random_shuffle(value, seed: seed, name: name);
9494

9595
public void set_random_seed(int seed)
96-
=> ops.get_default_graph().seed = seed;
96+
{
97+
if (executing_eagerly())
98+
Context.set_global_seed(seed);
99+
else
100+
ops.get_default_graph().seed = seed;
101+
}
97102

98103
public Tensor multinomial(Tensor logits, int num_samples, int? seed = null,
99104
string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid)

src/TensorFlowNET.Core/Contexts/Context.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ public sealed partial class Context : IDisposable
4242
SafeContextHandle _handle;
4343
public SafeContextHandle Handle => _handle;
4444

45+
int? _seed;
46+
4547
public Context()
4648
{
4749
_device_policy = ContextDevicePlacementPolicy.DEVICE_PLACEMENT_SILENT;
@@ -71,6 +73,12 @@ public void ensure_initialized()
7173
initialized = true;
7274
}
7375

76+
public void set_global_seed(int? seed)
77+
=> _seed = seed;
78+
79+
public int? global_seed()
80+
=> _seed;
81+
7482
public void start_step()
7583
=> c_api.TFE_ContextStartStep(_handle);
7684

src/TensorFlowNET.Core/Framework/random_seed.py.cs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using static Tensorflow.Binding;
18+
1719
namespace Tensorflow
1820
{
1921
public class random_seed
@@ -22,8 +24,18 @@ public class random_seed
2224

2325
public static (int?, int?) get_seed(int? op_seed = null)
2426
{
27+
int? global_seed;
28+
29+
if (tf.executing_eagerly())
30+
global_seed = tf.Context.global_seed();
31+
else
32+
global_seed = ops.get_default_graph().seed;
33+
34+
if (global_seed.HasValue)
35+
return (global_seed, op_seed);
36+
2537
if (op_seed.HasValue)
26-
return (DEFAULT_GRAPH_SEED, 0);
38+
return (DEFAULT_GRAPH_SEED, op_seed);
2739
else
2840
return (null, null);
2941
}

0 commit comments

Comments
 (0)