Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit e1ecf01

Browse files
nshazeerRyan Sepassi
authored andcommitted
Create a "long_answer" model for attacking the wikipedia title->article dataset.
PiperOrigin-RevId: 162264770
1 parent 5db92b5 commit e1ecf01

File tree

6 files changed

+389
-9
lines changed

6 files changed

+389
-9
lines changed

tensor2tensor/data_generators/problem_hparams.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,10 +351,9 @@ def wiki_32k(model_hparams):
351351
p = default_problem_hparams()
352352
encoder = text_encoder.SubwordTextEncoder(
353353
os.path.join(model_hparams.data_dir, "wiki_32k.subword_text_encoder"))
354-
p.input_modality = {
355-
"inputs": (registry.Modalities.SYMBOL, encoder.vocab_size)
356-
}
357-
p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size)
354+
modality_spec = (registry.Modalities.SYMBOL, encoder.vocab_size)
355+
p.input_modality = {"inputs": modality_spec}
356+
p.target_modality = modality_spec
358357
p.vocabulary = {
359358
"inputs": encoder,
360359
"targets": encoder

tensor2tensor/models/common_attention.py

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,13 +280,13 @@ def attention_image_summary(attn, image_shapes=None):
280280
(query_rows, query_cols, query_channels,
281281
memory_rows, memory_cols, memory_channels).
282282
"""
283-
num_heads = attn.get_shape().as_list()[1]
283+
num_heads = tf.shape(attn)[1]
284284
# [batch, query_length, memory_length, num_heads]
285285
image = tf.transpose(attn, [0, 2, 3, 1])
286286
image = tf.pow(image, 0.2) # for high-dynamic-range
287287
# Each head will correspond to one of RGB.
288288
# pad the heads to be a multiple of 3
289-
image = tf.pad(image, [[0, 0], [0, 0], [0, 0], [0, -num_heads % 3]])
289+
image = tf.pad(image, [[0, 0], [0, 0], [0, 0], [0, tf.mod(-num_heads, 3)]])
290290
image = split_last_dimension(image, 3)
291291
image = tf.reduce_max(image, 4)
292292
if image_shapes is not None:
@@ -345,6 +345,95 @@ def dot_product_attention(q,
345345
return tf.matmul(weights, v)
346346

347347

348+
def masked_local_attention_1d(
349+
q, k, v, block_length=128, summaries=True, name=None):
350+
"""Attention to the source position and a neigborhood to the left of it.
351+
352+
The sequence is divided into blocks of length block_size.
353+
Attention for a given query position can only see memory positions
354+
less than or equal to the query position, in the corresponding block
355+
and the previous block.
356+
357+
If mask_right is True, then a target position cannot see greater source
358+
positions.
359+
360+
Args:
361+
q: a Tensor with shape [batch, heads, length, depth_k]
362+
k: a Tensor with shape [batch, heads, length, depth_k]
363+
v: a Tensor with shape [batch, heads, length, depth_v]
364+
block_length: an integer
365+
summaries: a boolean
366+
name: an optional string
367+
368+
Returns:
369+
a Tensor of shape [batch, heads, length, depth_v]
370+
"""
371+
with tf.variable_scope(name, default_name="local_attention_1d",
372+
values=[q, k, v]):
373+
v_shape = v.get_shape()
374+
batch = tf.shape(q)[0]
375+
heads = tf.shape(q)[1]
376+
length = tf.shape(q)[2]
377+
# If (length < 2 * block_length), then we use only one block.
378+
block_length = tf.where(tf.less(length, block_length * 2),
379+
length, block_length)
380+
depth_k = tf.shape(q)[3]
381+
depth_v = tf.shape(v)[3]
382+
original_length = length
383+
padding_size = tf.mod(-length, block_length)
384+
length += padding_size
385+
padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]]
386+
q = tf.pad(q, padding)
387+
k = tf.pad(k, padding)
388+
v = tf.pad(v, padding)
389+
num_blocks = tf.div(length, block_length)
390+
391+
# compute attention for the first query block.
392+
first_q = tf.slice(q, [0, 0, 0, 0], [-1, -1, block_length, -1])
393+
first_k = tf.slice(k, [0, 0, 0, 0], [-1, -1, block_length, -1])
394+
first_v = tf.slice(v, [0, 0, 0, 0], [-1, -1, block_length, -1])
395+
first_output = dot_product_attention(
396+
first_q, first_k, first_v, attention_bias_lower_triangle(block_length),
397+
summaries=summaries, name="fist_block")
398+
399+
# compute attention for all subsequent query blocks.
400+
q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k])
401+
k = tf.reshape(k, [batch, heads, num_blocks, block_length, depth_k])
402+
v = tf.reshape(v, [batch, heads, num_blocks, block_length, depth_v])
403+
404+
def local(x):
405+
"""Create a local version of the keys or values."""
406+
prev_block = tf.slice(
407+
x, [0, 0, 0, 0, 0], [-1, -1, num_blocks - 1, -1, -1])
408+
cur_block = tf.slice(
409+
x, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1])
410+
return tf.concat([prev_block, cur_block], 3)
411+
local_k = local(k)
412+
local_v = local(v)
413+
tail_q = tf.slice(q, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1])
414+
415+
local_length = tf.shape(local_k)[3]
416+
417+
# [batch, heads, num_blocks - 1, block_length, local_length]
418+
attention = tf.matmul(tail_q, local_k, transpose_b=True)
419+
420+
# make sure source_pos <= target_pos
421+
good_part = tf.matrix_band_part(
422+
tf.ones([block_length, local_length]), -1, tf.to_int64(block_length))
423+
mask = (1.0 - good_part) * -1e9
424+
attention += tf.reshape(mask, [1, 1, 1, block_length, local_length])
425+
attention = tf.nn.softmax(attention)
426+
# TODO(noam): figure out how to show a summary for the remaining blocks.
427+
# The naive way currently causes errors due to empty tensors.
428+
# output: [batch, heads, num_blocks-1, block_length, depth_v]
429+
output = tf.matmul(attention, local_v)
430+
output = tf.reshape(output, [batch, heads, -1, depth_v])
431+
output = tf.concat([first_output, output], axis=2)
432+
output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
433+
output.set_shape(v_shape)
434+
return output
435+
436+
348437
def multihead_attention(query_antecedent,
349438
memory_antecedent,
350439
bias,
@@ -355,6 +444,8 @@ def multihead_attention(query_antecedent,
355444
dropout_rate,
356445
summaries=False,
357446
image_shapes=None,
447+
attention_type="dot_product",
448+
block_length=128,
358449
name=None):
359450
"""Multihead scaled-dot-product attention with input/output transformations.
360451
@@ -370,6 +461,8 @@ def multihead_attention(query_antecedent,
370461
summaries: a boolean
371462
image_shapes: optional tuple of integer scalars.
372463
see comments for attention_image_summary()
464+
attention_type: a string, either "dot_product" or "local_mask_right"
465+
block_length: an integer - relevent for "local_mask_right"
373466
name: an optional string
374467
375468
Returns:
@@ -414,8 +507,14 @@ def multihead_attention(query_antecedent,
414507
v = split_heads(v, num_heads)
415508
key_depth_per_head = total_key_depth // num_heads
416509
q *= key_depth_per_head**-0.5
417-
x = dot_product_attention(
418-
q, k, v, bias, dropout_rate, summaries, image_shapes)
510+
if attention_type == "dot_product":
511+
x = dot_product_attention(
512+
q, k, v, bias, dropout_rate, summaries, image_shapes)
513+
else:
514+
assert attention_type == "local_mask_right"
515+
x = masked_local_attention_1d(q, k, v,
516+
block_length=block_length,
517+
summaries=summaries)
419518
x = combine_heads(x)
420519
x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
421520
return x

tensor2tensor/models/common_hparams.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ def basic_params1():
7272
# setting the max length in a minibatch. 0 means default behavior,
7373
# max_length = hparams.batch_size * length_multiplier
7474
max_length=0,
75+
# If set to True, drop sequences longer than max_length during eval.
76+
# This affects the validity of the evaluation metrics.
77+
eval_drop_long_sequences=int(False),
7578
# in SymbolModality, share the output embeddings and the softmax
7679
# variables.
7780
# You can also share the input embeddings with the output embeddings

0 commit comments

Comments
 (0)