Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit bd300de

Browse files
committed
Two fix:
1. for beam search use the shape.ndims to avoid out of index error 2. for greedy search, still use the shape.ndims to avoid the out of index error. Before that I misuse the slice operation:-(
1 parent 5563a06 commit bd300de

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

tensor2tensor/utils/t2t_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,10 @@ def symbols_to_logits_fn(ids):
196196
if last_position_only:
197197
return tf.squeeze(logits, axis=[1, 2, 3])
198198
current_output_position = tf.shape(ids)[1] - 1 # -1 due to the pad above.
199-
logits = logits[:, current_output_position, :, :]
199+
if current_output_position.shape.ndims >= 1:
200+
logits = logits[:, current_output_position, :, :]
201+
else:
202+
logits = logits[:, -1 , :, :]
200203
return tf.squeeze(logits, axis=[1, 2])
201204

202205
batch_size = tf.shape(features["inputs"])[0]
@@ -270,7 +273,7 @@ def infer_step(recent_output, _):
270273
cur_sample = samples[:, -1, :, :]
271274
else:
272275
#Avoid the out of index Error
273-
if len(tf.shape(recent_output)) >= 2:
276+
if tf.shape(recent_output).shape.ndims >= 2:
274277
cur_sample = samples[:, tf.shape(recent_output)[1], :, :]
275278
else:
276279
cur_sample = samples[:, -1, :, :]

0 commit comments

Comments
 (0)