2727from __future__ import division
2828from __future__ import print_function
2929
30- from tensor2tensor .models .research .shuffle_network import ShuffleNetwork
31- from tensor2tensor .models .research .shuffle_network import shuffle_layer
32- from tensor2tensor .models .research .shuffle_network import reverse_shuffle_layer
30+ import numpy as np
3331from tensor2tensor .layers .common_layers import gelu
32+ from tensor2tensor .models .research .shuffle_network import reverse_shuffle_layer
33+ from tensor2tensor .models .research .shuffle_network import shuffle_layer
34+ from tensor2tensor .models .research .shuffle_network import ShuffleNetwork
3435from tensor2tensor .utils import registry
35-
36- import numpy as np
3736import tensorflow .compat .v1 as tf
3837
3938
@@ -46,29 +45,34 @@ def __init__(self, axis=1, epsilon=1e-10, **kwargs):
4645 Args:
4746 axis: Tuple or number of axis for calculating mean and variance
4847 epsilon: Small epsilon to avoid division by zero
48+ **kwargs: keyword args passed to super.
4949 """
5050 self .axis = axis
5151 self .epsilon = epsilon
5252 self .bias = None
5353 super (LayerNormalization , self ).__init__ (** kwargs )
5454
5555 def build (self , input_shape ):
56- """ Initialize bias weights for layer normalization.
56+ """Initialize bias weights for layer normalization.
57+
5758 Args:
5859 input_shape: shape of input tensor
5960 """
6061 num_units = input_shape .as_list ()[- 1 ]
61- self .bias = self .add_weight ("bias" , [ 1 , 1 , num_units ],
62- initializer = tf .zeros_initializer )
62+ self .bias = self .add_weight (
63+ "bias" , [ 1 , 1 , num_units ], initializer = tf .zeros_initializer )
6364 super (LayerNormalization , self ).build (input_shape )
6465
6566 def call (self , inputs , ** kwargs ):
66- """ Apply Layer Normalization without output bias and gain.
67+ """Apply Layer Normalization without output bias and gain.
6768
6869 Args:
69- inputs: tensor to be normalized. Axis should be smaller than input
70- tensor dimensions.
70+ inputs: tensor to be normalized. Axis should be smaller than input tensor
71+ dimensions.
7172 **kwargs: more arguments (unused)
73+
74+ Returns:
75+ tensor output.
7276 """
7377 inputs -= tf .reduce_mean (inputs , axis = self .axis , keepdims = True )
7478 inputs += self .bias
@@ -81,6 +85,9 @@ def inv_sigmoid(y):
8185
8286 Args:
8387 y: float in range 0 to 1
88+
89+ Returns:
90+ the inverse sigmoid.
8491 """
8592 return np .log (y / (1 - y ))
8693
@@ -107,7 +114,7 @@ def __init__(self, prefix, dropout, mode, **kwargs):
107114 self .residual_scale = None
108115
109116 residual_weight = 0.9
110- self .candidate_weight = np .sqrt (1 - residual_weight ** 2 ) * 0.25
117+ self .candidate_weight = np .sqrt (1 - residual_weight ** 2 ) * 0.25
111118 self .init_value = inv_sigmoid (residual_weight )
112119
113120 def build (self , input_shape ):
@@ -119,33 +126,35 @@ def build(self, input_shape):
119126 in_units = input_shape [- 1 ]
120127 middle_units = in_units * 4
121128 out_units = in_units * 2
122- init = tf .variance_scaling_initializer (scale = 1.0 , mode = "fan_avg" ,
123- distribution = "uniform" )
129+ init = tf .variance_scaling_initializer (
130+ scale = 1.0 , mode = "fan_avg" , distribution = "uniform" )
124131
125- self .first_linear = tf .keras .layers .Dense (middle_units ,
126- use_bias = False ,
127- kernel_initializer = init ,
128- name = self .prefix + "/cand1" )
132+ self .first_linear = tf .keras .layers .Dense (
133+ middle_units ,
134+ use_bias = False ,
135+ kernel_initializer = init ,
136+ name = self .prefix + "/cand1" )
129137
130- self .second_linear = tf .keras .layers .Dense (out_units ,
131- kernel_initializer = init ,
132- name = self .prefix + "/cand2" )
138+ self .second_linear = tf .keras .layers .Dense (
139+ out_units , kernel_initializer = init , name = self .prefix + "/cand2" )
133140 self .layer_norm = LayerNormalization ()
134141
135142 init = tf .constant_initializer (self .init_value )
136- self .residual_scale = self .add_weight (self . prefix + "/residual" ,
137- [out_units ], initializer = init )
143+ self .residual_scale = self .add_weight (
144+ self . prefix + "/residual" , [out_units ], initializer = init )
138145 super (RSU , self ).build (input_shape )
139146
140147 def call (self , inputs , ** kwargs ):
141148 """Apply Residual Switch Layer to inputs.
142149
143150 Args:
144- inputs: Input tensor
151+ inputs: Input tensor.
152+ **kwargs: unused kwargs.
145153
146154 Returns:
147155 tf.Tensor: New candidate value
148156 """
157+ del kwargs
149158 input_shape = tf .shape (inputs )
150159 batch_size = input_shape [0 ]
151160 length = input_shape [1 ]
@@ -201,7 +210,7 @@ def residual_shuffle_network(inputs, hparams):
201210
202211
203212def reverse_part (inputs , hparams , n_bits ):
204- """ Reverse part of Beneš block.
213+ """Reverse part of Benes block.
205214
206215 Repeatably applies interleaved Residual Switch layer and Reverse Shuffle
207216 Layer. One set of weights used for all Switch layers.
@@ -222,24 +231,23 @@ def reverse_step(state, _):
222231 return reverse_shuffle_layer (new_state )
223232
224233 reverse_outputs = tf .scan (
225- reverse_step ,
226- tf .range (n_bits , n_bits * 2 ),
227- initializer = inputs ,
228- parallel_iterations = 1 ,
229- swap_memory = True )
234+ reverse_step ,
235+ tf .range (n_bits , n_bits * 2 ),
236+ initializer = inputs ,
237+ parallel_iterations = 1 ,
238+ swap_memory = True )
230239
231240 return reverse_outputs [- 1 , :, :, :]
232241
233242
234243def forward_part (block_out , hparams , n_bits ):
235- """ Forward part of Beneš block.
244+ """Forward part of Benes block.
236245
237246 Repeatably applies interleaved Residual Switch layer and Shuffle
238247 Layer. One set of weights used for all Switch layers.
239248
240249 Args:
241- inputs: inputs for forward part. Should be inputs from previous layers
242- or Beneš block.
250+ block_out: TODO(authors) document.
243251 hparams: params of the network.
244252 n_bits: count of repeated layer applications.
245253
@@ -254,11 +262,11 @@ def forward_step(state, _):
254262 return shuffle_layer (new_state )
255263
256264 forward_outputs = tf .scan (
257- forward_step ,
258- tf .range (0 , n_bits ),
259- initializer = block_out ,
260- parallel_iterations = 1 ,
261- swap_memory = True )
265+ forward_step ,
266+ tf .range (0 , n_bits ),
267+ initializer = block_out ,
268+ parallel_iterations = 1 ,
269+ swap_memory = True )
262270
263271 return forward_outputs [- 1 , :, :, :]
264272
@@ -272,6 +280,9 @@ def body(self, features):
272280
273281 Args:
274282 features: dictionary of inputs and targets
283+
284+ Returns:
285+ the network output.
275286 """
276287
277288 inputs = tf .squeeze (features ["inputs" ], axis = 2 )
0 commit comments