Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f703629

Browse files
T2T TeamRyan Sepassi
authored andcommitted
Change summary generation to work better in multi-model case.
PiperOrigin-RevId: 162429483
1 parent d3502cb commit f703629

File tree

10 files changed

+27
-59
lines changed

10 files changed

+27
-59
lines changed

tensor2tensor/models/attention_lm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ def attention_lm_decoder(decoder_input,
101101
y: a Tensors
102102
"""
103103
x = decoder_input
104-
# Summaries don't work in multi-problem setting yet.
105-
summaries = "problems" not in hparams.values() or len(hparams.problems) == 1
106104
with tf.variable_scope(name):
107105
for layer in xrange(hparams.num_hidden_layers):
108106
with tf.variable_scope("layer_%d" % layer):
@@ -117,7 +115,6 @@ def attention_lm_decoder(decoder_input,
117115
hparams.hidden_size,
118116
hparams.num_heads,
119117
hparams.attention_dropout,
120-
summaries=summaries,
121118
name="decoder_self_attention"))
122119
x = residual_fn(x,
123120
common_layers.conv_hidden_relu(

tensor2tensor/models/attention_lm_moe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def residual_fn(x, y):
6969
hparams.hidden_size,
7070
hparams.num_heads,
7171
hparams.attention_dropout,
72-
summaries=True,
7372
name="decoder_self_attention")
7473
x = dp(residual_fn, x, y)
7574
with tf.variable_scope("ffn"):

tensor2tensor/models/common_attention.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,6 @@ def dot_product_attention(q,
312312
v,
313313
bias,
314314
dropout_rate=0.0,
315-
summaries=False,
316315
image_shapes=None,
317316
name=None):
318317
"""dot-product attention.
@@ -323,7 +322,6 @@ def dot_product_attention(q,
323322
v: a Tensor with shape [batch, heads, length_kv, depth_v]
324323
bias: bias Tensor (see attention_bias())
325324
dropout_rate: a floating point number
326-
summaries: a boolean
327325
image_shapes: optional tuple of integer scalars.
328326
see comments for attention_image_summary()
329327
name: an optional string
@@ -340,13 +338,13 @@ def dot_product_attention(q,
340338
weights = tf.nn.softmax(logits, name="attention_weights")
341339
# dropping out the attention links for each of the heads
342340
weights = tf.nn.dropout(weights, 1.0 - dropout_rate)
343-
if summaries and not tf.get_variable_scope().reuse:
341+
if not tf.get_variable_scope().reuse:
344342
attention_image_summary(weights, image_shapes)
345343
return tf.matmul(weights, v)
346344

347345

348346
def masked_local_attention_1d(
349-
q, k, v, block_length=128, summaries=True, name=None):
347+
q, k, v, block_length=128, name=None):
350348
"""Attention to the source position and a neigborhood to the left of it.
351349
352350
The sequence is divided into blocks of length block_size.
@@ -362,7 +360,6 @@ def masked_local_attention_1d(
362360
k: a Tensor with shape [batch, heads, length, depth_k]
363361
v: a Tensor with shape [batch, heads, length, depth_v]
364362
block_length: an integer
365-
summaries: a boolean
366363
name: an optional string
367364
368365
Returns:
@@ -394,7 +391,7 @@ def masked_local_attention_1d(
394391
first_v = tf.slice(v, [0, 0, 0, 0], [-1, -1, block_length, -1])
395392
first_output = dot_product_attention(
396393
first_q, first_k, first_v, attention_bias_lower_triangle(block_length),
397-
summaries=summaries, name="fist_block")
394+
name="fist_block")
398395

399396
# compute attention for all subsequent query blocks.
400397
q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k])
@@ -442,7 +439,6 @@ def multihead_attention(query_antecedent,
442439
output_depth,
443440
num_heads,
444441
dropout_rate,
445-
summaries=False,
446442
image_shapes=None,
447443
attention_type="dot_product",
448444
block_length=128,
@@ -458,7 +454,6 @@ def multihead_attention(query_antecedent,
458454
output_depth: an integer
459455
num_heads: an integer dividing total_key_depth and total_value_depth
460456
dropout_rate: a floating point number
461-
summaries: a boolean
462457
image_shapes: optional tuple of integer scalars.
463458
see comments for attention_image_summary()
464459
attention_type: a string, either "dot_product" or "local_mask_right"
@@ -509,12 +504,10 @@ def multihead_attention(query_antecedent,
509504
q *= key_depth_per_head**-0.5
510505
if attention_type == "dot_product":
511506
x = dot_product_attention(
512-
q, k, v, bias, dropout_rate, summaries, image_shapes)
507+
q, k, v, bias, dropout_rate, image_shapes)
513508
else:
514509
assert attention_type == "local_mask_right"
515-
x = masked_local_attention_1d(q, k, v,
516-
block_length=block_length,
517-
summaries=summaries)
510+
x = masked_local_attention_1d(q, k, v, block_length=block_length)
518511
x = combine_heads(x)
519512
x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
520513
return x

tensor2tensor/models/common_layers.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -777,15 +777,15 @@ def moe_layer(data_parallelism,
777777
xs_2d = dp(tf.reshape, xs, [[-1, model_hidden_size]] * dp.n)
778778
# Call the MoE
779779
moe_out_2d, importance, load, _, _ = moe.Eval(
780-
dp.devices, xs_2d, train, identifiers=None, summaries=True)
780+
dp.devices, xs_2d, train, identifiers=None)
781781
# Reshape the output to the original shape.
782782
moe_out = dp(tf.reshape, moe_out_2d, dp(tf.shape, xs))
783783
# These losses encourage equal load on the different experts.
784784
loss = loss_coef * (eu.CVSquared(importance) + eu.CVSquared(load))
785785
return moe_out, loss
786786

787787

788-
def simple_attention(target, source, bias=None, summaries=True):
788+
def simple_attention(target, source, bias=None):
789789
"""A simple attention function.
790790
791791
Args:
@@ -795,7 +795,6 @@ def simple_attention(target, source, bias=None, summaries=True):
795795
`[batch, source_timesteps_1, source_timesteps_2, depth]`
796796
bias: an optional `Tensor` with shape `[batch, timesteps, 1, 1]` used
797797
to mask the attention to not attend to padding of input.
798-
summaries: Boolean, whether to output summaries.
799798
800799
Returns:
801800
a `Tensor` with same shape as `target`
@@ -814,7 +813,7 @@ def simple_attention(target, source, bias=None, summaries=True):
814813
if bias is not None:
815814
attention += tf.expand_dims(tf.squeeze(bias, axis=[2, 3]), axis=1)
816815
attention = tf.nn.softmax(attention)
817-
if summaries and not tf.get_variable_scope().reuse:
816+
if not tf.get_variable_scope().reuse:
818817
tf.summary.image("attention", tf.expand_dims(attention, 3), max_outputs=5)
819818
attended = tf.matmul(attention, source)
820819
return tf.reshape(attended, target_shape)
@@ -861,8 +860,7 @@ def multiscale_conv_sum(inputs, output_size, dilation_rates_and_kernel_sizes,
861860
def multiscale_conv_and_attention(x,
862861
padding,
863862
hparams,
864-
source=None,
865-
summaries=True):
863+
source=None):
866864
"""A common part of t2t layers.
867865
868866
First, do a linear multiscale convolution
@@ -875,7 +873,6 @@ def multiscale_conv_and_attention(x,
875873
padding: a padding type
876874
hparams: hyperparameters for model
877875
source: optional source tensor for attention. (encoder output)
878-
summaries: Boolean, whether to output summaries.
879876
880877
Returns:
881878
a Tensor.
@@ -893,7 +890,7 @@ def multiscale_conv_and_attention(x,
893890
x = conv(x, hparams.hidden_size, (1, 1))
894891
x = noam_norm(x + conv_sum)
895892
if source is not None:
896-
x = noam_norm(x + simple_attention(x, source, summaries=summaries))
893+
x = noam_norm(x + simple_attention(x, source))
897894
return x
898895

899896

@@ -930,8 +927,7 @@ def conv_with_pools(inputs, output_size, kernel_size, pool_sizes, pooling_type,
930927
def conv_with_pools_and_attention(x,
931928
padding,
932929
hparams,
933-
source=None,
934-
summaries=True):
930+
source=None):
935931
"""A common part of t2t layers.
936932
937933
First, do conv_with_pools
@@ -944,7 +940,6 @@ def conv_with_pools_and_attention(x,
944940
padding: a padding type
945941
hparams: hyperparameters for model
946942
source: optional source tensor for attention. (encoder output)
947-
summaries: Boolean, whether to output summaries.
948943
949944
Returns:
950945
a Tensor.
@@ -959,7 +954,7 @@ def conv_with_pools_and_attention(x,
959954
conv_sum += x
960955
x = noam_norm(conv_sum)
961956
if source is not None:
962-
x = noam_norm(x + simple_attention(x, source, summaries=summaries))
957+
x = noam_norm(x + simple_attention(x, source))
963958
return x
964959

965960

@@ -1057,7 +1052,6 @@ def attention_1d_v0(source,
10571052
transform_source=True,
10581053
transform_target=True,
10591054
transform_output=True,
1060-
summaries=True,
10611055
name=None):
10621056
"""multi-headed attention.
10631057
@@ -1075,7 +1069,6 @@ def attention_1d_v0(source,
10751069
transform_source: a boolean
10761070
transform_target: a boolean
10771071
transform_output: a boolean
1078-
summaries: a boolean
10791072
name: an optional string
10801073
10811074
Returns:
@@ -1116,7 +1109,7 @@ def _maybe_transform(t, size, should_transform, name):
11161109
mask = (1.0 - mask) * -1e9
11171110
attention += mask
11181111
attention = tf.nn.softmax(attention)
1119-
if summaries and not tf.get_variable_scope().reuse:
1112+
if not tf.get_variable_scope().reuse:
11201113
# Compute a color image summary.
11211114
image = tf.reshape(attention,
11221115
[batch, num_heads, target_length, source_length])
@@ -1162,7 +1155,6 @@ def conv_hidden_relu(inputs,
11621155
output_size,
11631156
kernel_size=(1, 1),
11641157
second_kernel_size=(1, 1),
1165-
summaries=True,
11661158
dropout=0.0,
11671159
**kwargs):
11681160
"""Hidden layer with RELU activation followed by linear projection."""
@@ -1183,7 +1175,7 @@ def conv_hidden_relu(inputs,
11831175
**kwargs)
11841176
if dropout != 0.0:
11851177
h = tf.nn.dropout(h, 1.0 - dropout)
1186-
if summaries and not tf.get_variable_scope().reuse:
1178+
if not tf.get_variable_scope().reuse:
11871179
tf.summary.histogram("hidden_density_logit",
11881180
relu_density_logit(
11891181
h, list(range(inputs.shape.ndims - 1))))

tensor2tensor/models/long_answer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def residual_fn(x, y):
7575
hparams.hidden_size,
7676
hparams.num_heads,
7777
hparams.attention_dropout,
78-
summaries=True,
7978
attention_type="local_mask_right",
8079
block_length=hparams.block_length,
8180
name="decoder_self_attention")

tensor2tensor/models/multimodel.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ def flatten(inputs):
138138
hparams.hidden_size,
139139
hparams.num_heads,
140140
hparams.attention_dropout,
141-
summaries=False,
142141
name="decoder_self_attention")
143142
z = dp(common_attention.multihead_attention,
144143
y,
@@ -149,7 +148,6 @@ def flatten(inputs):
149148
hparams.hidden_size,
150149
hparams.num_heads,
151150
hparams.attention_dropout,
152-
summaries=False,
153151
name="encdec_attention")
154152
x = dp(residual_fn3, x, y, z, hparams)
155153
with tf.variable_scope("ffn"):
@@ -164,8 +162,7 @@ def flatten(inputs):
164162
x,
165163
hparams.filter_size,
166164
hparams.hidden_size,
167-
dropout=hparams.dropout,
168-
summaries=False)
165+
dropout=hparams.dropout)
169166
x = dp(residual_fn2, x, y, hparams)
170167

171168
x = dp(tf.expand_dims, x, 2)

tensor2tensor/models/slicenet.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None):
6464
hparams.hidden_size,
6565
hparams.num_heads,
6666
hparams.attention_dropout,
67-
name="self_attention",
68-
summaries=False)
67+
name="self_attention")
6968
qv = common_attention.multihead_attention(
7069
qv,
7170
inputs_encoded,
@@ -75,12 +74,11 @@ def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None):
7574
hparams.hidden_size,
7675
hparams.num_heads,
7776
hparams.attention_dropout,
78-
name="encdec_attention",
79-
summaries=False)
77+
name="encdec_attention")
8078
return tf.expand_dims(qv, 2)
8179
elif hparams.attention_type == "simple":
8280
targets_with_attention = common_layers.simple_attention(
83-
targets_timed, inputs_encoded, bias=bias, summaries=False)
81+
targets_timed, inputs_encoded, bias=bias)
8482
return norm_fn(targets_shifted + targets_with_attention, name="attn_norm")
8583

8684

tensor2tensor/models/transformer.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,6 @@ def transformer_encoder(encoder_input,
143143
y: a Tensors
144144
"""
145145
x = encoder_input
146-
# Summaries don't work in multi-problem setting yet.
147-
summaries = len(hparams.problems) < 2
148146
with tf.variable_scope(name):
149147
for layer in xrange(hparams.num_hidden_layers):
150148
with tf.variable_scope("layer_%d" % layer):
@@ -159,7 +157,6 @@ def transformer_encoder(encoder_input,
159157
hparams.hidden_size,
160158
hparams.num_heads,
161159
hparams.attention_dropout,
162-
summaries=summaries,
163160
name="encoder_self_attention"))
164161
x = residual_fn(x, transformer_ffn_layer(x, hparams))
165162
return x
@@ -189,8 +186,6 @@ def transformer_decoder(decoder_input,
189186
y: a Tensors
190187
"""
191188
x = decoder_input
192-
# Summaries don't work in multi-problem setting yet.
193-
summaries = len(hparams.problems) < 2
194189
with tf.variable_scope(name):
195190
for layer in xrange(hparams.num_hidden_layers):
196191
with tf.variable_scope("layer_%d" % layer):
@@ -205,7 +200,6 @@ def transformer_decoder(decoder_input,
205200
hparams.hidden_size,
206201
hparams.num_heads,
207202
hparams.attention_dropout,
208-
summaries=summaries,
209203
name="decoder_self_attention"))
210204
x = residual_fn(
211205
x,
@@ -218,7 +212,6 @@ def transformer_decoder(decoder_input,
218212
hparams.hidden_size,
219213
hparams.num_heads,
220214
hparams.attention_dropout,
221-
summaries=summaries,
222215
name="encdec_attention"))
223216
x = residual_fn(x, transformer_ffn_layer(x, hparams))
224217
return x
@@ -234,15 +227,12 @@ def transformer_ffn_layer(x, hparams):
234227
Returns:
235228
a Tensor of shape [batch_size, length, hparams.hidden_size]
236229
"""
237-
# Summaries don't work in multi-problem setting yet.
238-
summaries = len(hparams.problems) < 2
239230
if hparams.ffn_layer == "conv_hidden_relu":
240231
return common_layers.conv_hidden_relu(
241232
x,
242233
hparams.filter_size,
243234
hparams.hidden_size,
244-
dropout=hparams.relu_dropout,
245-
summaries=summaries)
235+
dropout=hparams.relu_dropout)
246236
elif hparams.ffn_layer == "parameter_attention":
247237
return common_attention.parameter_attention(
248238
x,
@@ -260,8 +250,7 @@ def transformer_ffn_layer(x, hparams):
260250
kernel_size=(3, 1),
261251
second_kernel_size=(31, 1),
262252
padding="LEFT",
263-
dropout=hparams.relu_dropout,
264-
summaries=summaries)
253+
dropout=hparams.relu_dropout)
265254
else:
266255
assert hparams.ffn_layer == "none"
267256
return x

tensor2tensor/models/transformer_alternative.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,6 @@ def alt_transformer_decoder(decoder_input,
140140
"""Alternative decoder."""
141141
x = decoder_input
142142

143-
# Summaries don't work in multi-problem setting yet.
144-
summaries = "problems" not in hparams.values() or len(hparams.problems) == 1
145143
with tf.variable_scope(name):
146144
for layer in xrange(hparams.num_hidden_layers):
147145
with tf.variable_scope("layer_%d" % layer):
@@ -155,7 +153,6 @@ def alt_transformer_decoder(decoder_input,
155153
hparams.hidden_size,
156154
hparams.num_heads,
157155
hparams.attention_dropout,
158-
summaries=summaries,
159156
name="encdec_attention")
160157

161158
x_ = residual_fn(x_, composite_layer(x_, mask, hparams))

tensor2tensor/utils/trainer_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,13 @@ def nth_model(n):
550550
optimizer=opt,
551551
colocate_gradients_with_ops=True)
552552

553+
# Remove summaries that will fail to run because they are in conditionals.
554+
# TODO(cwhipkey): Test with this code removed, later in 2017.
555+
summaries = tf.get_collection_ref(tf.GraphKeys.SUMMARIES)
556+
for i in range(len(summaries)-1, -1, -1):
557+
if summaries[i].name.startswith("cond_"):
558+
del summaries[i]
559+
553560
tf.logging.info("Global model_fn finished.")
554561
return run_info, total_loss, train_op
555562

0 commit comments

Comments
 (0)