Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit c1dbc36

Browse files
T2T TeamRyan Sepassi
authored andcommitted
Add expected_attention_loss_multiplier hparam to allow scaling the attention_loss.
PiperOrigin-RevId: 187549764
1 parent 975d7fc commit c1dbc36

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

tensor2tensor/layers/common_attention.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,17 @@ def add_standard_attention_hparams(hparams):
327327
return hparams
328328

329329

330-
def encoder_decoder_attention_loss(expected_attention, actual_attentions):
330+
def encoder_decoder_attention_loss(expected_attention,
331+
actual_attentions,
332+
loss_multiplier=1.0):
331333
"""Computes encdec attention loss between expected and actual attentions.
332334
333335
Args:
334336
expected_attention: Tensor storing the expected encoder-decoder attention
335337
weights with shape [batch_size, target_length, input_length].
336338
actual_attentions: Dictionary with actual attention weights for different
337339
attention types and hidden layers.
340+
loss_multiplier: multiplier for the attention loss.
338341
339342
Returns:
340343
MSE loss between the actual and expected attention weights.
@@ -351,8 +354,8 @@ def encoder_decoder_attention_loss(expected_attention, actual_attentions):
351354
# Reduce mean across all layers (axis=0) and all heads (axis=2) to get a
352355
# tensor with shape [batch_size, target_length, input_length].
353356
actual_attention_weights = tf.reduce_mean(actual_attention_weights, [0, 2])
354-
return tf.losses.mean_squared_error(expected_attention,
355-
actual_attention_weights)
357+
return tf.losses.mean_squared_error(
358+
expected_attention, actual_attention_weights) * loss_multiplier
356359

357360

358361
@expert_utils.add_name_scope()

tensor2tensor/models/transformer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ def body(self, features):
176176
expected_attentions = features.get("expected_attentions")
177177
if expected_attentions is not None:
178178
attention_loss = common_attention.encoder_decoder_attention_loss(
179-
expected_attentions, self.attention_weights)
179+
expected_attentions, self.attention_weights,
180+
hparams.expected_attention_loss_multiplier)
180181
return decoder_output, {"attention_loss": attention_loss}
181182

182183
return decoder_output
@@ -1462,3 +1463,11 @@ def transformer_librispeech_tpu():
14621463
librispeech.set_librispeech_length_hparams(hparams)
14631464
return hparams
14641465

1466+
1467+
@registry.register_hparams
1468+
def transformer_supervised_attention():
1469+
"""Hparams for supervised attention problems."""
1470+
hparams = transformer_base()
1471+
# Multiplier to the encoder-decoder expected attention loss.
1472+
hparams.add_hparam("expected_attention_loss_multiplier", 1.0)
1473+
return hparams

tensor2tensor/models/transformer_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ def testTransformerWithoutProblem(self):
208208
[BATCH_SIZE, TARGET_LENGTH, 1, hparams.hidden_size])
209209

210210
def testTransformerWithEncoderDecoderAttentionLoss(self):
211-
model, features = self.getModel(transformer.transformer_small())
211+
model, features = self.getModel(
212+
transformer.transformer_supervised_attention())
212213
expected_attention_weights = np.random.random_sample(
213214
size=(BATCH_SIZE, TARGET_LENGTH, INPUT_LENGTH))
214215
features["expected_attentions"] = tf.constant(

0 commit comments

Comments
 (0)