Skip to content

Commit 958fdb8

Browse files
committed
gen_random_ops.truncated_normal
1 parent 9f0c1e5 commit 958fdb8

File tree

9 files changed

+84
-6
lines changed

9 files changed

+84
-6
lines changed

src/TensorFlowNET.Core/Keras/Engine/Layer.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ public class Layer : CheckpointableBase
1919
/// </summary>
2020
protected bool built;
2121

22+
protected List<RefVariable> _trainable_weights;
23+
24+
public Layer()
25+
{
26+
_trainable_weights = new List<RefVariable>();
27+
}
28+
2229
public Tensor __call__(Tensor inputs,
2330
VariableScope scope = null)
2431
{
@@ -36,6 +43,7 @@ public Tensor __call__(Tensor inputs,
3643
if (!built)
3744
{
3845
_maybe_build(inputs);
46+
built = true;
3947
}
4048
});
4149

@@ -65,13 +73,15 @@ protected virtual void add_weight(string name,
6573
bool? trainable = null,
6674
Func<string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null)
6775
{
68-
_add_variable_with_custom_getter(name,
76+
var variable = _add_variable_with_custom_getter(name,
6977
shape,
7078
dtype: dtype,
7179
getter: getter,
7280
overwrite: true,
7381
initializer: initializer,
7482
trainable: trainable.Value);
83+
backend.track_variable(variable);
84+
_trainable_weights.Add(variable);
7585
}
7686
}
7787
}

src/TensorFlowNET.Core/Keras/Initializers.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public class Initializers
1414
/// <returns></returns>
1515
public IInitializer he_normal(int? seed = null)
1616
{
17-
return new VarianceScaling(scale: 20f, mode: "fan_in", distribution: "truncated_normal", seed: seed);
17+
return new VarianceScaling(scale: 2.0f, mode: "fan_in", distribution: "truncated_normal", seed: seed);
1818
}
1919
}
2020
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras
6+
{
7+
public class backend
8+
{
9+
public static void track_variable(RefVariable v)
10+
{
11+
12+
}
13+
}
14+
}

src/TensorFlowNET.Core/Layers/Layer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public class Layer : Keras.Engine.Layer
2929
public Layer(bool trainable = true,
3030
string name = null,
3131
TF_DataType dtype = TF_DataType.DtInvalid,
32-
bool? _reuse = null)
32+
bool? _reuse = null) : base()
3333
{
3434
this.trainable = trainable;
3535
this.stateful = false;

src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ public Tensor call(TensorShape shape, TF_DataType dtype)
4343

4444
if (_distribution == "normal" || _distribution == "truncated_normal")
4545
{
46-
throw new NotImplementedException("truncated_normal");
46+
float stddev = (float)Math.Sqrt(_scale) / .87962566103423978f;
47+
return random_ops.truncated_normal(shape, mean: 0.0f, stddev: stddev, dtype: dtype, seed: _seed);
4748
}
4849
else if (_distribution == "untruncated_normal")
4950
{

src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,28 @@ public static Tensor random_uniform(Tensor shape, TF_DataType dtype, int? seed =
5353

5454
return _op.outputs[0];
5555
}
56+
57+
/// <summary>
58+
/// Outputs random values from a truncated normal distribution.
59+
/// </summary>
60+
/// <param name="shape"></param>
61+
/// <param name="dtype"></param>
62+
/// <param name="seed"></param>
63+
/// <param name="seed2"></param>
64+
/// <param name="name"></param>
65+
/// <returns></returns>
66+
public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = 0, int? seed2 = 0, string name = null)
67+
{
68+
if (!seed.HasValue)
69+
seed = 0;
70+
if (!seed2.HasValue)
71+
seed2 = 0;
72+
73+
var _op = _op_def_lib._apply_op_helper("TruncatedNormal",
74+
name: name,
75+
args: new { shape, dtype, seed, seed2 });
76+
77+
return _op.outputs[0];
78+
}
5679
}
5780
}

src/TensorFlowNET.Core/Operations/random_ops.py.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,27 @@ public static Tensor random_uniform(int[] shape,
6464
});
6565
}
6666

67+
public static Tensor truncated_normal(int[] shape,
68+
float mean = 0.0f,
69+
float stddev = 1.0f,
70+
TF_DataType dtype = TF_DataType.TF_FLOAT,
71+
int? seed = null,
72+
string name = null)
73+
{
74+
return with(ops.name_scope(name, "truncated_normal", new { shape, mean, stddev }), scope =>
75+
{
76+
name = scope;
77+
var shape_tensor = _ShapeTensor(shape);
78+
var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean");
79+
var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name: "stddev");
80+
var (seed1, seed2) = random_seed.get_seed(seed);
81+
var rnd = gen_random_ops.truncated_normal(shape_tensor, dtype, seed: seed1, seed2: seed2);
82+
var mul = rnd * stddev_tensor;
83+
var value = math_ops.add(mul, mean_tensor, name: name);
84+
return value;
85+
});
86+
}
87+
6788
private static Tensor _ShapeTensor(int[] shape)
6889
{
6990
return ops.convert_to_tensor(shape, name: "shape");

src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,16 @@ protected virtual RefVariable _add_variable_with_custom_getter(string name,
1919
bool trainable = false)
2020
{
2121
var new_variable = getter(name, shape, dtype, initializer, trainable);
22-
throw new NotImplementedException("_add_variable_with_custom_getter");
22+
if (!overwrite || new_variable is RefVariable)
23+
return _track_checkpointable(new_variable, name: name,
24+
overwrite: overwrite);
25+
else
26+
return new_variable;
27+
}
28+
29+
protected RefVariable _track_checkpointable(RefVariable checkpointable, string name, bool overwrite = false)
30+
{
31+
return checkpointable;
2332
}
2433
}
2534
}

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ private void _init_from_args(object initial_value,
136136
{
137137
_initial_value = (initial_value as Func<Tensor>)();
138138
_initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype);
139-
_variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
140139
});
140+
_variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
141141
}
142142
// Or get the initial value from a Tensor or Python object.
143143
else

0 commit comments

Comments
 (0)