@@ -280,13 +280,13 @@ def attention_image_summary(attn, image_shapes=None):
280280 (query_rows, query_cols, query_channels,
281281 memory_rows, memory_cols, memory_channels).
282282 """
283- num_heads = attn . get_shape (). as_list ( )[1 ]
283+ num_heads = tf . shape ( attn )[1 ]
284284 # [batch, query_length, memory_length, num_heads]
285285 image = tf .transpose (attn , [0 , 2 , 3 , 1 ])
286286 image = tf .pow (image , 0.2 ) # for high-dynamic-range
287287 # Each head will correspond to one of RGB.
288288 # pad the heads to be a multiple of 3
289- image = tf .pad (image , [[0 , 0 ], [0 , 0 ], [0 , 0 ], [0 , - num_heads % 3 ]])
289+ image = tf .pad (image , [[0 , 0 ], [0 , 0 ], [0 , 0 ], [0 , tf . mod ( - num_heads , 3 ) ]])
290290 image = split_last_dimension (image , 3 )
291291 image = tf .reduce_max (image , 4 )
292292 if image_shapes is not None :
@@ -345,6 +345,95 @@ def dot_product_attention(q,
345345 return tf .matmul (weights , v )
346346
347347
348+ def masked_local_attention_1d (
349+ q , k , v , block_length = 128 , summaries = True , name = None ):
350+ """Attention to the source position and a neigborhood to the left of it.
351+
352+ The sequence is divided into blocks of length block_size.
353+ Attention for a given query position can only see memory positions
354+ less than or equal to the query position, in the corresponding block
355+ and the previous block.
356+
357+ If mask_right is True, then a target position cannot see greater source
358+ positions.
359+
360+ Args:
361+ q: a Tensor with shape [batch, heads, length, depth_k]
362+ k: a Tensor with shape [batch, heads, length, depth_k]
363+ v: a Tensor with shape [batch, heads, length, depth_v]
364+ block_length: an integer
365+ summaries: a boolean
366+ name: an optional string
367+
368+ Returns:
369+ a Tensor of shape [batch, heads, length, depth_v]
370+ """
371+ with tf .variable_scope (name , default_name = "local_attention_1d" ,
372+ values = [q , k , v ]):
373+ v_shape = v .get_shape ()
374+ batch = tf .shape (q )[0 ]
375+ heads = tf .shape (q )[1 ]
376+ length = tf .shape (q )[2 ]
377+ # If (length < 2 * block_length), then we use only one block.
378+ block_length = tf .where (tf .less (length , block_length * 2 ),
379+ length , block_length )
380+ depth_k = tf .shape (q )[3 ]
381+ depth_v = tf .shape (v )[3 ]
382+ original_length = length
383+ padding_size = tf .mod (- length , block_length )
384+ length += padding_size
385+ padding = [[0 , 0 ], [0 , 0 ], [0 , padding_size ], [0 , 0 ]]
386+ q = tf .pad (q , padding )
387+ k = tf .pad (k , padding )
388+ v = tf .pad (v , padding )
389+ num_blocks = tf .div (length , block_length )
390+
391+ # compute attention for the first query block.
392+ first_q = tf .slice (q , [0 , 0 , 0 , 0 ], [- 1 , - 1 , block_length , - 1 ])
393+ first_k = tf .slice (k , [0 , 0 , 0 , 0 ], [- 1 , - 1 , block_length , - 1 ])
394+ first_v = tf .slice (v , [0 , 0 , 0 , 0 ], [- 1 , - 1 , block_length , - 1 ])
395+ first_output = dot_product_attention (
396+ first_q , first_k , first_v , attention_bias_lower_triangle (block_length ),
397+ summaries = summaries , name = "fist_block" )
398+
399+ # compute attention for all subsequent query blocks.
400+ q = tf .reshape (q , [batch , heads , num_blocks , block_length , depth_k ])
401+ k = tf .reshape (k , [batch , heads , num_blocks , block_length , depth_k ])
402+ v = tf .reshape (v , [batch , heads , num_blocks , block_length , depth_v ])
403+
404+ def local (x ):
405+ """Create a local version of the keys or values."""
406+ prev_block = tf .slice (
407+ x , [0 , 0 , 0 , 0 , 0 ], [- 1 , - 1 , num_blocks - 1 , - 1 , - 1 ])
408+ cur_block = tf .slice (
409+ x , [0 , 0 , 1 , 0 , 0 ], [- 1 , - 1 , - 1 , - 1 , - 1 ])
410+ return tf .concat ([prev_block , cur_block ], 3 )
411+ local_k = local (k )
412+ local_v = local (v )
413+ tail_q = tf .slice (q , [0 , 0 , 1 , 0 , 0 ], [- 1 , - 1 , - 1 , - 1 , - 1 ])
414+
415+ local_length = tf .shape (local_k )[3 ]
416+
417+ # [batch, heads, num_blocks - 1, block_length, local_length]
418+ attention = tf .matmul (tail_q , local_k , transpose_b = True )
419+
420+ # make sure source_pos <= target_pos
421+ good_part = tf .matrix_band_part (
422+ tf .ones ([block_length , local_length ]), - 1 , tf .to_int64 (block_length ))
423+ mask = (1.0 - good_part ) * - 1e9
424+ attention += tf .reshape (mask , [1 , 1 , 1 , block_length , local_length ])
425+ attention = tf .nn .softmax (attention )
426+ # TODO(noam): figure out how to show a summary for the remaining blocks.
427+ # The naive way currently causes errors due to empty tensors.
428+ # output: [batch, heads, num_blocks-1, block_length, depth_v]
429+ output = tf .matmul (attention , local_v )
430+ output = tf .reshape (output , [batch , heads , - 1 , depth_v ])
431+ output = tf .concat ([first_output , output ], axis = 2 )
432+ output = tf .slice (output , [0 , 0 , 0 , 0 ], [- 1 , - 1 , original_length , - 1 ])
433+ output .set_shape (v_shape )
434+ return output
435+
436+
348437def multihead_attention (query_antecedent ,
349438 memory_antecedent ,
350439 bias ,
@@ -355,6 +444,8 @@ def multihead_attention(query_antecedent,
355444 dropout_rate ,
356445 summaries = False ,
357446 image_shapes = None ,
447+ attention_type = "dot_product" ,
448+ block_length = 128 ,
358449 name = None ):
359450 """Multihead scaled-dot-product attention with input/output transformations.
360451
@@ -370,6 +461,8 @@ def multihead_attention(query_antecedent,
370461 summaries: a boolean
371462 image_shapes: optional tuple of integer scalars.
372463 see comments for attention_image_summary()
464+ attention_type: a string, either "dot_product" or "local_mask_right"
465+ block_length: an integer - relevent for "local_mask_right"
373466 name: an optional string
374467
375468 Returns:
@@ -414,8 +507,14 @@ def multihead_attention(query_antecedent,
414507 v = split_heads (v , num_heads )
415508 key_depth_per_head = total_key_depth // num_heads
416509 q *= key_depth_per_head ** - 0.5
417- x = dot_product_attention (
418- q , k , v , bias , dropout_rate , summaries , image_shapes )
510+ if attention_type == "dot_product" :
511+ x = dot_product_attention (
512+ q , k , v , bias , dropout_rate , summaries , image_shapes )
513+ else :
514+ assert attention_type == "local_mask_right"
515+ x = masked_local_attention_1d (q , k , v ,
516+ block_length = block_length ,
517+ summaries = summaries )
419518 x = combine_heads (x )
420519 x = common_layers .conv1d (x , output_depth , 1 , name = "output_transform" )
421520 return x
0 commit comments