@@ -60,6 +60,12 @@ def model_fn_body(self, features):
6060 transformer_prepare_encoder (inputs , target_space , hparams ) )
6161 (decoder_input , decoder_self_attention_bias ) = transformer .\
6262 transformer_prepare_decoder (targets , hparams )
63+
64+ # We need masks of the form batch size x input sequences
65+ # Biases seem to be of the form batch_size x 1 x input sequences x vec dim
66+ # Squeeze out dim one, and get the first element of each vector
67+ encoder_mask = tf .squeeze (encoder_attention_bias , [1 ])[:,:,0 ]
68+ decoder_mask = tf .squeeze (decoder_self_attention_bias , [1 ])[:,:,0 ]
6369
6470 def residual_fn (x , y ):
6571 return common_layers .layer_norm (x + tf .nn .dropout (
@@ -68,35 +74,60 @@ def residual_fn(x, y):
6874 encoder_input = tf .nn .dropout (encoder_input , 1.0 - hparams .residual_dropout )
6975 decoder_input = tf .nn .dropout (decoder_input , 1.0 - hparams .residual_dropout )
7076 encoder_output = alt_transformer_encoder (
71- encoder_input , residual_fn , encoder_attention_bias , hparams )
77+ encoder_input , residual_fn , encoder_mask , hparams )
7278
7379 decoder_output = alt_transformer_decoder (
74- decoder_input , encoder_output , residual_fn , decoder_self_attention_bias ,
80+ decoder_input , encoder_output , residual_fn , decoder_mask ,
7581 encoder_attention_bias , hparams )
7682
7783 decoder_output = tf .expand_dims (decoder_output , 2 )
7884
7985 return decoder_output
8086
8187
88+
89+ def composite_layer (inputs , mask , hparams ):
90+ x = inputs
91+
92+ # Applies ravanbakhsh on top of each other
93+ if hparams .composite_layer_type == "ravanbakhsh" :
94+ for layer in xrange (hparams .layers_per_layer ):
95+ with tf .variable_scope (".%d" % layer ):
96+ x = common_layers .ravanbakhsh_set_layer (
97+ hparams .hidden_size ,
98+ x ,
99+ mask = mask ,
100+ dropout = 0.0 )
101+
102+ # Transforms elements to get a context, and then uses this in a final layer
103+ elif hparams .composite_layer_type == "reembedding" :
104+ initial_elems = x
105+ # Transform elements n times and then pool
106+ for layer in xrange (hparams .layers_per_layer ):
107+ with tf .variable_scope (".%d" % layer ):
108+ x = common_layers .linear_set_layer (
109+ hparams .hidden_size ,
110+ x ,
111+ dropout = 0.0 )
112+ context = common_layers .global_pool_1d (x , mask = mask )
113+
114+ #Final layer
115+ x = common_layers .linear_set_layer (
116+ hparams .hidden_size ,
117+ x ,
118+ context = context ,
119+ dropout = 0.0 )
120+
121+ return x
122+
123+
124+
82125def alt_transformer_encoder (encoder_input ,
83126 residual_fn ,
84- encoder_attention_bias ,
127+ mask ,
85128 hparams ,
86129 name = "encoder" ):
87- """
88- A stack of transformer layers.
89130
90- Args:
91- encoder_input: a Tensor
92- residual_fn: a function from (layer_input, layer_output) -> combined_output
93-
94- hparams: hyperparameters for model
95- name: a string
96-
97- Returns:
98- y: a Tensors
99- """
100131 x = encoder_input
101132
102133 # Summaries don't work in multi-problem setting yet.
@@ -105,36 +136,19 @@ def alt_transformer_encoder(encoder_input,
105136 with tf .variable_scope (name ):
106137 for layer in xrange (hparams .num_hidden_layers ):
107138 with tf .variable_scope ("layer_%d" % layer ):
108- x = residual_fn (
109- x ,
110- ravanbakhsh_set_layer (hparams .hidden_size , x , mask = encoder_attention_bias )
111- )
139+ x = residual_fn (x , composite_layer (x , mask , hparams ))
112140
113141 return x
114142
115143
116144def alt_transformer_decoder (decoder_input ,
117145 encoder_output ,
118146 residual_fn ,
119- decoder_self_attention_bias ,
147+ mask ,
120148 encoder_decoder_attention_bias ,
121149 hparams ,
122150 name = "decoder" ):
123- """
124- A stack of transformer layers.
125-
126- Args:
127- decoder_input: a Tensor
128- encoder_output: a Tensor
129- residual_fn: a function from (layer_input, layer_output) -> combined_output
130- encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
131- (see common_attention.attention_bias())
132- hparams: hyperparameters for model
133- name: a string
134-
135- Returns:
136- y: a Tensors
137- """
151+
138152 x = decoder_input
139153
140154 # Summaries don't work in multi-problem setting yet.
@@ -143,30 +157,34 @@ def alt_transformer_decoder(decoder_input,
143157 for layer in xrange (hparams .num_hidden_layers ):
144158 with tf .variable_scope ("layer_%d" % layer ):
145159
146- x = residual_fn (
147- x ,
148- ravanbakhsh_set_layer (hparams .hidden_size ,
149- common_attention .multihead_attention (
150- x ,
151- encoder_output ,
152- encoder_decoder_attention_bias ,
153- hparams .attention_key_channels or hparams .hidden_size ,
154- hparams .attention_value_channels or hparams .hidden_size ,
155- hparams .hidden_size ,
156- hparams .num_heads ,
157- hparams .attention_dropout ,
158- summaries = summaries ,
159- name = "encdec_attention" ),
160- mask = decoder_self_attention_bias )
161- )
160+ x_ = common_attention .multihead_attention (
161+ x ,
162+ encoder_output ,
163+ encoder_decoder_attention_bias ,
164+ hparams .attention_key_channels or hparams .hidden_size ,
165+ hparams .attention_value_channels or hparams .hidden_size ,
166+ hparams .hidden_size ,
167+ hparams .num_heads ,
168+ hparams .attention_dropout ,
169+ summaries = summaries ,
170+ name = "encdec_attention" )
171+
172+ x_ = residual_fn (x_ , composite_layer (x_ , mask , hparams ))
173+ x = residual_fn (x , x_ )
162174
163175 return x
164176
165177
178+
179+
180+
166181@registry .register_hparams
167182def transformer_alt ():
168183 """Set of hyperparameters."""
169184 hparams = transformer .transformer_base ()
185+ hparams .batch_size = 64
170186 hparams .add_hparam ("layers_per_layer" , 4 )
187+ #hparams.add_hparam("composite_layer_type", "ravanbakhsh") #ravanbakhsh or reembedding
188+ hparams .add_hparam ("composite_layer_type" , "reembedding" )
171189 return hparams
172190
0 commit comments