Skip to content

Commit f0c4c92

Browse files
hertschuhcopybara-github
authored andcommitted
Explicitly import estimator from tensorflow as a separate import instead of
accessing it via tf.estimator and depend on the tensorflow estimator target. PiperOrigin-RevId: 439395101
1 parent 6fc7fd5 commit f0c4c92

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

tensorflow_graphics/projects/nasa/lib/models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Model Implementations."""
1515
import tensorflow.compat.v1 as tf
1616

17+
from tensorflow.compat.v1 import estimator as tf_estimator
1718
from tensorflow_graphics.projects.nasa.lib import model_utils
1819

1920
tf.disable_eager_execution()
@@ -37,7 +38,7 @@ def nasa(hparams):
3738
sample_bbox = hparams.sample_bbox
3839

3940
def _model_fn(features, labels, mode, params=None):
40-
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
41+
is_training = (mode == tf_estimator.ModeKeys.TRAIN)
4142
batch_size = features['point'].shape[0]
4243
n_sample_frames = features['point'].shape[1]
4344
accum_size = batch_size * n_sample_frames
@@ -116,7 +117,7 @@ def _model_fn(features, labels, mode, params=None):
116117
train_op = optimizer.minimize(
117118
indicator_loss, global_step=global_step, name='optimizer_shape')
118119

119-
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
120+
return tf_estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
120121

121122
return _model_fn
122123

tensorflow_graphics/projects/nasa/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616
import tensorflow.compat.v1 as tf
1717

18+
from tensorflow.compat.v1 import estimator as tf_estimator
1819
from tensorflow_graphics.projects.nasa.lib import datasets
1920
from tensorflow_graphics.projects.nasa.lib import models
2021
from tensorflow_graphics.projects.nasa.lib import utils
@@ -45,13 +46,13 @@ def main(unused_argv):
4546

4647
# Set up training.
4748
logging.info("=> Setting up training ...")
48-
run_config = tf.estimator.RunConfig(
49+
run_config = tf_estimator.RunConfig(
4950
model_dir=FLAGS.train_dir,
5051
save_checkpoints_steps=FLAGS.save_every,
5152
save_summary_steps=FLAGS.summary_every,
5253
keep_checkpoint_max=None,
5354
)
54-
trainer = tf.estimator.Estimator(
55+
trainer = tf_estimator.Estimator(
5556
model_fn=model_fn,
5657
config=run_config,
5758
)

0 commit comments

Comments
 (0)