@@ -2886,25 +2886,6 @@ def sample_with_temperature(logits, temperature, sampling_keep_top_k=-1):
28862886 return choices
28872887
28882888
2889- def _to_nd_indices (indices ):
2890- """Returns indices used for tf.gather_nd or tf.scatter_nd.
2891-
2892- Args:
2893- indices: A `Tensor` of shape [batch_size, size] with integer values. The
2894- values are the indices of another `Tensor`. For example, `indices` is the
2895- output of tf.argsort or tf.math.top_k.
2896-
2897- Returns:
2898- A `Tensor` with shape [batch_size, size, 2] that can be used by tf.gather_nd
2899- or tf.scatter_nd.
2900-
2901- """
2902- indices .get_shape ().assert_has_rank (2 )
2903- batch_ids = tf .ones_like (indices ) * tf .expand_dims (
2904- tf .range (tf .shape (input = indices )[0 ]), 1 )
2905- return tf .stack ([batch_ids , indices ], axis = - 1 )
2906-
2907-
29082889def _select_top_k (logits , top_k ):
29092890 """Replaces logits, expect the top k highest values, with small number (-1e6).
29102891
@@ -2918,26 +2899,15 @@ def _select_top_k(logits, top_k):
29182899 A `Tensor` with same shape as logits.
29192900 """
29202901 vocab_size = logits .shape [- 1 ]
2921- flat_logits = tf . reshape ( logits , [ - 1 , vocab_size ])
2902+
29222903 top_k = tf .where (
29232904 tf .not_equal (top_k , - 1 ), top_k ,
29242905 tf .ones_like (top_k ) * vocab_size )
2925- values , idx = tf .math .top_k (flat_logits , k = vocab_size , sorted = False )
2926- nd_idx = _to_nd_indices (idx )
29272906
2928- mask_idx = tf .reshape (
2929- tf .range (vocab_size ), [1 ] * (len (logits .shape ) - 1 ) + [- 1 ])
2930- for i , size in enumerate (logits .shape [:- 1 ]):
2931- mask_idx = tf .repeat (mask_idx , size , axis = i )
2932- mask = tf .reshape (
2933- mask_idx < tf .reshape (top_k , [- 1 ] + [1 ] * (len (logits .shape ) - 1 )), [- 1 ])
2934-
2935- topk_logits = tf .tensor_scatter_nd_update (
2936- tf .ones_like (flat_logits ) * - 1e6 ,
2937- tf .reshape (nd_idx , [- 1 , 2 ])[mask ],
2938- tf .reshape (values , [- 1 ])[mask ])
2939-
2940- return tf .reshape (topk_logits , logits .shape )
2907+ return tf .where (
2908+ tf .argsort (logits ) < tf .reshape (top_k , [- 1 ] + [1 ] *
2909+ (len (logits .shape ) - 1 )), logits ,
2910+ tf .ones_like (logits ) * - 1e6 )
29412911
29422912
29432913def sample_temperature_per_example (logits , temperature , sampling_keep_top_k = - 1 ):
@@ -2950,9 +2920,7 @@ def sample_temperature_per_example(logits, temperature, sampling_keep_top_k=-1):
29502920 Returns:
29512921 a Tensor with one fewer dimension than logits.
29522922 """
2953- if sampling_keep_top_k != - 1 :
2954- logits = _select_top_k (logits , sampling_keep_top_k )
2955-
2923+ logits = _select_top_k (logits , sampling_keep_top_k )
29562924 logits /= tf .reshape (temperature , [- 1 ] + [1 ] * (len (logits .shape ) - 1 ))
29572925 reshaped_logits = tf .reshape (logits , [- 1 , shape_list (logits )[- 1 ]])
29582926 choices = tf .multinomial (reshaped_logits , 1 )
0 commit comments