2323from __future__ import print_function # Not necessary in a Python 3-only module
2424
2525from absl import logging
26- from tensorflow .python import tf2 # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
27- is_tf2 = tf2 .enabled ()
26+ import tensorflow .compat .v1 as tf
27+
28+ # Check if we have contrib available
29+ try :
30+ from tensorflow .contrib import slim as tf_slim # pylint: disable=g-import-not-at-top
31+ is_tf2 = False
32+ except : # pylint: disable=bare-except
33+ # tf.contrib, including slim and certain optimizers are not available in TF2
34+ # Some features are now available in separate packages. We shim support for
35+ # these as needed.
36+ import tensorflow_addons as tfa # pylint: disable=g-import-not-at-top
37+ import tf_slim # pylint: disable=g-import-not-at-top
38+ is_tf2 = True
2839
2940
3041def err_if_tf2 (msg = 'err' ):
3142 if is_tf2 :
32- msg = 'contrib is unavailable in tf2.'
3343 if msg == 'err' :
44+ msg = 'contrib is unavailable in tf2.'
3445 raise ImportError (msg )
3546 else :
47+ msg = 'contrib is unavailable in tf2.'
3648 logging .info (msg )
3749
3850
51+ class DummyModule (object ):
52+
53+ def __init__ (self , ** kw ):
54+ for k , v in kw .items ():
55+ setattr (self , k , v )
56+
57+
3958def slim ():
40- err_if_tf2 ()
41- from tensorflow .contrib import slim as contrib_slim # pylint: disable=g-import-not-at-top
42- return contrib_slim
59+ return tf_slim
4360
4461
4562def util ():
@@ -54,8 +71,26 @@ def tfe():
5471 return contrib_eager
5572
5673
74+ def deprecated (reason , date ):
75+ del reason
76+ del date
77+ def decorator (fn ):
78+ return fn
79+ return decorator
80+
81+
5782def framework (msg = 'err' ):
58- err_if_tf2 (msg = msg )
83+ """Return framework module or dummy version."""
84+ del msg
85+ if is_tf2 :
86+ return DummyModule (
87+ arg_scope = None ,
88+ get_name_scope = lambda : tf .get_default_graph ().get_name_scope (),
89+ name_scope = tf .name_scope ,
90+ deprecated = deprecated ,
91+ nest = tf .nest ,
92+ argsort = tf .argsort )
93+
5994 from tensorflow .contrib import framework as contrib_framework # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
6095 return contrib_framework
6196
@@ -67,9 +102,13 @@ def nn():
67102
68103
69104def layers ():
70- err_if_tf2 (msg = 'err' )
71- from tensorflow .contrib import layers as contrib_layers # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
72- return contrib_layers
105+ """Return layers module or dummy version."""
106+ try :
107+ from tensorflow .contrib import layers as contrib_layers # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
108+ return contrib_layers
109+ except : # pylint: disable=bare-except
110+ return DummyModule (
111+ OPTIMIZER_CLS_NAMES = {}, optimize_loss = tf_slim .optimize_loss )
73112
74113
75114def rnn ():
@@ -109,9 +148,13 @@ def metrics():
109148
110149
111150def opt ():
112- err_if_tf2 (msg = 'err' )
113- from tensorflow .contrib import opt as contrib_opt # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
114- return contrib_opt
151+ if not is_tf2 :
152+ from tensorflow .contrib import opt as contrib_opt # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
153+ return contrib_opt
154+ return DummyModule (
155+ LazyAdam = tfa .optimizers .LazyAdam ,
156+ LazyAdamOptimizer = tfa .optimizers .LazyAdam ,
157+ )
115158
116159
117160def mixed_precision ():
@@ -132,10 +175,31 @@ def distribute():
132175 return contrib_distribute
133176
134177
178+ def replace_monitors_with_hooks (monitors_or_hooks , estimator ):
179+ """Stub for missing function."""
180+ del estimator
181+ monitors_or_hooks = monitors_or_hooks or []
182+ hooks = [
183+ m for m in monitors_or_hooks if isinstance (m , tf .estimator .SessionRunHook )
184+ ]
185+ deprecated_monitors = [
186+ m for m in monitors_or_hooks
187+ if not isinstance (m , tf .estimator .SessionRunHook )
188+ ]
189+ assert not deprecated_monitors
190+ return hooks
191+
192+
135193def learn ():
136- err_if_tf2 (msg = 'err' )
137- from tensorflow .contrib import learn as contrib_learn # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
138- return contrib_learn
194+ """Return tf.contrib.learn module or dummy version."""
195+ if not is_tf2 :
196+ from tensorflow .contrib import learn as contrib_learn # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
197+ return contrib_learn
198+ return DummyModule (
199+ RunConfig = tf .estimator .RunConfig ,
200+ monitors = DummyModule (
201+ replace_monitors_with_hooks = replace_monitors_with_hooks ),
202+ )
139203
140204
141205def tf_prof ():
0 commit comments