Skip to content

Commit c909310

Browse files
committed
variance_scaling_initializer
1 parent 18b078d commit c909310

File tree

8 files changed

+48
-18
lines changed

8 files changed

+48
-18
lines changed

src/TensorFlowNET.Core/APIs/tf.init.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,20 @@ public IInitializer random_normal_initializer(float mean = 0.0f,
6666
/// <summary>
6767
/// Initializer capable of adapting its scale to the shape of weights tensors.
6868
/// </summary>
69-
/// <param name="scale"></param>
69+
/// <param name="factor"></param>
7070
/// <param name="mode"></param>
7171
/// <param name="distribution"></param>
7272
/// <param name="seed"></param>
7373
/// <param name="dtype"></param>
7474
/// <returns></returns>
75-
public IInitializer variance_scaling_initializer(float scale = 1.0f,
76-
string mode = "fan_in",
77-
string distribution = "truncated_normal",
75+
public IInitializer variance_scaling_initializer(float factor = 1.0f,
76+
string mode = "FAN_IN",
77+
bool uniform = false,
7878
int? seed = null,
7979
TF_DataType dtype = TF_DataType.TF_FLOAT) => new VarianceScaling(
80-
scale: scale,
80+
factor: factor,
8181
mode: mode,
82-
distribution: distribution,
82+
uniform: uniform,
8383
seed: seed,
8484
dtype: dtype);
8585
}

src/TensorFlowNET.Core/APIs/tf.random.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,21 @@ public partial class tensorflow
2828
/// <param name="seed"></param>
2929
/// <param name="name"></param>
3030
/// <returns></returns>
31-
public Tensor random_normal(int[] shape,
31+
public Tensor random_normal(TensorShape shape,
3232
float mean = 0.0f,
3333
float stddev = 1.0f,
3434
TF_DataType dtype = TF_DataType.TF_FLOAT,
3535
int? seed = null,
3636
string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name);
3737

38-
public Tensor random_uniform(int[] shape,
38+
public Tensor random_uniform(TensorShape shape,
3939
float minval = 0,
4040
float maxval = 1,
4141
TF_DataType dtype = TF_DataType.TF_FLOAT,
4242
int? seed = null,
4343
string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name);
4444

45-
public Tensor truncated_normal(int[] shape,
45+
public Tensor truncated_normal(TensorShape shape,
4646
float mean = 0.0f,
4747
float stddev = 1.0f,
4848
TF_DataType dtype = TF_DataType.TF_FLOAT,
@@ -62,5 +62,8 @@ public Tensor truncated_normal(int[] shape,
6262
/// </returns>
6363
public Tensor random_shuffle(Tensor value, int? seed = null, string name = null)
6464
=> random_ops.random_shuffle(value, seed: seed, name: name);
65+
66+
public void set_random_seed(int seed)
67+
=> ops.get_default_graph().seed = seed;
6568
}
6669
}

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,9 @@ public static double sum(IEnumerable enumerable)
273273
return sum;
274274
}
275275

276+
public static double sum(IEnumerable<int> enumerable)
277+
=> enumerable.Sum();
278+
276279
public static double sum<TKey, TValue>(Dictionary<TKey, TValue> values)
277280
{
278281
return sum(values.Keys);

src/TensorFlowNET.Core/Data/DatasetV2.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
namespace Tensorflow.Data
1+
using System;
2+
3+
namespace Tensorflow.Data
24
{
35
/// <summary>
46
/// Represents a potentially large set of elements.
@@ -11,5 +13,9 @@
1113
/// </summary>
1214
public class DatasetV2
1315
{
16+
public static DatasetV2 from_generator()
17+
{
18+
throw new NotImplementedException("");
19+
}
1420
}
1521
}

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,16 @@ public partial class Graph : DisposableObject
107107

108108
public bool building_function;
109109

110+
int _seed;
111+
public int seed
112+
{
113+
get => _seed;
114+
set
115+
{
116+
_seed = value;
117+
}
118+
}
119+
110120
public Graph()
111121
{
112122
_handle = c_api.TF_NewGraph();

src/TensorFlowNET.Core/Keras/Initializers.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public class Initializers
2727
/// <returns></returns>
2828
public IInitializer he_normal(int? seed = null)
2929
{
30-
return new VarianceScaling(scale: 2.0f, mode: "fan_in", distribution: "truncated_normal", seed: seed);
30+
return new VarianceScaling(factor: 2.0f, mode: "fan_in", seed: seed);
3131
}
3232
}
3333
}

src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ public GlorotUniform(float scale = 1.0f,
2222
string mode = "fan_avg",
2323
string distribution = "uniform",
2424
int? seed = null,
25-
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype)
25+
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(factor: scale,
26+
mode: mode,
27+
seed: seed,
28+
dtype: dtype)
2629
{
2730

2831
}

src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,22 @@ public class VarianceScaling : IInitializer
3131
protected int? _seed;
3232
protected TF_DataType _dtype;
3333

34-
public VarianceScaling(float scale = 1.0f,
35-
string mode = "fan_in",
36-
string distribution = "truncated_normal",
34+
public VarianceScaling(float factor = 2.0f,
35+
string mode = "FAN_IN",
36+
bool uniform = false,
3737
int? seed = null,
3838
TF_DataType dtype = TF_DataType.TF_FLOAT)
3939
{
40-
if (scale < 0)
40+
if (!dtype.is_floating())
41+
throw new TypeError("Cannot create initializer for non-floating point type.");
42+
if (!new string[] { "FAN_IN", "FAN_OUT", "FAN_AVG" }.Contains(mode))
43+
throw new TypeError($"Unknown {mode} %s [FAN_IN, FAN_OUT, FAN_AVG]");
44+
45+
if (factor < 0)
4146
throw new ValueError("`scale` must be positive float.");
42-
_scale = scale;
47+
48+
_scale = factor;
4349
_mode = mode;
44-
_distribution = distribution;
4550
_seed = seed;
4651
_dtype = dtype;
4752
}

0 commit comments

Comments
 (0)