@@ -577,7 +577,7 @@ def layer_norm_compute(x, epsilon, scale, bias):
577577def layer_norm (x , filters = None , epsilon = 1e-6 , name = None , reuse = None ):
578578 """Layer normalize the tensor x, averaging over the last dimension."""
579579 if filters is None :
580- filters = x . get_shape ( )[- 1 ]
580+ filters = shape_list ( x )[- 1 ]
581581 with tf .variable_scope (
582582 name , default_name = "layer_norm" , values = [x ], reuse = reuse ):
583583 scale = tf .get_variable (
@@ -592,6 +592,27 @@ def layer_norm(x, filters=None, epsilon=1e-6, name=None, reuse=None):
592592 return result
593593
594594
595+ def group_norm (x , filters = None , num_groups = 8 , epsilon = 1e-5 ):
596+ """Group normalization as in https://arxiv.org/abs/1803.08494."""
597+ x_shape = shape_list (x )
598+ if filters is None :
599+ filters = x_shape [- 1 ]
600+ assert len (x_shape ) == 4
601+ assert filters % num_groups == 0
602+ # Prepare variables.
603+ scale = tf .get_variable (
604+ "group_norm_scale" , [filters ], initializer = tf .ones_initializer ())
605+ bias = tf .get_variable (
606+ "group_norm_bias" , [filters ], initializer = tf .zeros_initializer ())
607+ epsilon , scale , bias = [tf .cast (t , x .dtype ) for t in [epsilon , scale , bias ]]
608+ # Reshape and compute group norm.
609+ x = tf .reshape (x , x_shape [:- 1 ] + [num_groups , filters // num_groups ])
610+ # Calculate mean and variance on heights, width, channels (not groups).
611+ mean , variance = tf .nn .moments (x , [1 , 2 , 4 ], keep_dims = True )
612+ norm_x = (x - mean ) * tf .rsqrt (variance + epsilon )
613+ return tf .reshape (norm_x , x_shape ) * scale + bias
614+
615+
595616def noam_norm (x , epsilon = 1.0 , name = None ):
596617 """One version of layer normalization."""
597618 with tf .name_scope (name , default_name = "noam_norm" , values = [x ]):
@@ -605,6 +626,8 @@ def apply_norm(x, norm_type, depth, epsilon):
605626 """Apply Normalization."""
606627 if norm_type == "layer" :
607628 return layer_norm (x , filters = depth , epsilon = epsilon )
629+ if norm_type == "group" :
630+ return group_norm (x , filters = depth , epsilon = epsilon )
608631 if norm_type == "batch" :
609632 return tf .layers .batch_normalization (x , epsilon = epsilon )
610633 if norm_type == "noam" :
0 commit comments