Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 72b8f7b

Browse files
T2T Teamcopybara-github
authored andcommitted
Simplify implementation of sample_temperature_per_example and make it work with dynamic shapes.
PiperOrigin-RevId: 308944400
1 parent 022387c commit 72b8f7b

File tree

2 files changed

+51
-39
lines changed

2 files changed

+51
-39
lines changed

tensor2tensor/layers/common_layers.py

Lines changed: 6 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2886,25 +2886,6 @@ def sample_with_temperature(logits, temperature, sampling_keep_top_k=-1):
28862886
return choices
28872887

28882888

2889-
def _to_nd_indices(indices):
2890-
"""Returns indices used for tf.gather_nd or tf.scatter_nd.
2891-
2892-
Args:
2893-
indices: A `Tensor` of shape [batch_size, size] with integer values. The
2894-
values are the indices of another `Tensor`. For example, `indices` is the
2895-
output of tf.argsort or tf.math.top_k.
2896-
2897-
Returns:
2898-
A `Tensor` with shape [batch_size, size, 2] that can be used by tf.gather_nd
2899-
or tf.scatter_nd.
2900-
2901-
"""
2902-
indices.get_shape().assert_has_rank(2)
2903-
batch_ids = tf.ones_like(indices) * tf.expand_dims(
2904-
tf.range(tf.shape(input=indices)[0]), 1)
2905-
return tf.stack([batch_ids, indices], axis=-1)
2906-
2907-
29082889
def _select_top_k(logits, top_k):
29092890
"""Replaces logits, expect the top k highest values, with small number (-1e6).
29102891
@@ -2918,26 +2899,15 @@ def _select_top_k(logits, top_k):
29182899
A `Tensor` with same shape as logits.
29192900
"""
29202901
vocab_size = logits.shape[-1]
2921-
flat_logits = tf.reshape(logits, [-1, vocab_size])
2902+
29222903
top_k = tf.where(
29232904
tf.not_equal(top_k, -1), top_k,
29242905
tf.ones_like(top_k) * vocab_size)
2925-
values, idx = tf.math.top_k(flat_logits, k=vocab_size, sorted=False)
2926-
nd_idx = _to_nd_indices(idx)
29272906

2928-
mask_idx = tf.reshape(
2929-
tf.range(vocab_size), [1] * (len(logits.shape) - 1) + [-1])
2930-
for i, size in enumerate(logits.shape[:-1]):
2931-
mask_idx = tf.repeat(mask_idx, size, axis=i)
2932-
mask = tf.reshape(
2933-
mask_idx < tf.reshape(top_k, [-1] + [1] * (len(logits.shape) - 1)), [-1])
2934-
2935-
topk_logits = tf.tensor_scatter_nd_update(
2936-
tf.ones_like(flat_logits) * -1e6,
2937-
tf.reshape(nd_idx, [-1, 2])[mask],
2938-
tf.reshape(values, [-1])[mask])
2939-
2940-
return tf.reshape(topk_logits, logits.shape)
2907+
return tf.where(
2908+
tf.argsort(logits) < tf.reshape(top_k, [-1] + [1] *
2909+
(len(logits.shape) - 1)), logits,
2910+
tf.ones_like(logits) * -1e6)
29412911

29422912

29432913
def sample_temperature_per_example(logits, temperature, sampling_keep_top_k=-1):
@@ -2950,9 +2920,7 @@ def sample_temperature_per_example(logits, temperature, sampling_keep_top_k=-1):
29502920
Returns:
29512921
a Tensor with one fewer dimension than logits.
29522922
"""
2953-
if sampling_keep_top_k != -1:
2954-
logits = _select_top_k(logits, sampling_keep_top_k)
2955-
2923+
logits = _select_top_k(logits, sampling_keep_top_k)
29562924
logits /= tf.reshape(temperature, [-1] + [1] * (len(logits.shape) - 1))
29572925
reshaped_logits = tf.reshape(logits, [-1, shape_list(logits)[-1]])
29582926
choices = tf.multinomial(reshaped_logits, 1)

tensor2tensor/layers/common_layers_test.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,11 +704,55 @@ def testSampleTemperaturePerExample(self):
704704
logits = np.random.randn(batch_size, seq_len, 1, 1, vocab_size)
705705
temperature = np.random.rand(batch_size)
706706

707-
out = common_layers.sample_temperature_per_example(logits, temperature)
707+
out = common_layers.sample_temperature_per_example(logits, temperature, -1)
708708

709709
self.assertAllEqual(
710710
self.evaluate(tf.shape(out)), [batch_size, seq_len, 1, 1])
711711

712+
@test_utils.run_in_graph_and_eager_modes()
713+
def testSampleTemperaturePerExampleWithTopK(self):
714+
batch_size = 3
715+
seq_len = 5
716+
vocab_size = 7
717+
718+
logits = np.random.randn(batch_size, seq_len, 1, 1, vocab_size)
719+
temperature = np.random.rand(batch_size)
720+
top_k = np.array([3, -1, 4], dtype=np.int32)
721+
722+
out = common_layers.sample_temperature_per_example(logits, temperature,
723+
top_k)
724+
725+
self.assertAllEqual(
726+
self.evaluate(tf.shape(out)), [batch_size, seq_len, 1, 1])
727+
728+
@test_utils.run_in_graph_and_eager_modes()
729+
def testSampleTemperaturePerExampleWithTopK2(self):
730+
batch_size = 3
731+
vocab_size = 7
732+
733+
logits = np.random.randn(batch_size, vocab_size)
734+
temperature = np.random.rand(batch_size)
735+
top_k = np.array([3, -1, 4], dtype=np.int32)
736+
737+
out = common_layers.sample_temperature_per_example(logits, temperature,
738+
top_k)
739+
740+
self.assertAllEqual(self.evaluate(tf.shape(out)), [batch_size])
741+
742+
@test_utils.run_in_graph_mode_only()
743+
def testSampleTemperaturePerExampleDynamicBatchSize(self):
744+
batch_size = None
745+
vocab_size = 7
746+
747+
logits = tf.placeholder(tf.float32, shape=(batch_size, vocab_size))
748+
temperature = tf.placeholder(tf.float32, shape=(batch_size, 1))
749+
sampling_keep_top_k = tf.placeholder(tf.int32, shape=(batch_size, 1))
750+
751+
out = common_layers.sample_temperature_per_example(logits, temperature,
752+
sampling_keep_top_k)
753+
754+
self.assertAllEqual(out.shape.as_list(), [batch_size])
755+
712756
@test_utils.run_in_graph_and_eager_modes()
713757
def testCycleGANUpsampleNnUpsampleConv(self):
714758
batch = 8

0 commit comments

Comments
 (0)