|
24 | 24 | # Dependency imports |
25 | 25 |
|
26 | 26 | import numpy as np |
| 27 | +import functools |
27 | 28 | import gym |
28 | 29 |
|
29 | 30 | from tensor2tensor.rl import rl_trainer_lib |
@@ -168,21 +169,25 @@ def __init__(self, event_dir, *args, **kwargs): |
168 | 169 | self._event_dir = event_dir |
169 | 170 | env_spec = lambda: atari_wrappers.wrap_atari( |
170 | 171 | gym.make("PongNoFrameskip-v4"), warp=False, frame_skip=4, frame_stack=False) |
171 | | - _1, _2, policy_factory = rl_trainer_lib.define_train(rl.atari_base(), env_spec, event_dir=None) |
| 172 | + hparams = rl.atari_base() |
| 173 | + with tf.variable_scope("train"): |
| 174 | + policy_lambda = hparams.network |
| 175 | + policy_factory = tf.make_template( |
| 176 | + "network", |
| 177 | + functools.partial(policy_lambda, env_spec().action_space, hparams)) |
| 178 | + self._max_frame_pl = tf.placeholder(tf.float32, self.env.observation_space.shape) |
| 179 | + actor_critic = policy_factory(tf.expand_dims(tf.expand_dims(self._max_frame_pl, 0), 0)) |
| 180 | + policy = actor_critic.policy |
| 181 | + self._last_policy_op = policy.mode() |
172 | 182 | self._last_action = self.env.action_space.sample() |
173 | 183 | self._skip = 4 |
174 | 184 | self._skip_step = 0 |
175 | 185 | self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape, dtype=np.uint8) |
176 | 186 | self._sess = tf.Session() |
177 | | - model_saver = tf.train.Saver() |
| 187 | + model_saver = tf.train.Saver(tf.global_variables(".*network_parameters.*")) |
178 | 188 | model_saver.restore(self._sess, FLAGS.model_path) |
179 | 189 |
|
180 | | - self._max_frame_pl = tf.placeholder(tf.float32, self.env.observation_space.shape) |
181 | | - actor_critic = policy_factory(tf.expand_dims(tf.expand_dims(self._max_frame_pl, 0), 0)) |
182 | | - policy = actor_critic.policy |
183 | | - self._last_policy_op = policy.mode() |
184 | | - |
185 | | - # TODO(blazej): For training of atari agents wrappers are usually used. |
| 190 | + # TODO(blazej0): For training of atari agents wrappers are usually used. |
186 | 191 | # Below we have a hacky solution which is a temporary workaround to be used together |
187 | 192 | # with atari_wrappers.MaxAndSkipEnv. |
188 | 193 | def get_action(self, observation=None): |
|
0 commit comments