Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 4ebb860

Browse files
authored
Merge pull request #770 from deepsense-ai/more_games
Add more tari games.
2 parents 6cea4c4 + 032b595 commit 4ebb860

File tree

4 files changed

+80
-49
lines changed

4 files changed

+80
-49
lines changed

tensor2tensor/data_generators/gym.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
flags = tf.flags
4343
FLAGS = flags.FLAGS
4444

45-
flags.DEFINE_string("agent_policy_path", "", "File with model for pong")
45+
flags.DEFINE_string("agent_policy_path", "", "File with model for agent")
4646

4747

4848
class GymDiscreteProblem(video_utils.VideoProblem):
@@ -99,6 +99,14 @@ def env(self):
9999
def num_actions(self):
100100
return self.env.action_space.n
101101

102+
@property
103+
def frame_height(self):
104+
return self.env.observation_space.shape[0]
105+
106+
@property
107+
def frame_width(self):
108+
return self.env.observation_space.shape[1]
109+
102110
@property
103111
def num_rewards(self):
104112
raise NotImplementedError()
@@ -150,14 +158,6 @@ class GymPongRandom5k(GymDiscreteProblem):
150158
def env_name(self):
151159
return "PongDeterministic-v4"
152160

153-
@property
154-
def frame_height(self):
155-
return 210
156-
157-
@property
158-
def frame_width(self):
159-
return 160
160-
161161
@property
162162
def min_reward(self):
163163
return -1
@@ -179,9 +179,38 @@ class GymPongRandom50k(GymPongRandom5k):
179179
def num_steps(self):
180180
return 50000
181181

182+
@registry.register_problem
183+
class GymFreewayRandom5k(GymDiscreteProblem):
184+
"""Freeway game, random actions."""
185+
186+
@property
187+
def env_name(self):
188+
return "FreewayDeterministic-v4"
189+
190+
@property
191+
def min_reward(self):
192+
return 0
193+
194+
@property
195+
def num_rewards(self):
196+
return 2
197+
198+
@property
199+
def num_steps(self):
200+
return 5000
201+
202+
203+
@registry.register_problem
204+
class GymFreewayRandom50k(GymFreewayRandom5k):
205+
"""Freeway game, random actions."""
206+
207+
@property
208+
def num_steps(self):
209+
return 50000
210+
182211

183212
@registry.register_problem
184-
class GymDiscreteProblemWithAgent(GymPongRandom5k):
213+
class GymDiscreteProblemWithAgent(GymDiscreteProblem):
185214
"""Gym environment with discrete actions and rewards and an agent."""
186215

187216
def __init__(self, *args, **kwargs):
@@ -190,7 +219,7 @@ def __init__(self, *args, **kwargs):
190219
self.debug_dump_frames_path = "debug_frames_env"
191220

192221
# defaults
193-
self.environment_spec = lambda: gym.make("PongDeterministic-v4")
222+
self.environment_spec = lambda: gym.make(self.env_name)
194223
self.in_graph_wrappers = []
195224
self.collect_hparams = rl.atari_base()
196225
self.settable_num_steps = 20000
@@ -286,3 +315,23 @@ def restore_networks(self, sess):
286315
ckpts = tf.train.get_checkpoint_state(FLAGS.output_dir)
287316
ckpt = ckpts.model_checkpoint_path
288317
env_model_loader.restore(sess, ckpt)
318+
319+
320+
@registry.register_problem
321+
class GymSimulatedDiscreteProblemWithAgentOnPong(GymSimulatedDiscreteProblemWithAgent, GymPongRandom5k):
322+
pass
323+
324+
325+
@registry.register_problem
326+
class GymDiscreteProblemWithAgentOnPong(GymDiscreteProblemWithAgent, GymPongRandom5k):
327+
pass
328+
329+
330+
@registry.register_problem
331+
class GymSimulatedDiscreteProblemWithAgentOnFreeway(GymSimulatedDiscreteProblemWithAgent, GymFreewayRandom5k):
332+
pass
333+
334+
335+
@registry.register_problem
336+
class GymDiscreteProblemWithAgentOnFreeway(GymDiscreteProblemWithAgent, GymFreewayRandom5k):
337+
pass

tensor2tensor/rl/envs/simulated_batch_env.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,13 @@
3232
from tensor2tensor.utils import trainer_lib
3333

3434
import tensorflow as tf
35+
import numpy as np
3536

3637

3738
flags = tf.flags
3839
FLAGS = flags.FLAGS
3940

4041

41-
flags.DEFINE_string("frames_path", "", "Path to the first frames.")
42-
43-
4442
class SimulatedBatchEnv(InGraphBatchEnv):
4543
"""Batch of environments inside the TensorFlow graph.
4644
@@ -49,42 +47,31 @@ class SimulatedBatchEnv(InGraphBatchEnv):
4947
flags are held in according variables.
5048
"""
5149

52-
def __init__(self, length, observ_shape, observ_dtype, action_shape,
53-
action_dtype):
50+
def __init__(self, environment_lambda, length):
5451
"""Batch of environments inside the TensorFlow graph."""
5552
self.length = length
53+
initalization_env = environment_lambda()
5654
hparams = trainer_lib.create_hparams(
5755
FLAGS.hparams_set, problem_name=FLAGS.problem, data_dir="UNUSED")
5856
hparams.force_full_predict = True
5957
self._model = registry.model(FLAGS.model)(
6058
hparams, tf.estimator.ModeKeys.PREDICT)
6159

62-
self.action_shape = action_shape
63-
self.action_dtype = action_dtype
64-
65-
with open(os.path.join(FLAGS.frames_path, "frame1.png"), "rb") as f:
66-
png_frame_1_raw = f.read()
60+
self.action_space = initalization_env.action_space
61+
self.action_shape = list(initalization_env.action_space.shape)
62+
self.action_dtype = tf.int32
6763

68-
with open(os.path.join(FLAGS.frames_path, "frame2.png"), "rb") as f:
69-
png_frame_2_raw = f.read()
64+
obs_1 = initalization_env.reset()
65+
obs_2 = initalization_env.step(0)[0]
7066

71-
self.frame_1 = tf.expand_dims(tf.cast(tf.image.decode_png(png_frame_1_raw),
72-
tf.float32), 0)
73-
self.frame_2 = tf.expand_dims(tf.cast(tf.image.decode_png(png_frame_2_raw),
74-
tf.float32), 0)
67+
self.frame_1 = tf.expand_dims(tf.cast(obs_1, tf.float32), 0)
68+
self.frame_2 = tf.expand_dims(tf.cast(obs_2, tf.float32), 0)
7569

76-
shape = (self.length,) + observ_shape
77-
self._observ = tf.Variable(tf.zeros(shape, observ_dtype), trainable=False)
78-
self._prev_observ = tf.Variable(tf.zeros(shape, observ_dtype),
70+
shape = (self.length,) + initalization_env.observation_space.shape
71+
# TODO(blazej0) - make more generic - make higher number of previous observations possible.
72+
self._observ = tf.Variable(tf.zeros(shape, tf.float32), trainable=False)
73+
self._prev_observ = tf.Variable(tf.zeros(shape, tf.float32),
7974
trainable=False)
80-
self._starting_observ = tf.Variable(tf.zeros(shape, observ_dtype),
81-
trainable=False)
82-
83-
observ_dtype = tf.int64
84-
85-
@property
86-
def action_space(self):
87-
return gym.make("PongNoFrameskip-v4").action_space
8875

8976
def __len__(self):
9077
"""Number of combined environments."""

tensor2tensor/rl/envs/utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def batch_env_factory(environment_lambda, hparams, num_agents, xvfb=False):
287287
hparams, "in_graph_wrappers") else []
288288

289289
if hparams.simulated_environment:
290-
cur_batch_env = define_simulated_batch_env(num_agents)
290+
cur_batch_env = define_simulated_batch_env(environment_lambda, num_agents)
291291
else:
292292
cur_batch_env = define_batch_env(environment_lambda, num_agents, xvfb=xvfb)
293293
for w in wrappers:
@@ -306,12 +306,6 @@ def define_batch_env(constructor, num_agents, xvfb=False):
306306
return env
307307

308308

309-
def define_simulated_batch_env(num_agents):
310-
# TODO(blazej0): the parameters should be infered.
311-
observ_shape = (210, 160, 3)
312-
observ_dtype = tf.float32
313-
action_shape = []
314-
action_dtype = tf.int32
315-
cur_batch_env = simulated_batch_env.SimulatedBatchEnv(
316-
num_agents, observ_shape, observ_dtype, action_shape, action_dtype)
309+
def define_simulated_batch_env(environment_lambda, num_agents):
310+
cur_batch_env = simulated_batch_env.SimulatedBatchEnv(environment_lambda, num_agents)
317311
return cur_batch_env

tensor2tensor/rl/model_rl_experiment.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def train(hparams, output_dir):
5151
time_delta = time.time() - start_time
5252
print(line+"Step {}.1. - generate data from policy. "
5353
"Time: {}".format(iloop, str(datetime.timedelta(seconds=time_delta))))
54-
FLAGS.problem = "gym_discrete_problem_with_agent"
54+
FLAGS.problem = "gym_discrete_problem_with_agent_on_{}".format(hparams.game)
5555
FLAGS.agent_policy_path = last_model
5656
gym_problem = registry.problem(FLAGS.problem)
5757
gym_problem.settable_num_steps = hparams.true_env_generator_num_steps
@@ -76,7 +76,7 @@ def train(hparams, output_dir):
7676
print(line+"Step {}.3. - evalue env model. "
7777
"Time: {}".format(iloop, str(datetime.timedelta(seconds=time_delta))))
7878
gym_simulated_problem = registry.problem(
79-
"gym_simulated_discrete_problem_with_agent")
79+
"gym_simulated_discrete_problem_with_agent_on_{}".format(hparams.game))
8080
sim_steps = hparams.simulated_env_generator_num_steps
8181
gym_simulated_problem.settable_num_steps = sim_steps
8282
gym_simulated_problem.generate_data(iter_data_dir, tmp_dir)
@@ -115,6 +115,7 @@ def main(_):
115115
simulated_env_generator_num_steps=300,
116116
ppo_epochs_num=200,
117117
ppo_epoch_length=300,
118+
game="pong",
118119
)
119120
train(hparams, FLAGS.output_dir)
120121

0 commit comments

Comments
 (0)