@@ -271,10 +271,14 @@ def attention_image_summary(attn, image_shapes=None):
271271
272272 Args:
273273 attn: a Tensor with shape [batch, num_heads, query_length, memory_length]
274- image_shapes: optional quadruple of integer scalars.
274+ image_shapes: optional tuple of integer scalars.
275275 If the query positions and memory positions represent the
276- pixels of a flattened image , then pass in their dimensions:
276+ pixels of flattened images , then pass in their dimensions:
277277 (query_rows, query_cols, memory_rows, memory_cols).
278+ If the query positions and memory positions represent the
279+ pixels x channels of flattened images, then pass in their dimensions:
280+ (query_rows, query_cols, query_channels,
281+ memory_rows, memory_cols, memory_channels).
278282 """
279283 num_heads = attn .get_shape ().as_list ()[1 ]
280284 # [batch, query_length, memory_length, num_heads]
@@ -286,10 +290,20 @@ def attention_image_summary(attn, image_shapes=None):
286290 image = split_last_dimension (image , 3 )
287291 image = tf .reduce_max (image , 4 )
288292 if image_shapes is not None :
289- q_rows , q_cols , m_rows , m_cols = list (image_shapes )
290- image = tf .reshape (image , [- 1 , q_rows , q_cols , m_rows , m_cols , 3 ])
291- image = tf .transpose (image , [0 , 1 , 3 , 2 , 4 , 5 ])
292- image = tf .reshape (image , [- 1 , q_rows * m_rows , q_cols * m_cols , 3 ])
293+ if len (image_shapes ) == 4 :
294+ q_rows , q_cols , m_rows , m_cols = list (image_shapes )
295+ image = tf .reshape (image , [- 1 , q_rows , q_cols , m_rows , m_cols , 3 ])
296+ image = tf .transpose (image , [0 , 1 , 3 , 2 , 4 , 5 ])
297+ image = tf .reshape (image , [- 1 , q_rows * m_rows , q_cols * m_cols , 3 ])
298+ else :
299+ assert len (image_shapes ) == 6
300+ q_rows , q_cols , q_channnels , m_rows , m_cols , m_channels = list (
301+ image_shapes )
302+ image = tf .reshape (image , [- 1 , q_rows , q_cols , q_channnels ,
303+ m_rows , m_cols , m_channels , 3 ])
304+ image = tf .transpose (image , [0 , 1 , 4 , 3 , 2 , 5 , 6 , 7 ])
305+ image = tf .reshape (image , [- 1 , q_rows * m_rows * q_channnels ,
306+ q_cols * m_cols * m_channels , 3 ])
293307 tf .summary .image ("attention" , image , max_outputs = 1 )
294308
295309
@@ -310,10 +324,8 @@ def dot_product_attention(q,
310324 bias: bias Tensor (see attention_bias())
311325 dropout_rate: a floating point number
312326 summaries: a boolean
313- image_shapes: optional quadruple of integer scalars for image summary.
314- If the query positions and memory positions represent the
315- pixels of a flattened image, then pass in their dimensions:
316- (query_rows, query_cols, memory_rows, memory_cols).
327+ image_shapes: optional tuple of integer scalars.
328+ see comments for attention_image_summary()
317329 name: an optional string
318330
319331 Returns:
@@ -356,10 +368,8 @@ def multihead_attention(query_antecedent,
356368 num_heads: an integer dividing total_key_depth and total_value_depth
357369 dropout_rate: a floating point number
358370 summaries: a boolean
359- image_shapes: optional quadruple of integer scalars for image summary.
360- If the query positions and memory positions represent the
361- pixels of a flattened image, then pass in their dimensions:
362- (query_rows, query_cols, memory_rows, memory_cols).
371+ image_shapes: optional tuple of integer scalars.
372+ see comments for attention_image_summary()
363373 name: an optional string
364374
365375 Returns:
0 commit comments