@@ -223,34 +223,35 @@ def evolved_transformer_encoder(encoder_input,
223223 hidden_state = common_layers .layer_postprocess (
224224 residual_state , hidden_state , hparams )
225225
226- with tf .variable_scope ("self_attention" ):
227- residual_state = hidden_state
228- hidden_state = common_layers .layer_preprocess (hidden_state , hparams )
226+ if hparams .get ("et_encoder_self_attention" , True ):
227+ with tf .variable_scope ("self_attention" ):
228+ residual_state = hidden_state
229+ hidden_state = common_layers .layer_preprocess (hidden_state , hparams )
229230
230- hidden_state = common_attention .multihead_attention (
231- hidden_state ,
232- None ,
233- encoder_self_attention_bias ,
234- hparams .attention_key_channels or hparams .hidden_size ,
235- hparams .attention_value_channels or hparams .hidden_size ,
236- hparams .hidden_size ,
237- hparams .num_heads ,
238- hparams .attention_dropout ,
239- attention_type = hparams .self_attention_type ,
240- max_relative_position = hparams .max_relative_position ,
241- heads_share_relative_embedding = (
242- hparams .heads_share_relative_embedding ),
243- add_relative_to_values = hparams .add_relative_to_values ,
244- save_weights_to = save_weights_to ,
245- make_image_summary = make_image_summary ,
246- dropout_broadcast_dims = attention_dropout_broadcast_dims ,
247- max_length = hparams .get ("max_length" ),
248- vars_3d = hparams .get ("attention_variables_3d" ),
249- activation_dtype = hparams .get ("activation_dtype" , "float32" ),
250- weight_dtype = hparams .get ("weight_dtype" , "float32" ))
231+ hidden_state = common_attention .multihead_attention (
232+ hidden_state ,
233+ None ,
234+ encoder_self_attention_bias ,
235+ hparams .attention_key_channels or hparams .hidden_size ,
236+ hparams .attention_value_channels or hparams .hidden_size ,
237+ hparams .hidden_size ,
238+ hparams .num_heads ,
239+ hparams .attention_dropout ,
240+ attention_type = hparams .self_attention_type ,
241+ max_relative_position = hparams .max_relative_position ,
242+ heads_share_relative_embedding = (
243+ hparams .heads_share_relative_embedding ),
244+ add_relative_to_values = hparams .add_relative_to_values ,
245+ save_weights_to = save_weights_to ,
246+ make_image_summary = make_image_summary ,
247+ dropout_broadcast_dims = attention_dropout_broadcast_dims ,
248+ max_length = hparams .get ("max_length" ),
249+ vars_3d = hparams .get ("attention_variables_3d" ),
250+ activation_dtype = hparams .get ("activation_dtype" , "float32" ),
251+ weight_dtype = hparams .get ("weight_dtype" , "float32" ))
251252
252- hidden_state = common_layers .layer_postprocess (
253- residual_state , hidden_state , hparams )
253+ hidden_state = common_layers .layer_postprocess (
254+ residual_state , hidden_state , hparams )
254255
255256 with tf .variable_scope ("dense_layers" ):
256257 residual_state = hidden_state
0 commit comments