Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 8f05bab

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Correction for eager-mode decoding scopes (to use with pre-trained checkpoints).
PiperOrigin-RevId: 177261645
1 parent 0ffe0e6 commit 8f05bab

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

tensor2tensor/models/transformer.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,12 @@ def _greedy_infer(self, features, decode_length):
163163
Raises:
164164
NotImplementedError: If there are multiple data shards.
165165
"""
166-
with tf.variable_scope(self.name):
167-
# TODO(nikip): Remove slow decoding for eager. Eager mode doesn't work
168-
# with accessing _shape which is used in fast decoding currently.
169-
if self._hparams.use_eager_mode:
170-
return self._slow_greedy_infer(features, decode_length)
171-
else:
166+
# TODO(nikip): Remove slow decoding for eager. Eager mode doesn't work
167+
# with accessing _shape which is used in fast decoding currently.
168+
if self._hparams.use_eager_mode:
169+
return self._slow_greedy_infer(features, decode_length)
170+
else:
171+
with tf.variable_scope(self.name):
172172
decoded_ids, _ = self._fast_decode(features, decode_length)
173173
return decoded_ids, None, None
174174

@@ -186,13 +186,13 @@ def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha):
186186
Returns:
187187
samples: an integer `Tensor`. Top samples from the beam search
188188
"""
189-
with tf.variable_scope(self.name):
190-
# TODO(nikip): Remove slow decoding for eager. Eager mode doesn't work
191-
# with accessing _shape which is used in fast decoding currently.
192-
if self._hparams.use_eager_mode:
193-
return self._beam_decode_slow(
194-
features, decode_length, beam_size, top_beams, alpha)
195-
else:
189+
# TODO(nikip): Remove slow decoding for eager. Eager mode doesn't work
190+
# with accessing _shape which is used in fast decoding currently.
191+
if self._hparams.use_eager_mode:
192+
return self._beam_decode_slow(
193+
features, decode_length, beam_size, top_beams, alpha)
194+
else:
195+
with tf.variable_scope(self.name):
196196
decoded_ids, scores = self._fast_decode(features, decode_length,
197197
beam_size, top_beams, alpha)
198198
return {"outputs": decoded_ids, "scores": scores}

0 commit comments

Comments
 (0)