1818from __future__ import division
1919from __future__ import print_function
2020
21- from collections import deque
22-
2321import functools
22+ import os
23+
2424# Dependency imports
25+
2526import gym
2627
2728from tensor2tensor .data_generators import problem
@@ -62,9 +63,7 @@ def num_target_frames(self):
6263 return 1
6364
6465 def eval_metrics (self ):
65- eval_metrics = [
66- metrics .Metrics .ACC , metrics .Metrics .ACC_PER_SEQ ,
67- metrics .Metrics .NEG_LOG_PERPLEXITY ]
66+ eval_metrics = [metrics .Metrics .ACC , metrics .Metrics .ACC_PER_SEQ ]
6867 return eval_metrics
6968
7069 @property
@@ -108,6 +107,10 @@ def num_rewards(self):
108107 def num_steps (self ):
109108 raise NotImplementedError ()
110109
110+ @property
111+ def total_number_of_frames (self ):
112+ return self .num_steps
113+
111114 @property
112115 def min_reward (self ):
113116 raise NotImplementedError ()
@@ -126,13 +129,13 @@ def hparams(self, defaults, unused_model_hparams):
126129 p .target_space_id = problem .SpaceID .IMAGE
127130
128131 def generate_samples (self , data_dir , tmp_dir , unused_dataset_split ):
129- next_obs = self .env .reset ()
132+ next_observation = self .env .reset ()
130133 for _ in range (self .num_steps ):
131- observation = next_obs
134+ observation = next_observation
132135 action = self .get_action (observation )
133- next_obs , reward , done , _ = self .env .step (action )
136+ next_observation , reward , done , _ = self .env .step (action )
134137 if done :
135- next_obs = self .env .reset ()
138+ next_observation = self .env .reset ()
136139 yield {"frame" : observation ,
137140 "action" : [action ],
138141 "done" : [done ],
@@ -184,23 +187,22 @@ class GymDiscreteProblemWithAgent(GymPongRandom5k):
184187 def __init__ (self , * args , ** kwargs ):
185188 super (GymDiscreteProblemWithAgent , self ).__init__ (* args , ** kwargs )
186189 self ._env = None
187- self .history_size = 2
190+ self .debug_dump_frames_path = "debug_frames_env"
188191
189192 # defaults
190193 self .environment_spec = lambda : gym .make ("PongDeterministic-v4" )
191- self .in_graph_wrappers = [( atari . MaxAndSkipWrapper , { "skip" : 4 }) ]
194+ self .in_graph_wrappers = []
192195 self .collect_hparams = rl .atari_base ()
193- self .settable_num_steps = 1000
196+ self .settable_num_steps = 20000
194197 self .simulated_environment = None
195- self .warm_up = 70
198+ self .warm_up = 10
196199
197200 @property
198201 def num_steps (self ):
199202 return self .settable_num_steps
200203
201204 def _setup (self ):
202- in_graph_wrappers = [(atari .ShiftRewardWrapper , {"add_value" : 2 }),
203- (atari .MemoryWrapper , {})] + self .in_graph_wrappers
205+ in_graph_wrappers = [(atari .MemoryWrapper , {})] + self .in_graph_wrappers
204206 env_hparams = tf .contrib .training .HParams (
205207 in_graph_wrappers = in_graph_wrappers ,
206208 simulated_environment = self .simulated_environment )
@@ -229,41 +231,41 @@ def _setup(self):
229231
230232 self .avilable_data_size_op = atari .MemoryWrapper .singleton .speculum .size ()
231233 self .data_get_op = atari .MemoryWrapper .singleton .speculum .dequeue ()
232- self .history_buffer = deque (maxlen = self .history_size + 1 )
233234
234235 def restore_networks (self , sess ):
235236 if FLAGS .agent_policy_path :
236237 model_saver = tf .train .Saver (
237- tf .global_variables (".*network_parameters.*" ))
238+ tf .global_variables (".*network_parameters.*" ))
238239 model_saver .restore (sess , FLAGS .agent_policy_path )
239240
240241 def generate_encoded_samples (self , data_dir , tmp_dir , unused_dataset_split ):
241242 self ._setup ()
243+ self .debug_dump_frames_path = os .path .join (
244+ data_dir , self .debug_dump_frames_path )
242245
243246 with tf .Session () as sess :
244247 sess .run (tf .global_variables_initializer ())
245248 self .restore_networks (sess )
246-
249+ # Actions are shifted by 1 by MemoryWrapper, compensate here.
250+ avilable_data_size = sess .run (self .avilable_data_size_op )
251+ if avilable_data_size < 1 :
252+ sess .run (self .collect_trigger_op )
247253 pieces_generated = 0
254+ observ , reward , _ , _ = sess .run (self .data_get_op )
248255 while pieces_generated < self .num_steps + self .warm_up :
249256 avilable_data_size = sess .run (self .avilable_data_size_op )
250- if avilable_data_size > 0 :
251- observ , reward , action , _ = sess .run (self .data_get_op )
252- self .history_buffer .append (observ )
253-
254- if len (self .history_buffer ) == self .history_size + 1 :
255- pieces_generated += 1
256- ret_dict = {"image/encoded" : [observ ],
257- "image/format" : ["png" ],
258- "image/height" : [self .frame_height ],
259- "image/width" : [self .frame_width ],
260- "action" : [int (action )],
261- "done" : [int (False )],
262- "reward" : [int (reward ) - self .min_reward ]}
263- if pieces_generated > self .warm_up :
264- yield ret_dict
265- else :
257+ if avilable_data_size < 1 :
266258 sess .run (self .collect_trigger_op )
259+ next_observ , next_reward , action , _ = sess .run (self .data_get_op )
260+ yield {"image/encoded" : [observ ],
261+ "image/format" : ["png" ],
262+ "image/height" : [self .frame_height ],
263+ "image/width" : [self .frame_width ],
264+ "action" : [int (action )],
265+ "done" : [int (False )],
266+ "reward" : [int (reward ) - self .min_reward ]}
267+ pieces_generated += 1
268+ observ , reward = next_observ , next_reward
267269
268270
269271@registry .register_problem
@@ -273,7 +275,7 @@ class GymSimulatedDiscreteProblemWithAgent(GymDiscreteProblemWithAgent):
273275 def __init__ (self , * args , ** kwargs ):
274276 super (GymSimulatedDiscreteProblemWithAgent , self ).__init__ (* args , ** kwargs )
275277 self .simulated_environment = True
276- self .debug_dump_frames_path = "/tmp/t2t_debug_dump_frames "
278+ self .debug_dump_frames_path = "debug_frames_sim "
277279
278280 def restore_networks (self , sess ):
279281 super (GymSimulatedDiscreteProblemWithAgent , self ).restore_networks (sess )
0 commit comments