Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit d7dfe79

Browse files
author
Błażej O
committed
Fixing variables scope in loading policy network.
1 parent 2618c54 commit d7dfe79

File tree

1 file changed

+13
-8
lines changed
  • tensor2tensor/data_generators

1 file changed

+13
-8
lines changed

tensor2tensor/data_generators/gym.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# Dependency imports
2525

2626
import numpy as np
27+
import functools
2728
import gym
2829

2930
from tensor2tensor.rl import rl_trainer_lib
@@ -168,21 +169,25 @@ def __init__(self, event_dir, *args, **kwargs):
168169
self._event_dir = event_dir
169170
env_spec = lambda: atari_wrappers.wrap_atari(
170171
gym.make("PongNoFrameskip-v4"), warp=False, frame_skip=4, frame_stack=False)
171-
_1, _2, policy_factory = rl_trainer_lib.define_train(rl.atari_base(), env_spec, event_dir=None)
172+
hparams = rl.atari_base()
173+
with tf.variable_scope("train"):
174+
policy_lambda = hparams.network
175+
policy_factory = tf.make_template(
176+
"network",
177+
functools.partial(policy_lambda, env_spec().action_space, hparams))
178+
self._max_frame_pl = tf.placeholder(tf.float32, self.env.observation_space.shape)
179+
actor_critic = policy_factory(tf.expand_dims(tf.expand_dims(self._max_frame_pl, 0), 0))
180+
policy = actor_critic.policy
181+
self._last_policy_op = policy.mode()
172182
self._last_action = self.env.action_space.sample()
173183
self._skip = 4
174184
self._skip_step = 0
175185
self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape, dtype=np.uint8)
176186
self._sess = tf.Session()
177-
model_saver = tf.train.Saver()
187+
model_saver = tf.train.Saver(tf.global_variables(".*network_parameters.*"))
178188
model_saver.restore(self._sess, FLAGS.model_path)
179189

180-
self._max_frame_pl = tf.placeholder(tf.float32, self.env.observation_space.shape)
181-
actor_critic = policy_factory(tf.expand_dims(tf.expand_dims(self._max_frame_pl, 0), 0))
182-
policy = actor_critic.policy
183-
self._last_policy_op = policy.mode()
184-
185-
# TODO(blazej): For training of atari agents wrappers are usually used.
190+
# TODO(blazej0): For training of atari agents wrappers are usually used.
186191
# Below we have a hacky solution which is a temporary workaround to be used together
187192
# with atari_wrappers.MaxAndSkipEnv.
188193
def get_action(self, observation=None):

0 commit comments

Comments
 (0)