@@ -1477,21 +1477,46 @@ def padded_cross_entropy(logits,
14771477 return tf .reduce_sum (xent * weights ), tf .reduce_sum (weights )
14781478
14791479
1480- def smoothing_cross_entropy (logits , labels , vocab_size , confidence ):
1481- """Cross entropy with label smoothing to limit over-confidence."""
1480+ def smoothing_cross_entropy (logits , labels , vocab_size , confidence ,
1481+ gaussian = False ):
1482+ """Cross entropy with label smoothing to limit over-confidence.
1483+
1484+ Args:
1485+ logits: Tensor of size [batch_size, ?, ?, ?, vocab_size]
1486+ labels: Tensor of size [batch_size, ?, ?, ?]
1487+ vocab_size: Tensor representing the size of the vocabulary.
1488+ confidence: Used to determine on and off values for label smoothing.
1489+ If `gaussian` is true, `confidence` is the variance to the gaussian
1490+ distribution.
1491+ gaussian: Uses a gaussian distribution for label smoothing
1492+
1493+ Returns:
1494+
1495+ """
14821496 with tf .name_scope ("smoothing_cross_entropy" , [logits , labels ]):
14831497 # Low confidence is given to all non-true labels, uniformly.
14841498 low_confidence = (1.0 - confidence ) / tf .to_float (vocab_size - 1 )
14851499 # Normalizing constant is the best cross-entropy value with soft targets.
14861500 # We subtract it just for readability, makes no difference on learning.
14871501 normalizing = - (confidence * tf .log (confidence ) + tf .to_float (
14881502 vocab_size - 1 ) * low_confidence * tf .log (low_confidence + 1e-20 ))
1489- # Soft targets.
1490- soft_targets = tf .one_hot (
1491- tf .cast (labels , tf .int32 ),
1492- depth = vocab_size ,
1493- on_value = confidence ,
1494- off_value = low_confidence )
1503+
1504+ if gaussian :
1505+ labels = tf .cast (labels , tf .float32 )
1506+
1507+ normal_dist = tf .distributions .Normal (loc = labels , scale = confidence )
1508+ # Locations to evaluate the probability distributions.
1509+ soft_targets = normal_dist .prob (tf .cast (tf .range (vocab_size ), tf .float32 )
1510+ [:, None , None , None , None ])
1511+ # Reordering soft_targets from [vocab_size, batch_size, ?, ?, ?] to match
1512+ # logits: [batch_size, ?, ?, ?, vocab_size]
1513+ soft_targets = tf .transpose (soft_targets , perm = [1 , 2 , 3 , 4 , 0 ])
1514+ else :
1515+ soft_targets = tf .one_hot (
1516+ tf .cast (labels , tf .int32 ),
1517+ depth = vocab_size ,
1518+ on_value = confidence ,
1519+ off_value = low_confidence )
14951520 xentropy = tf .nn .softmax_cross_entropy_with_logits (
14961521 logits = logits , labels = soft_targets )
14971522 return xentropy - normalizing
0 commit comments