Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 8fa33f6

Browse files
katelee168Ryan Sepassi
authored andcommitted
Add Gaussian label smoothing.
PiperOrigin-RevId: 174383193
1 parent 5aedc3d commit 8fa33f6

File tree

1 file changed

+33
-8
lines changed

1 file changed

+33
-8
lines changed

tensor2tensor/layers/common_layers.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,21 +1477,46 @@ def padded_cross_entropy(logits,
14771477
return tf.reduce_sum(xent * weights), tf.reduce_sum(weights)
14781478

14791479

1480-
def smoothing_cross_entropy(logits, labels, vocab_size, confidence):
1481-
"""Cross entropy with label smoothing to limit over-confidence."""
1480+
def smoothing_cross_entropy(logits, labels, vocab_size, confidence,
1481+
gaussian=False):
1482+
"""Cross entropy with label smoothing to limit over-confidence.
1483+
1484+
Args:
1485+
logits: Tensor of size [batch_size, ?, ?, ?, vocab_size]
1486+
labels: Tensor of size [batch_size, ?, ?, ?]
1487+
vocab_size: Tensor representing the size of the vocabulary.
1488+
confidence: Used to determine on and off values for label smoothing.
1489+
If `gaussian` is true, `confidence` is the variance to the gaussian
1490+
distribution.
1491+
gaussian: Uses a gaussian distribution for label smoothing
1492+
1493+
Returns:
1494+
1495+
"""
14821496
with tf.name_scope("smoothing_cross_entropy", [logits, labels]):
14831497
# Low confidence is given to all non-true labels, uniformly.
14841498
low_confidence = (1.0 - confidence) / tf.to_float(vocab_size - 1)
14851499
# Normalizing constant is the best cross-entropy value with soft targets.
14861500
# We subtract it just for readability, makes no difference on learning.
14871501
normalizing = -(confidence * tf.log(confidence) + tf.to_float(
14881502
vocab_size - 1) * low_confidence * tf.log(low_confidence + 1e-20))
1489-
# Soft targets.
1490-
soft_targets = tf.one_hot(
1491-
tf.cast(labels, tf.int32),
1492-
depth=vocab_size,
1493-
on_value=confidence,
1494-
off_value=low_confidence)
1503+
1504+
if gaussian:
1505+
labels = tf.cast(labels, tf.float32)
1506+
1507+
normal_dist = tf.distributions.Normal(loc=labels, scale=confidence)
1508+
# Locations to evaluate the probability distributions.
1509+
soft_targets = normal_dist.prob(tf.cast(tf.range(vocab_size), tf.float32)
1510+
[:, None, None, None, None])
1511+
# Reordering soft_targets from [vocab_size, batch_size, ?, ?, ?] to match
1512+
# logits: [batch_size, ?, ?, ?, vocab_size]
1513+
soft_targets = tf.transpose(soft_targets, perm=[1, 2, 3, 4, 0])
1514+
else:
1515+
soft_targets = tf.one_hot(
1516+
tf.cast(labels, tf.int32),
1517+
depth=vocab_size,
1518+
on_value=confidence,
1519+
off_value=low_confidence)
14951520
xentropy = tf.nn.softmax_cross_entropy_with_logits(
14961521
logits=logits, labels=soft_targets)
14971522
return xentropy - normalizing

0 commit comments

Comments
 (0)