Skip to content

Commit 3805771

Browse files
author
Beacontownfc
committed
improve layer norm
1 parent 4c6063d commit 3805771

File tree

3 files changed

+54
-1
lines changed

3 files changed

+54
-1
lines changed

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using System.Xml.Linq;
1718
using Tensorflow.Operations;
1819
using Tensorflow.Operations.Activation;
20+
//using static System.Formats.Asn1.AsnWriter;
1921
using static Tensorflow.Binding;
2022

2123
namespace Tensorflow
@@ -125,6 +127,22 @@ public Tensor[] fused_batch_norm(Tensor x,
125127
is_training: is_training,
126128
name: name,
127129
exponential_avg_factor: exponential_avg_factor);
130+
public Tensor batch_normalization(Tensor x,
131+
Tensor mean,
132+
Tensor variance,
133+
Tensor offset,
134+
Tensor scale,
135+
float variance_epsilon,
136+
string name = null)
137+
{
138+
var inv = math_ops.rsqrt(variance + variance_epsilon);
139+
tf_with(ops.name_scope(name, "batchnorm", (x, mean, variance, scale, offset)), scope =>
140+
{
141+
if (scale != null) inv *= scale;
142+
});
143+
if (offset != null) return x * math_ops.cast(inv, x.dtype) + math_ops.cast(offset - mean * inv, dtype: x.dtype);
144+
else return x * math_ops.cast(inv, x.dtype) + math_ops.cast(-mean * inv, dtype: x.dtype);
145+
}
128146

129147
public Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null)
130148
=> nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name);

src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,22 @@ protected override Tensors Call(Tensors inputs, Tensors state = null, bool? trai
153153
}
154154
else
155155
{
156+
var input_dtype = inputs.dtype;
157+
if ((input_dtype == tf.float16) && DType == tf.float32) inputs = tf.cast(inputs, tf.float32);
158+
(Tensor mean, Tensor variance) = tf.nn.moments(inputs, axis, keep_dims: true);
156159

157-
}
160+
(Tensor scale, Tensor offset) = (_broadcast(gamma), _broadcast(beta));
161+
162+
outputs = tf.nn.batch_normalization(
163+
inputs,
164+
mean,
165+
variance,
166+
offset: offset,
167+
scale: scale,
168+
variance_epsilon: epsilon);
158169

170+
outputs = tf.cast(outputs, input_dtype);
171+
}
159172
// If some components of the shape got lost due to adjustments, fix that.
160173
outputs.shape = input_shape;
161174

test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
23
using System.Collections.Generic;
4+
using System.Linq;
35
using Tensorflow.NumPy;
46
using static Tensorflow.Binding;
57
using static Tensorflow.KerasApi;
@@ -161,6 +163,26 @@ public void LayerNormalization()
161163
Tensor output = layer.Apply(inputs);
162164
Assert.AreEqual((5, 2), output.shape);
163165
Assert.IsTrue(output[0].numpy().Equals(new[] { -0.99998f, 0.99998f }));
166+
167+
// test_layernorm_weights
168+
Assert.AreEqual(len(layer.TrainableWeights), 2);
169+
Assert.AreEqual(len(layer.Weights), 2);
170+
171+
var beta = layer.Weights.Where(x => x.Name.StartsWith("beta")).Single();
172+
var gamma = layer.Weights.Where(x => x.Name.StartsWith("gamma")).Single();
173+
174+
// correctness_test
175+
layer = keras.layers.LayerNormalization(axis: -1, epsilon: (float) 1e-12);
176+
var x = np.random.normal(loc: 5.0f, scale: 10.0f, size: (1000, 2, 2, 2)).astype(tf.float32);
177+
178+
output = layer.Apply(x);
179+
180+
var y = (output - beta.numpy()) / gamma.numpy();
181+
182+
var y_mean = np.mean(y.numpy());
183+
var y_std = np.sqrt(np.sum(np.power(y.numpy() - np.mean(y.numpy()), 2)) / 8000);
184+
Assert.IsTrue(tf.greater(np.array(0.1f), tf.abs(y_std - 1.0)).ToArray<bool>()[0]);
185+
Assert.IsTrue(tf.greater(np.array(0.1f), tf.abs(y_mean)).ToArray<bool>()[0]);
164186
}
165187

166188
/// <summary>

0 commit comments

Comments
 (0)