Skip to content

Commit b2b083a

Browse files
committed
Keras.Layers.BatchNormalization
1 parent 053f40b commit b2b083a

File tree

7 files changed

+212
-35
lines changed

7 files changed

+212
-35
lines changed

src/TensorFlowNET.Core/APIs/tf.init.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ namespace Tensorflow
88
public static partial class tf
99
{
1010
public static IInitializer zeros_initializer => new Zeros();
11+
public static IInitializer ones_initializer => new Ones();
1112
public static IInitializer glorot_uniform_initializer => new GlorotUniform();
1213

1314
public static variable_scope variable_scope(string name,

src/TensorFlowNET.Core/APIs/tf.layers.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,22 @@ public static Tensor batch_normalization(Tensor inputs,
8383
bool renorm = false,
8484
float renorm_momentum = 0.99f)
8585
{
86-
throw new NotImplementedException("batch_normalization");
86+
var layer = new BatchNormalization(
87+
axis: axis,
88+
momentum: momentum,
89+
epsilon: epsilon,
90+
center: center,
91+
scale: scale,
92+
beta_initializer: beta_initializer,
93+
gamma_initializer: gamma_initializer,
94+
moving_mean_initializer: moving_mean_initializer,
95+
moving_variance_initializer: moving_variance_initializer,
96+
renorm: renorm,
97+
renorm_momentum: renorm_momentum,
98+
trainable: trainable,
99+
name: name);
100+
101+
return layer.apply(inputs, training: training);
87102
}
88103
}
89104
}

src/TensorFlowNET.Core/Keras/Engine/Layer.cs

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,33 @@ public class Layer : CheckpointableBase
1818
/// the layer's weights.
1919
/// </summary>
2020
protected bool built;
21-
21+
protected bool trainable;
22+
protected TF_DataType _dtype;
23+
/// <summary>
24+
/// A stateful layer is a layer whose updates are run during inference too,
25+
/// for instance stateful RNNs.
26+
/// </summary>
27+
protected bool stateful;
28+
/// <summary>
29+
/// Provides information about which inputs are compatible with the layer.
30+
/// </summary>
31+
protected InputSpec input_spec;
32+
protected bool supports_masking;
2233
protected List<RefVariable> _trainable_weights;
34+
protected string _name;
35+
protected string _base_name;
36+
protected bool _compute_previous_mask;
2337

24-
public Layer()
38+
public Layer(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid)
2539
{
40+
this.trainable = trainable;
41+
this._dtype = dtype;
42+
stateful = false;
43+
built = false;
44+
this.supports_masking = false;
45+
_init_set_name(name);
2646
_trainable_weights = new List<RefVariable>();
47+
_compute_previous_mask = false;
2748
}
2849

2950
public Tensor __call__(Tensor inputs,
@@ -97,7 +118,7 @@ protected void _maybe_build(Tensor inputs)
97118

98119
protected virtual void build(TensorShape input_shape)
99120
{
100-
121+
throw new NotImplementedException("Layer.build");
101122
}
102123

103124
protected virtual RefVariable add_weight(string name,
@@ -119,5 +140,18 @@ protected virtual RefVariable add_weight(string name,
119140

120141
return variable;
121142
}
143+
144+
protected virtual void _init_set_name(string name)
145+
{
146+
if (string.IsNullOrEmpty(name))
147+
(_name, _base_name) = _make_unique_name();
148+
}
149+
150+
protected virtual (string, string) _make_unique_name()
151+
{
152+
string base_name = "conv2d";
153+
string name = base_layer_utils.unique_layer_name(base_name);
154+
return (name, base_name);
155+
}
122156
}
123157
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using Tensorflow.Layers;
6+
7+
namespace Tensorflow.Keras.Layers
8+
{
9+
public class BatchNormalization : Layer
10+
{
11+
private bool _USE_V2_BEHAVIOR = true;
12+
private float momentum;
13+
private float epsilon;
14+
private bool center;
15+
private bool scale;
16+
private bool renorm;
17+
private bool fused;
18+
private bool _bessels_correction_test_only;
19+
private int[] axis;
20+
private string _data_format;
21+
private IInitializer beta_initializer;
22+
private IInitializer gamma_initializer;
23+
private IInitializer moving_mean_initializer;
24+
private IInitializer moving_variance_initializer;
25+
private RefVariable gamma;
26+
private RefVariable beta;
27+
private RefVariable moving_mean;
28+
29+
public BatchNormalization(int axis = -1,
30+
float momentum = 0.99f,
31+
float epsilon = 0.001f,
32+
bool center = true,
33+
bool scale = true,
34+
IInitializer beta_initializer = null,
35+
IInitializer gamma_initializer = null,
36+
IInitializer moving_mean_initializer = null,
37+
IInitializer moving_variance_initializer = null,
38+
bool renorm = false,
39+
float renorm_momentum = 0.99f,
40+
bool trainable = true,
41+
string name = null) : base(trainable: trainable,
42+
name: name)
43+
{
44+
this.axis = new int[] { axis };
45+
this.momentum = momentum;
46+
this.epsilon = epsilon;
47+
this.center = center;
48+
this.scale = scale;
49+
if (beta_initializer == null)
50+
beta_initializer = tf.zeros_initializer;
51+
if (gamma_initializer == null)
52+
gamma_initializer = tf.ones_initializer;
53+
if (moving_mean_initializer == null)
54+
moving_mean_initializer = tf.zeros_initializer;
55+
if (moving_variance_initializer == null)
56+
moving_variance_initializer = tf.ones_initializer;
57+
this.beta_initializer = beta_initializer;
58+
this.gamma_initializer = gamma_initializer;
59+
this.moving_mean_initializer = moving_mean_initializer;
60+
this.moving_variance_initializer = moving_variance_initializer;
61+
this.renorm = renorm;
62+
this.fused = true;
63+
this.supports_masking = true;
64+
this._bessels_correction_test_only = true;
65+
}
66+
67+
protected override void build(TensorShape input_shape)
68+
{
69+
var ndims = input_shape.NDim;
70+
foreach (var (idx, x) in Python.enumerate(axis))
71+
if (x < 0)
72+
axis[idx] = ndims + x;
73+
74+
if (fused)
75+
if (Enumerable.SequenceEqual(axis, new int[] { 3 }))
76+
_data_format = "NHWC";
77+
78+
var param_dtype = _dtype == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : _dtype;
79+
var param_shape = new int[] { input_shape.Dimensions[axis[0]] };
80+
81+
if (scale)
82+
gamma = add_weight("gamma",
83+
param_shape,
84+
dtype: param_dtype,
85+
initializer: gamma_initializer,
86+
trainable: true);
87+
else
88+
throw new NotImplementedException("add_weight gamma");
89+
90+
if (center)
91+
beta = add_weight("beta",
92+
param_shape,
93+
dtype: param_dtype,
94+
initializer: beta_initializer,
95+
trainable: true);
96+
else
97+
throw new NotImplementedException("add_weight beta");
98+
99+
if(_scope != null)
100+
{
101+
102+
}
103+
104+
moving_mean = add_weight("moving_mean",
105+
param_shape,
106+
dtype: param_dtype);
107+
}
108+
}
109+
}

src/TensorFlowNET.Core/Layers/Layer.cs

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,39 +7,27 @@ namespace Tensorflow.Layers
77
{
88
public class Layer : Keras.Engine.Layer
99
{
10-
protected bool trainable;
11-
protected string _name;
12-
protected TF_DataType _dtype;
1310
protected Graph _graph;
14-
protected string _base_name;
11+
1512
protected VariableScope _scope;
1613
protected VariableScope _current_scope;
17-
/// <summary>
18-
/// A stateful layer is a layer whose updates are run during inference too,
19-
/// for instance stateful RNNs.
20-
/// </summary>
21-
protected bool stateful;
22-
/// <summary>
23-
/// Provides information about which inputs are compatible with the layer.
24-
/// </summary>
25-
protected InputSpec input_spec;
26-
protected bool supports_masking;
14+
2715
protected bool? _reuse;
16+
protected bool _use_resource_variables;
17+
protected bool _keras_style;
2818

2919
public Layer(bool trainable = true,
3020
string name = null,
3121
TF_DataType dtype = TF_DataType.DtInvalid,
32-
bool? _reuse = null) : base()
22+
bool? _reuse = null) : base(trainable: trainable, name: name, dtype: dtype)
3323
{
34-
this.trainable = trainable;
35-
this.stateful = false;
24+
this._use_resource_variables = false;
3625
this._reuse = _reuse;
3726
this.built = false;
38-
this.supports_masking = false;
39-
_init_set_name(name);
27+
_keras_style = false;
4028
}
4129

42-
public Tensor apply(Tensor inputs)
30+
public virtual Tensor apply(Tensor inputs, Tensor training = null)
4331
{
4432
return __call__(inputs);
4533
}
@@ -126,18 +114,7 @@ protected virtual RefVariable add_weight(string name,
126114
});
127115
}
128116

129-
private void _init_set_name(string name)
130-
{
131-
if (string.IsNullOrEmpty(name))
132-
(_name, _base_name) = _make_unique_name();
133-
}
134117

135-
private (string, string) _make_unique_name()
136-
{
137-
string base_name = "conv2d";
138-
string name = base_layer_utils.unique_layer_name(base_name);
139-
return (name, base_name);
140-
}
141118

142119
protected override string _name_scope()
143120
{
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Operations.Initializers
6+
{
7+
public class Ones : IInitializer
8+
{
9+
private TF_DataType dtype;
10+
11+
public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT)
12+
{
13+
this.dtype = dtype;
14+
}
15+
16+
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid)
17+
{
18+
if (dtype == TF_DataType.DtInvalid)
19+
dtype = this.dtype;
20+
21+
return array_ops.ones(shape.Dimensions, dtype);
22+
}
23+
24+
public object get_config()
25+
{
26+
return new { dtype = dtype.name() };
27+
}
28+
}
29+
}

src/TensorFlowNET.Core/Operations/array_ops.py.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,18 @@ public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT
9292
});
9393
}
9494

95+
public static Tensor ones(int[] dims, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
96+
{
97+
dtype = dtype.as_base_dtype();
98+
return with(ops.name_scope(name, "ones", new { dims }), scope =>
99+
{
100+
name = scope;
101+
var shape = ops.convert_to_tensor(dims, dtype: TF_DataType.TF_INT32);
102+
var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name);
103+
return output;
104+
});
105+
}
106+
95107
public static Tensor where(Tensor condition, Tensor x = null, Tensor y = null, string name = null)
96108
{
97109
if( x == null && y == null)

0 commit comments

Comments
 (0)