Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit b951c79

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Add the recent group normalization to common layers.
PiperOrigin-RevId: 191769014
1 parent b39d152 commit b951c79

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

tensor2tensor/layers/common_layers.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ def layer_norm_compute(x, epsilon, scale, bias):
577577
def layer_norm(x, filters=None, epsilon=1e-6, name=None, reuse=None):
578578
"""Layer normalize the tensor x, averaging over the last dimension."""
579579
if filters is None:
580-
filters = x.get_shape()[-1]
580+
filters = shape_list(x)[-1]
581581
with tf.variable_scope(
582582
name, default_name="layer_norm", values=[x], reuse=reuse):
583583
scale = tf.get_variable(
@@ -592,6 +592,27 @@ def layer_norm(x, filters=None, epsilon=1e-6, name=None, reuse=None):
592592
return result
593593

594594

595+
def group_norm(x, filters=None, num_groups=8, epsilon=1e-5):
596+
"""Group normalization as in https://arxiv.org/abs/1803.08494."""
597+
x_shape = shape_list(x)
598+
if filters is None:
599+
filters = x_shape[-1]
600+
assert len(x_shape) == 4
601+
assert filters % num_groups == 0
602+
# Prepare variables.
603+
scale = tf.get_variable(
604+
"group_norm_scale", [filters], initializer=tf.ones_initializer())
605+
bias = tf.get_variable(
606+
"group_norm_bias", [filters], initializer=tf.zeros_initializer())
607+
epsilon, scale, bias = [tf.cast(t, x.dtype) for t in [epsilon, scale, bias]]
608+
# Reshape and compute group norm.
609+
x = tf.reshape(x, x_shape[:-1] + [num_groups, filters // num_groups])
610+
# Calculate mean and variance on heights, width, channels (not groups).
611+
mean, variance = tf.nn.moments(x, [1, 2, 4], keep_dims=True)
612+
norm_x = (x - mean) * tf.rsqrt(variance + epsilon)
613+
return tf.reshape(norm_x, x_shape) * scale + bias
614+
615+
595616
def noam_norm(x, epsilon=1.0, name=None):
596617
"""One version of layer normalization."""
597618
with tf.name_scope(name, default_name="noam_norm", values=[x]):
@@ -605,6 +626,8 @@ def apply_norm(x, norm_type, depth, epsilon):
605626
"""Apply Normalization."""
606627
if norm_type == "layer":
607628
return layer_norm(x, filters=depth, epsilon=epsilon)
629+
if norm_type == "group":
630+
return group_norm(x, filters=depth, epsilon=epsilon)
608631
if norm_type == "batch":
609632
return tf.layers.batch_normalization(x, epsilon=epsilon)
610633
if norm_type == "noam":

tensor2tensor/layers/common_layers_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,14 @@ def testLayerNorm(self):
236236
res = session.run(y)
237237
self.assertEqual(res.shape, (5, 7, 11))
238238

239+
def testGroupNorm(self):
240+
x = np.random.rand(5, 7, 3, 16)
241+
with self.test_session() as session:
242+
y = common_layers.group_norm(tf.constant(x, dtype=tf.float32))
243+
session.run(tf.global_variables_initializer())
244+
res = session.run(y)
245+
self.assertEqual(res.shape, (5, 7, 3, 16))
246+
239247
def testConvLSTM(self):
240248
x = np.random.rand(5, 7, 11, 13)
241249
with self.test_session() as session:

0 commit comments

Comments
 (0)