@@ -271,10 +271,14 @@ def attention_image_summary(attn, image_shapes=None):
271271
272272 Args:
273273 attn: a Tensor with shape [batch, num_heads, query_length, memory_length]
274- image_shapes: optional quadruple of integer scalars.
274+ image_shapes: optional tuple of integer scalars.
275275 If the query positions and memory positions represent the
276- pixels of a flattened image , then pass in their dimensions:
276+ pixels of flattened images , then pass in their dimensions:
277277 (query_rows, query_cols, memory_rows, memory_cols).
278+ If the query positions and memory positions represent the
279+ pixels x channels of flattened images, then pass in their dimensions:
280+ (query_rows, query_cols, query_channels,
281+ memory_rows, memory_cols, memory_channels).
278282 """
279283 num_heads = attn .get_shape ().as_list ()[1 ]
280284 # [batch, query_length, memory_length, num_heads]
@@ -286,10 +290,20 @@ def attention_image_summary(attn, image_shapes=None):
286290 image = split_last_dimension (image , 3 )
287291 image = tf .reduce_max (image , 4 )
288292 if image_shapes is not None :
289- q_rows , q_cols , m_rows , m_cols = list (image_shapes )
290- image = tf .reshape (image , [- 1 , q_rows , q_cols , m_rows , m_cols , 3 ])
291- image = tf .transpose (image , [0 , 1 , 3 , 2 , 4 , 5 ])
292- image = tf .reshape (image , [- 1 , q_rows * m_rows , q_cols * m_cols , 3 ])
293+ if len (image_shapes ) == 4 :
294+ q_rows , q_cols , m_rows , m_cols = list (image_shapes )
295+ image = tf .reshape (image , [- 1 , q_rows , q_cols , m_rows , m_cols , 3 ])
296+ image = tf .transpose (image , [0 , 1 , 3 , 2 , 4 , 5 ])
297+ image = tf .reshape (image , [- 1 , q_rows * m_rows , q_cols * m_cols , 3 ])
298+ else :
299+ assert len (image_shapes ) == 6
300+ q_rows , q_cols , q_channnels , m_rows , m_cols , m_channels = list (
301+ image_shapes )
302+ image = tf .reshape (image , [- 1 , q_rows , q_cols , q_channnels ,
303+ m_rows , m_cols , m_channels , 3 ])
304+ image = tf .transpose (image , [0 , 1 , 4 , 3 , 2 , 5 , 6 , 7 ])
305+ image = tf .reshape (image , [- 1 , q_rows * m_rows * q_channnels ,
306+ q_cols * m_cols * m_channels , 3 ])
293307 tf .summary .image ("attention" , image , max_outputs = 1 )
294308
295309
@@ -310,10 +324,8 @@ def dot_product_attention(q,
310324 bias: bias Tensor (see attention_bias())
311325 dropout_rate: a floating point number
312326 summaries: a boolean
313- image_shapes: optional quadruple of integer scalars for image summary.
314- If the query positions and memory positions represent the
315- pixels of a flattened image, then pass in their dimensions:
316- (query_rows, query_cols, memory_rows, memory_cols).
327+ image_shapes: optional tuple of integer scalars.
328+ see comments for attention_image_summary()
317329 name: an optional string
318330
319331 Returns:
@@ -356,10 +368,8 @@ def multihead_attention(query_antecedent,
356368 num_heads: an integer dividing total_key_depth and total_value_depth
357369 dropout_rate: a floating point number
358370 summaries: a boolean
359- image_shapes: optional quadruple of integer scalars for image summary.
360- If the query positions and memory positions represent the
361- pixels of a flattened image, then pass in their dimensions:
362- (query_rows, query_cols, memory_rows, memory_cols).
371+ image_shapes: optional tuple of integer scalars.
372+ see comments for attention_image_summary()
363373 name: an optional string
364374
365375 Returns:
@@ -398,3 +408,72 @@ def multihead_attention(query_antecedent,
398408 x = combine_heads (x )
399409 x = common_layers .conv1d (x , output_depth , 1 , name = "output_transform" )
400410 return x
411+
412+
413+ def parameter_attention (x ,
414+ total_key_depth ,
415+ total_value_depth ,
416+ output_depth ,
417+ memory_rows ,
418+ num_heads ,
419+ dropout_rate ,
420+ name = None ):
421+ """Attention over parameters.
422+
423+ We use the same multi-headed attention as in the other layers, but the memory
424+ keys and values are model parameters. There are no linear transformation
425+ on the keys or values.
426+
427+ We are also a bit more careful about memory usage, since the number of
428+ memory positions may be very large.
429+
430+ Args:
431+ x: a Tensor with shape [batch, length_q, channels]
432+ total_key_depth: an integer
433+ total_value_depth: an integer
434+ output_depth: an integer
435+ memory_rows: an integer
436+ num_heads: an integer dividing total_key_depth and total_value_depth
437+ dropout_rate: a floating point number
438+ name: an optional string
439+
440+ Returns:
441+ A Tensor.
442+ """
443+ with tf .variable_scope (name , default_name = "parameter_attention" ,
444+ values = [x ]):
445+ head_size_k = total_key_depth // num_heads
446+ head_size_v = total_value_depth // num_heads
447+ var_shape_k = [num_heads , memory_rows , head_size_k ]
448+ var_shape_v = [num_heads , memory_rows , head_size_v ]
449+ k = tf .get_variable (
450+ "k" , var_shape_k ,
451+ initializer = tf .random_normal_initializer (
452+ 0 , output_depth ** - 0.5 )) * (num_heads ** 0.5 )
453+ v = tf .get_variable (
454+ "v" , var_shape_v ,
455+ initializer = tf .random_normal_initializer (
456+ 0 , output_depth ** - 0.5 )) * (output_depth ** 0.5 )
457+ batch_size = tf .shape (x )[0 ]
458+ length = tf .shape (x )[1 ]
459+ q = common_layers .conv1d (x , total_key_depth , 1 , name = "q_transform" )
460+ if dropout_rate :
461+ # This is a cheaper form of attention dropout where we use to use
462+ # the same dropout decisions across batch elemets and query positions,
463+ # but different decisions across heads and memory positions.
464+ v = tf .nn .dropout (v , 1.0 - dropout_rate ,
465+ noise_shape = [num_heads , memory_rows , 1 ])
466+ # query is [batch, length, hidden_size]
467+ # reshape and transpose it to [heads, batch * length, head_size]
468+ q = tf .reshape (q , [batch_size , length , num_heads , head_size_k ])
469+ q = tf .transpose (q , [2 , 0 , 1 , 3 ])
470+ q = tf .reshape (q , [num_heads , batch_size * length , head_size_k ])
471+ weights = tf .matmul (q , k , transpose_b = True )
472+ weights = tf .nn .softmax (weights )
473+ y = tf .matmul (weights , v )
474+ y = tf .reshape (y , [num_heads , batch_size , length , head_size_v ])
475+ y = tf .transpose (y , [1 , 2 , 0 , 3 ])
476+ y = tf .reshape (y , [batch_size , length , total_value_depth ])
477+ y .set_shape ([None , None , total_value_depth ])
478+ y = common_layers .conv1d (y , output_depth , 1 , name = "output_transform" )
479+ return y
0 commit comments