Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit bcf9a8b

Browse files
committed
Added example model
1 parent 08d6a63 commit bcf9a8b

File tree

1 file changed

+69
-51
lines changed

1 file changed

+69
-51
lines changed

tensor2tensor/models/transformer_alternative_.py renamed to tensor2tensor/models/transformer_alternative.py

Lines changed: 69 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ def model_fn_body(self, features):
6060
transformer_prepare_encoder(inputs, target_space, hparams) )
6161
(decoder_input, decoder_self_attention_bias) = transformer.\
6262
transformer_prepare_decoder(targets, hparams)
63+
64+
# We need masks of the form batch size x input sequences
65+
# Biases seem to be of the form batch_size x 1 x input sequences x vec dim
66+
# Squeeze out dim one, and get the first element of each vector
67+
encoder_mask = tf.squeeze(encoder_attention_bias, [1])[:,:,0]
68+
decoder_mask = tf.squeeze(decoder_self_attention_bias, [1])[:,:,0]
6369

6470
def residual_fn(x, y):
6571
return common_layers.layer_norm(x + tf.nn.dropout(
@@ -68,35 +74,60 @@ def residual_fn(x, y):
6874
encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout)
6975
decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.residual_dropout)
7076
encoder_output = alt_transformer_encoder(
71-
encoder_input, residual_fn, encoder_attention_bias, hparams)
77+
encoder_input, residual_fn, encoder_mask, hparams)
7278

7379
decoder_output = alt_transformer_decoder(
74-
decoder_input, encoder_output, residual_fn, decoder_self_attention_bias,
80+
decoder_input, encoder_output, residual_fn, decoder_mask,
7581
encoder_attention_bias, hparams)
7682

7783
decoder_output = tf.expand_dims(decoder_output, 2)
7884

7985
return decoder_output
8086

8187

88+
89+
def composite_layer(inputs, mask, hparams):
90+
x = inputs
91+
92+
# Applies ravanbakhsh on top of each other
93+
if hparams.composite_layer_type == "ravanbakhsh":
94+
for layer in xrange(hparams.layers_per_layer):
95+
with tf.variable_scope(".%d" % layer):
96+
x = common_layers.ravanbakhsh_set_layer(
97+
hparams.hidden_size,
98+
x,
99+
mask=mask,
100+
dropout=0.0)
101+
102+
# Transforms elements to get a context, and then uses this in a final layer
103+
elif hparams.composite_layer_type == "reembedding":
104+
initial_elems = x
105+
# Transform elements n times and then pool
106+
for layer in xrange(hparams.layers_per_layer):
107+
with tf.variable_scope(".%d" % layer):
108+
x = common_layers.linear_set_layer(
109+
hparams.hidden_size,
110+
x,
111+
dropout=0.0)
112+
context = common_layers.global_pool_1d(x, mask=mask)
113+
114+
#Final layer
115+
x = common_layers.linear_set_layer(
116+
hparams.hidden_size,
117+
x,
118+
context=context,
119+
dropout=0.0)
120+
121+
return x
122+
123+
124+
82125
def alt_transformer_encoder(encoder_input,
83126
residual_fn,
84-
encoder_attention_bias,
127+
mask,
85128
hparams,
86129
name="encoder"):
87-
"""
88-
A stack of transformer layers.
89130

90-
Args:
91-
encoder_input: a Tensor
92-
residual_fn: a function from (layer_input, layer_output) -> combined_output
93-
94-
hparams: hyperparameters for model
95-
name: a string
96-
97-
Returns:
98-
y: a Tensors
99-
"""
100131
x = encoder_input
101132

102133
# Summaries don't work in multi-problem setting yet.
@@ -105,36 +136,19 @@ def alt_transformer_encoder(encoder_input,
105136
with tf.variable_scope(name):
106137
for layer in xrange(hparams.num_hidden_layers):
107138
with tf.variable_scope("layer_%d" % layer):
108-
x = residual_fn(
109-
x,
110-
ravanbakhsh_set_layer(hparams.hidden_size, x, mask=encoder_attention_bias)
111-
)
139+
x = residual_fn(x, composite_layer(x, mask, hparams))
112140

113141
return x
114142

115143

116144
def alt_transformer_decoder(decoder_input,
117145
encoder_output,
118146
residual_fn,
119-
decoder_self_attention_bias,
147+
mask,
120148
encoder_decoder_attention_bias,
121149
hparams,
122150
name="decoder"):
123-
"""
124-
A stack of transformer layers.
125-
126-
Args:
127-
decoder_input: a Tensor
128-
encoder_output: a Tensor
129-
residual_fn: a function from (layer_input, layer_output) -> combined_output
130-
encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
131-
(see common_attention.attention_bias())
132-
hparams: hyperparameters for model
133-
name: a string
134-
135-
Returns:
136-
y: a Tensors
137-
"""
151+
138152
x = decoder_input
139153

140154
# Summaries don't work in multi-problem setting yet.
@@ -143,30 +157,34 @@ def alt_transformer_decoder(decoder_input,
143157
for layer in xrange(hparams.num_hidden_layers):
144158
with tf.variable_scope("layer_%d" % layer):
145159

146-
x = residual_fn(
147-
x,
148-
ravanbakhsh_set_layer(hparams.hidden_size,
149-
common_attention.multihead_attention(
150-
x,
151-
encoder_output,
152-
encoder_decoder_attention_bias,
153-
hparams.attention_key_channels or hparams.hidden_size,
154-
hparams.attention_value_channels or hparams.hidden_size,
155-
hparams.hidden_size,
156-
hparams.num_heads,
157-
hparams.attention_dropout,
158-
summaries=summaries,
159-
name="encdec_attention"),
160-
mask=decoder_self_attention_bias)
161-
)
160+
x_ = common_attention.multihead_attention(
161+
x,
162+
encoder_output,
163+
encoder_decoder_attention_bias,
164+
hparams.attention_key_channels or hparams.hidden_size,
165+
hparams.attention_value_channels or hparams.hidden_size,
166+
hparams.hidden_size,
167+
hparams.num_heads,
168+
hparams.attention_dropout,
169+
summaries=summaries,
170+
name="encdec_attention")
171+
172+
x_ = residual_fn(x_, composite_layer(x_, mask, hparams))
173+
x = residual_fn(x, x_)
162174

163175
return x
164176

165177

178+
179+
180+
166181
@registry.register_hparams
167182
def transformer_alt():
168183
"""Set of hyperparameters."""
169184
hparams = transformer.transformer_base()
185+
hparams.batch_size = 64
170186
hparams.add_hparam("layers_per_layer", 4)
187+
#hparams.add_hparam("composite_layer_type", "ravanbakhsh") #ravanbakhsh or reembedding
188+
hparams.add_hparam("composite_layer_type", "reembedding")
171189
return hparams
172190

0 commit comments

Comments
 (0)