Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 06df1d4

Browse files
authored
Merge pull request #69 from yynil/master
I'm afraid I need to ask for pull request again for the "Out of index" Error when doing Class_Label_Modality
2 parents af235c1 + bd300de commit 06df1d4

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

tensor2tensor/utils/t2t_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,10 @@ def symbols_to_logits_fn(ids):
196196
if last_position_only:
197197
return tf.squeeze(logits, axis=[1, 2, 3])
198198
current_output_position = tf.shape(ids)[1] - 1 # -1 due to the pad above.
199-
logits = logits[:, current_output_position, :, :]
199+
if current_output_position.shape.ndims >= 1:
200+
logits = logits[:, current_output_position, :, :]
201+
else:
202+
logits = logits[:, -1 , :, :]
200203
return tf.squeeze(logits, axis=[1, 2])
201204

202205
batch_size = tf.shape(features["inputs"])[0]
@@ -270,7 +273,7 @@ def infer_step(recent_output, _):
270273
cur_sample = samples[:, -1, :, :]
271274
else:
272275
#Avoid the out of index Error
273-
if len(tf.shape(recent_output)) >= 2:
276+
if tf.shape(recent_output).shape.ndims >= 2:
274277
cur_sample = samples[:, tf.shape(recent_output)[1], :, :]
275278
else:
276279
cur_sample = samples[:, -1, :, :]

0 commit comments

Comments
 (0)