Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f65b5e4

Browse files
rjpowercopybara-github
authored andcommitted
Add basic support for TF2 modeling.
This is not complete, but can be extended to add support for TPUs and more models as required. Tested on CPU/GPU with the following configuration: PROBLEM=translate_envi_iwslt32k MODEL=transformer HPARAMS=transformer_base_single_gpu DATA_DIR=$HOME/t2t_data TMP_DIR=/tmp/t2t_datagen TRAIN_DIR=$HOME/t2t_train/$PROBLEM/$MODEL-$HPARAMS t2t-trainer --data_dir=$DATA_DIR --problem=$PROBLEM --model=$MODEL --hparams_set=$HPARAMS --output_dir=$TRAIN_DIR Verified the loss decreases as expected and checkpoints etc work. PiperOrigin-RevId: 312557333
1 parent c762954 commit f65b5e4

File tree

8 files changed

+144
-48
lines changed

8 files changed

+144
-48
lines changed

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,15 @@
6363
'scipy',
6464
'six>=1.12.0',
6565
'sympy',
66+
'tensorflow-addons',
6667
'tensorflow-datasets',
6768
'tensorflow-gan',
6869
'tensorflow-probability==0.7.0',
70+
'tf_slim',
6971
'tqdm',
7072
],
7173
extras_require={
72-
'tensorflow': ['tensorflow>=1.15.0,<2.0'],
74+
'tensorflow': ['tensorflow>=1.15.0'],
7375
'tensorflow-hub': ['tensorflow-hub>=0.1.1'],
7476
'tests': [
7577
# Needed to fix a Travis pytest error.

tensor2tensor/bin/t2t-datagen

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ from __future__ import print_function
1717

1818
from tensor2tensor.bin import t2t_datagen
1919

20-
import tensorflow as tf
20+
import tensorflow.compat.v1 as tf
2121

2222
def main(argv):
2323
t2t_datagen.main(argv)

tensor2tensor/bin/t2t-trainer

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ from __future__ import print_function
2222

2323
from tensor2tensor.bin import t2t_trainer
2424

25-
import tensorflow as tf
25+
import tensorflow.compat.v1 as tf
2626

2727
def main(argv):
2828
t2t_trainer.main(argv)
2929

3030

3131
if __name__ == "__main__":
3232
tf.logging.set_verbosity(tf.logging.INFO)
33-
tf.app.run()
33+
tf.app.run(main)

tensor2tensor/bin/t2t_trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@
3636
from tensor2tensor.utils import usr_dir
3737
import tensorflow.compat.v1 as tf
3838

39-
from tensorflow.contrib.tpu.python.tpu import tpu_config
40-
4139

4240
flags = tf.flags
4341
FLAGS = flags.FLAGS
@@ -242,8 +240,10 @@ def create_run_config(hp, output_dir=None):
242240
save_ckpt_steps = None # Disable the default saver
243241
save_ckpt_secs = None # Disable the default saver
244242
tpu_config_extra_kwargs = {
245-
"num_cores_per_replica": 1,
246-
"per_host_input_for_training": tpu_config.InputPipelineConfig.BROADCAST,
243+
"num_cores_per_replica":
244+
1,
245+
"per_host_input_for_training":
246+
tf.estimator.tpu.InputPipelineConfig.BROADCAST,
247247
}
248248

249249
# the various custom getters we have written do not play well together yet.

tensor2tensor/models/__init__.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,6 @@
3030
from tensor2tensor.models import image_transformer
3131
from tensor2tensor.models import image_transformer_2d
3232
from tensor2tensor.models import lstm
33-
from tensor2tensor.models import mtf_image_transformer
34-
from tensor2tensor.models import mtf_resnet
35-
from tensor2tensor.models import mtf_transformer
36-
from tensor2tensor.models import mtf_transformer2
3733
from tensor2tensor.models import neural_assistant
3834
from tensor2tensor.models import neural_gpu
3935
from tensor2tensor.models import resnet
@@ -47,15 +43,9 @@
4743
from tensor2tensor.models.neural_architecture_search import nas_model
4844
from tensor2tensor.models.research import adafactor_experiments
4945
from tensor2tensor.models.research import aligned
50-
from tensor2tensor.models.research import attention_lm
51-
from tensor2tensor.models.research import attention_lm_moe
5246
from tensor2tensor.models.research import autoencoders
5347
from tensor2tensor.models.research import cycle_gan
5448
from tensor2tensor.models.research import gene_expression
55-
from tensor2tensor.models.research import glow
56-
from tensor2tensor.models.research import lm_experiments
57-
from tensor2tensor.models.research import moe_experiments
58-
from tensor2tensor.models.research import multiquery_paper
5949
from tensor2tensor.models.research import neural_stack
6050
from tensor2tensor.models.research import rl
6151
from tensor2tensor.models.research import shuffle_network
@@ -69,19 +59,34 @@
6959
from tensor2tensor.models.research import transformer_symshard
7060
from tensor2tensor.models.research import transformer_vae
7161
from tensor2tensor.models.research import universal_transformer
72-
from tensor2tensor.models.research import vqa_attention
73-
from tensor2tensor.models.research import vqa_recurrent_self_attention
74-
from tensor2tensor.models.research import vqa_self_attention
7562
from tensor2tensor.models.video import basic_deterministic
7663
from tensor2tensor.models.video import basic_recurrent
7764
from tensor2tensor.models.video import basic_stochastic
7865
from tensor2tensor.models.video import emily
79-
from tensor2tensor.models.video import epva
80-
from tensor2tensor.models.video import next_frame_glow
8166
from tensor2tensor.models.video import savp
8267
from tensor2tensor.models.video import sv2p
68+
from tensor2tensor.utils import contrib
8369
from tensor2tensor.utils import registry
8470

71+
# The following models can't be imported under TF2
72+
if not contrib.is_tf2:
73+
# pylint: disable=g-import-not-at-top
74+
from tensor2tensor.models.research import attention_lm
75+
from tensor2tensor.models.research import attention_lm_moe
76+
from tensor2tensor.models.research import glow
77+
from tensor2tensor.models.research import lm_experiments
78+
from tensor2tensor.models.research import moe_experiments
79+
from tensor2tensor.models.research import multiquery_paper
80+
from tensor2tensor.models import mtf_image_transformer
81+
from tensor2tensor.models import mtf_resnet
82+
from tensor2tensor.models import mtf_transformer
83+
from tensor2tensor.models import mtf_transformer2
84+
from tensor2tensor.models.research import vqa_attention
85+
from tensor2tensor.models.research import vqa_recurrent_self_attention
86+
from tensor2tensor.models.research import vqa_self_attention
87+
from tensor2tensor.models.video import epva
88+
from tensor2tensor.models.video import next_frame_glow
89+
# pylint: enable=g-import-not-at-top
8590

8691
# pylint: disable=unused-import
8792

tensor2tensor/utils/contrib.py

Lines changed: 80 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,40 @@
2323
from __future__ import print_function # Not necessary in a Python 3-only module
2424

2525
from absl import logging
26-
from tensorflow.python import tf2 # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
27-
is_tf2 = tf2.enabled()
26+
import tensorflow.compat.v1 as tf
27+
28+
# Check if we have contrib available
29+
try:
30+
from tensorflow.contrib import slim as tf_slim # pylint: disable=g-import-not-at-top
31+
is_tf2 = False
32+
except: # pylint: disable=bare-except
33+
# tf.contrib, including slim and certain optimizers are not available in TF2
34+
# Some features are now available in separate packages. We shim support for
35+
# these as needed.
36+
import tensorflow_addons as tfa # pylint: disable=g-import-not-at-top
37+
import tf_slim # pylint: disable=g-import-not-at-top
38+
is_tf2 = True
2839

2940

3041
def err_if_tf2(msg='err'):
3142
if is_tf2:
32-
msg = 'contrib is unavailable in tf2.'
3343
if msg == 'err':
44+
msg = 'contrib is unavailable in tf2.'
3445
raise ImportError(msg)
3546
else:
47+
msg = 'contrib is unavailable in tf2.'
3648
logging.info(msg)
3749

3850

51+
class DummyModule(object):
52+
53+
def __init__(self, **kw):
54+
for k, v in kw.items():
55+
setattr(self, k, v)
56+
57+
3958
def slim():
40-
err_if_tf2()
41-
from tensorflow.contrib import slim as contrib_slim # pylint: disable=g-import-not-at-top
42-
return contrib_slim
59+
return tf_slim
4360

4461

4562
def util():
@@ -54,8 +71,26 @@ def tfe():
5471
return contrib_eager
5572

5673

74+
def deprecated(reason, date):
75+
del reason
76+
del date
77+
def decorator(fn):
78+
return fn
79+
return decorator
80+
81+
5782
def framework(msg='err'):
58-
err_if_tf2(msg=msg)
83+
"""Return framework module or dummy version."""
84+
del msg
85+
if is_tf2:
86+
return DummyModule(
87+
arg_scope=None,
88+
get_name_scope=lambda: tf.get_default_graph().get_name_scope(),
89+
name_scope=tf.name_scope,
90+
deprecated=deprecated,
91+
nest=tf.nest,
92+
argsort=tf.argsort)
93+
5994
from tensorflow.contrib import framework as contrib_framework # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
6095
return contrib_framework
6196

@@ -67,9 +102,13 @@ def nn():
67102

68103

69104
def layers():
70-
err_if_tf2(msg='err')
71-
from tensorflow.contrib import layers as contrib_layers # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
72-
return contrib_layers
105+
"""Return layers module or dummy version."""
106+
try:
107+
from tensorflow.contrib import layers as contrib_layers # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
108+
return contrib_layers
109+
except: # pylint: disable=bare-except
110+
return DummyModule(
111+
OPTIMIZER_CLS_NAMES={}, optimize_loss=tf_slim.optimize_loss)
73112

74113

75114
def rnn():
@@ -109,9 +148,13 @@ def metrics():
109148

110149

111150
def opt():
112-
err_if_tf2(msg='err')
113-
from tensorflow.contrib import opt as contrib_opt # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
114-
return contrib_opt
151+
if not is_tf2:
152+
from tensorflow.contrib import opt as contrib_opt # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
153+
return contrib_opt
154+
return DummyModule(
155+
LazyAdam=tfa.optimizers.LazyAdam,
156+
LazyAdamOptimizer=tfa.optimizers.LazyAdam,
157+
)
115158

116159

117160
def mixed_precision():
@@ -132,10 +175,31 @@ def distribute():
132175
return contrib_distribute
133176

134177

178+
def replace_monitors_with_hooks(monitors_or_hooks, estimator):
179+
"""Stub for missing function."""
180+
del estimator
181+
monitors_or_hooks = monitors_or_hooks or []
182+
hooks = [
183+
m for m in monitors_or_hooks if isinstance(m, tf.estimator.SessionRunHook)
184+
]
185+
deprecated_monitors = [
186+
m for m in monitors_or_hooks
187+
if not isinstance(m, tf.estimator.SessionRunHook)
188+
]
189+
assert not deprecated_monitors
190+
return hooks
191+
192+
135193
def learn():
136-
err_if_tf2(msg='err')
137-
from tensorflow.contrib import learn as contrib_learn # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
138-
return contrib_learn
194+
"""Return tf.contrib.learn module or dummy version."""
195+
if not is_tf2:
196+
from tensorflow.contrib import learn as contrib_learn # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
197+
return contrib_learn
198+
return DummyModule(
199+
RunConfig=tf.estimator.RunConfig,
200+
monitors=DummyModule(
201+
replace_monitors_with_hooks=replace_monitors_with_hooks),
202+
)
139203

140204

141205
def tf_prof():

tensor2tensor/utils/optimize.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,22 @@ def optimize(loss,
112112

113113
@registry.register_optimizer
114114
def adam(learning_rate, hparams):
115+
"""Return adam optimizer for the given params."""
115116
# We change the default epsilon for Adam.
116117
# Using LazyAdam as it's much faster for large vocabulary embeddings.
117-
return contrib.opt().LazyAdamOptimizer(
118-
learning_rate,
119-
beta1=hparams.optimizer_adam_beta1,
120-
beta2=hparams.optimizer_adam_beta2,
121-
epsilon=hparams.optimizer_adam_epsilon)
118+
if contrib.is_tf2:
119+
# in TF2 beta1 -> beta_1 :/
120+
return contrib.opt().LazyAdamOptimizer(
121+
learning_rate,
122+
beta_1=hparams.optimizer_adam_beta1,
123+
beta_2=hparams.optimizer_adam_beta2,
124+
epsilon=hparams.optimizer_adam_epsilon)
125+
else:
126+
return contrib.opt().LazyAdamOptimizer(
127+
learning_rate,
128+
beta1=hparams.optimizer_adam_beta1,
129+
beta2=hparams.optimizer_adam_beta2,
130+
epsilon=hparams.optimizer_adam_epsilon)
122131

123132

124133
@registry.register_optimizer
@@ -229,7 +238,12 @@ def __init__(self, optimizer_name, lr, hparams, use_tpu=False): # pylint: disab
229238
self._zero_grads = hparams.optimizer_zero_grads
230239

231240
def compute_gradients(self, loss, var_list=None, **kwargs): # pylint: disable=arguments-differ
232-
gradients = self._opt.compute_gradients(loss, var_list, **kwargs)
241+
if contrib.is_tf2:
242+
gradients = self._opt.get_gradients(loss, var_list)
243+
gradients = zip(gradients, var_list)
244+
else:
245+
gradients = self._opt.compute_gradients(loss, var_list, **kwargs)
246+
233247
def cast_grad(g, v):
234248
if v is not None and g is not None:
235249
g = common_layers.cast_like(g, v)
@@ -240,8 +254,13 @@ def cast_grad(g, v):
240254
return gradients
241255

242256
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
243-
return self._opt.apply_gradients(
244-
grads_and_vars, global_step=global_step, name=name)
257+
if contrib.is_tf2:
258+
with tf.control_dependencies(
259+
[tf.assign_add(tf.train.get_or_create_global_step(), 1)]):
260+
return self._opt.apply_gradients(grads_and_vars, name=name)
261+
else:
262+
return self._opt.apply_gradients(
263+
grads_and_vars, global_step=global_step, name=name)
245264

246265

247266
def weight_decay_and_noise(loss, hparams, learning_rate, var_list=None):

tensor2tensor/utils/trainer_lib.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def create_run_config(model_name,
200200
"keep_checkpoint_max": keep_checkpoint_max,
201201
"keep_checkpoint_every_n_hours": keep_checkpoint_every_n_hours,
202202
"tf_random_seed": random_seed,
203-
"log_step_count_steps": log_step_count_steps
203+
"log_step_count_steps": log_step_count_steps,
204204
}
205205
if save_checkpoints_secs:
206206
del run_config_args["save_checkpoints_steps"]
@@ -239,6 +239,12 @@ def create_run_config(model_name,
239239
del run_config_args["master"]
240240
del run_config_args["evaluation_master"]
241241

242+
# tf.estimator RunConfig construction got totally broken in TF2.
243+
# we now have to specify master in a global environment variable
244+
if contrib.is_tf2:
245+
del run_config_args["evaluation_master"]
246+
del run_config_args["master"]
247+
242248
config = run_config_cls(**run_config_args)
243249

244250
# If not using TPU, add device info for data_parallelism

0 commit comments

Comments
 (0)