Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 0bdfcbb

Browse files
T2T TeamRyan Sepassi
authored andcommitted
Use get_residual_fn to get the residual_fn in the transformer.
PiperOrigin-RevId: 163919630
1 parent c35c7a3 commit 0bdfcbb

File tree

4 files changed

+13
-9
lines changed

4 files changed

+13
-9
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ python -c "from tensor2tensor.models.transformer import Transformer"
180180
**Datasets** are all standardized on `TFRecord` files with `tensorflow.Example`
181181
protocol buffers. All datasets are registered and generated with the
182182
[data
183-
generator](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/bin/t2t-datagen)
183+
generator](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/generator.py)
184184
and many common sequence datasets are already available for generation and use.
185185

186186
### Problems and Modalities

tensor2tensor/bin/t2t-datagen renamed to tensor2tensor/data_generators/generator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python
21
# coding=utf-8
32
# Copyright 2017 The Tensor2Tensor Authors.
43
#

tensor2tensor/models/transformer.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,7 @@ def model_fn_body(self, features):
5656
(decoder_input, decoder_self_attention_bias) = transformer_prepare_decoder(
5757
targets, hparams)
5858

59-
def residual_fn(x, y):
60-
return common_layers.residual_fn(x, y,
61-
hparams.norm_type,
62-
hparams.residual_dropout,
63-
hparams.hidden_size,
64-
epsilon=hparams.layer_norm_epsilon)
59+
residual_fn = get_residual_fn(hparams)
6560

6661
encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout)
6762
decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.residual_dropout)
@@ -76,6 +71,17 @@ def residual_fn(x, y):
7671
return decoder_output
7772

7873

74+
def get_residual_fn(hparams):
75+
"""Get residual_fn."""
76+
def residual_fn(x, y):
77+
return common_layers.residual_fn(x, y,
78+
hparams.norm_type,
79+
hparams.residual_dropout,
80+
hparams.hidden_size,
81+
epsilon=hparams.layer_norm_epsilon)
82+
return residual_fn
83+
84+
7985
def transformer_prepare_encoder(inputs, target_space, hparams):
8086
"""Prepare one shard of the model for the encoder.
8187

tensor2tensor/bin/t2t-trainer renamed to tensor2tensor/trainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python
21
# coding=utf-8
32
# Copyright 2017 The Tensor2Tensor Authors.
43
#

0 commit comments

Comments
 (0)