2424# Dependency imports
2525
2626import gym
27- import numpy as np
27+ import os
28+ from tensorflow .contrib .training import HParams
29+ from collections import deque
2830
2931from tensor2tensor .data_generators import generator_utils
3032from tensor2tensor .data_generators import problem
31-
3233from tensor2tensor .models .research import rl
33- from tensor2tensor .rl import rl_trainer_lib # pylint: disable=unused-import
34- from tensor2tensor .rl .envs import atari_wrappers
35-
36- from tensor2tensor .utils import metrics
3734from tensor2tensor .utils import registry
35+ from tensor2tensor .rl .envs .utils import batch_env_factory
36+ from tensor2tensor .rl .envs .tf_atari_wrappers import MemoryWrapper , TimeLimitWrapper
37+ from tensor2tensor .rl .envs .tf_atari_wrappers import MaxAndSkipWrapper
38+ from tensor2tensor .rl .envs .tf_atari_wrappers import PongT2TGeneratorHackWrapper
39+ from tensor2tensor .rl import collect
3840
3941import tensorflow as tf
4042
4143
44+ def moviepy_editor ():
45+ """Access to moviepy to allow for import of this file without a moviepy install."""
46+ try :
47+ from moviepy import editor # pylint: disable=g-import-not-at-top
48+ except ImportError :
49+ raise ImportError ("pip install moviepy to record videos" )
50+ return editor
51+
4252flags = tf .flags
4353FLAGS = flags .FLAGS
4454
45- flags .DEFINE_string ("model_path" , "" , "File with model for pong" )
46-
55+ flags .DEFINE_string ("agent_policy_path" , "" , "File with model for pong" )
4756
57+ @registry .register_problem
4858class GymDiscreteProblem (problem .Problem ):
4959 """Gym environment with discrete actions and rewards."""
5060
5161 def __init__ (self , * args , ** kwargs ):
5262 super (GymDiscreteProblem , self ).__init__ (* args , ** kwargs )
53- self ._env = None
63+ self .num_channels = 3
64+ self .history_size = 2
65+
66+ # defaults
67+ self .environment_spec = lambda : gym .make ("PongNoFrameskip-v4" )
68+ self .in_graph_wrappers = [(MaxAndSkipWrapper , {"skip" : 4 })]
69+ self .collect_hparams = rl .atari_base ()
70+ self .num_steps = 1000
71+ self .movies = True
72+ self .movies_fps = 24
73+ self .simulated_environment = None
74+ self .warm_up = 70
75+
76+ def _setup (self ):
77+ # TODO: remove PongT2TGeneratorHackWrapper by writing a modality
78+
79+ in_graph_wrappers = [(PongT2TGeneratorHackWrapper , {"add_value" : 2 }),
80+ (MemoryWrapper , {})] + self .in_graph_wrappers
81+ env_hparams = HParams (in_graph_wrappers = in_graph_wrappers ,
82+ simulated_environment = self .simulated_environment )
83+
84+ generator_batch_env = \
85+ batch_env_factory (self .environment_spec , env_hparams , num_agents = 1 , xvfb = False )
86+
87+ with tf .variable_scope ("" , reuse = tf .AUTO_REUSE ):
88+ policy_lambda = self .collect_hparams .network
89+ policy_factory = tf .make_template (
90+ "network" ,
91+ functools .partial (policy_lambda , self .environment_spec ().action_space , self .collect_hparams ),
92+ create_scope_now_ = True ,
93+ unique_name_ = "network" )
5494
55- def example_reading_spec (self , label_repr = None ):
95+ with tf .variable_scope ("" , reuse = tf .AUTO_REUSE ):
96+ sample_policy = lambda policy : 0 * policy .sample ()
97+ # sample_policy = lambda policy: 0
5698
99+ self .collect_hparams .epoch_length = 10
100+ _ , self .collect_trigger_op = collect .define_collect (
101+ policy_factory , generator_batch_env , self .collect_hparams ,
102+ eval_phase = False , policy_to_actions_lambda = sample_policy , scope = "define_collect" )
103+
104+ self .avilable_data_size_op = MemoryWrapper .singleton ._speculum .size ()
105+ self .data_get_op = MemoryWrapper .singleton ._speculum .dequeue ()
106+ self .history_buffer = deque (maxlen = self .history_size + 1 )
107+
108+ def example_reading_spec (self , label_repr = None ):
57109 data_fields = {
58- "inputs " : tf .FixedLenFeature ([ 210 , 160 , 3 ], tf .int64 ),
59- "inputs_prev " : tf .FixedLenFeature ([ 210 , 160 , 3 ], tf .int64 ),
60- "targets " : tf .FixedLenFeature ([210 , 160 , 3 ], tf .int64 ),
61- "action " : tf .FixedLenFeature ([1 ], tf .int64 ),
62- "reward " : tf .FixedLenFeature ([1 ], tf .int64 )
110+ "targets_encoded " : tf .FixedLenFeature ((), tf .string ),
111+ "image/format " : tf .FixedLenFeature ((), tf .string ),
112+ "action " : tf .FixedLenFeature ([1 ], tf .int64 ),
113+ "reward " : tf .FixedLenFeature ([1 ], tf .int64 ),
114+ # "done ": tf.FixedLenFeature([1], tf.int64)
63115 }
64116
65- return data_fields , None
117+ for x in range (self .history_size ):
118+ data_fields ["inputs_encoded_{}" .format (x )] = tf .FixedLenFeature ((), tf .string )
66119
67- def eval_metrics (self ):
68- return [metrics .Metrics .ACC , metrics .Metrics .ACC_PER_SEQ ,
69- metrics .Metrics .NEG_LOG_PERPLEXITY , metrics .Metrics .IMAGE_SUMMARY ]
70120
71- @property
72- def env_name (self ):
73- # This is the name of the Gym environment for this problem.
74- raise NotImplementedError ()
121+ data_items_to_decoders = {
122+ "targets" :
123+ tf .contrib .slim .tfexample_decoder .Image (
124+ image_key = "targets_encoded" ,
125+ format_key = "image/format" ,
126+ shape = [210 , 160 , 3 ],
127+ channels = 3 ),
75128
76- @property
77- def env (self ):
78- if self ._env is None :
79- self ._env = gym .make (self .env_name )
80- return self ._env
129+ #Just do a pass through
130+ "action" :tf .contrib .slim .tfexample_decoder .Tensor (tensor_key = "action" ),
131+ "reward" :tf .contrib .slim .tfexample_decoder .Tensor (tensor_key = "reward" ),
132+ }
81133
82- @property
83- def num_channels (self ):
84- return 3
134+ for x in range (self .history_size ):
135+ data_items_to_decoders ["inputs_{}" .format (x )] = tf .contrib .slim .tfexample_decoder .Image (
136+ image_key = "inputs_encoded_{}" .format (x ),
137+ format_key = "image/format" ,
138+ shape = [210 , 160 , 3 ],
139+ channels = 3 )
140+
141+ return data_fields , data_items_to_decoders
142+
143+ # def preprocess_example(self, example, mode, hparams):
144+ # if not self._was_reversed:
145+ # for x in range(self.history_size):
146+ # input_name = "inputs_{}".format(x)
147+ # example[input_name] = tf.image.per_image_standardization(example[input_name])
148+ # return example
85149
86150 @property
87151 def num_actions (self ):
88- raise NotImplementedError ()
152+ return 4
89153
90154 @property
91155 def num_rewards (self ):
92- raise NotImplementedError ()
93-
94- @property
95- def num_steps (self ):
96- raise NotImplementedError ()
156+ return 2
97157
98158 @property
99159 def num_shards (self ):
@@ -108,35 +168,70 @@ def get_action(self, observation=None):
108168
109169 def hparams (self , defaults , unused_model_hparams ):
110170 p = defaults
111- p .input_modality = {"inputs" : ("image" , 256 ),
112- "inputs_prev" : ("image" , 256 ),
113- "reward" : ("symbol" , self .num_rewards ),
114- "action" : ("symbol" , self .num_actions )}
115- p .target_modality = ("image" , 256 )
171+ # hard coded +1 after "symbol" refers to the fact
172+ # that 0 is a special symbol meaning padding
173+ # when symbols are e.g. 0, 1, 2, 3 we
174+ # shift them to 0, 1, 2, 3, 4
175+ p .input_modality = {"action" : ("symbol:identity" , self .num_actions )}
176+
177+ for x in range (self .history_size ):
178+ p .input_modality ["inputs_{}" .format (x )] = ("image" , 256 )
179+
180+ p .target_modality = {"targets" : ("image" , 256 ),
181+ "reward" : ("symbol" , self .num_rewards + 1 ),
182+ # "done": ("symbol", 2+1)
183+ }
184+
116185 p .input_space_id = problem .SpaceID .IMAGE
117186 p .target_space_id = problem .SpaceID .IMAGE
118187
188+ def restore_networks (self , sess ):
189+ model_saver = tf .train .Saver (
190+ tf .global_variables (".*network_parameters.*" ))
191+ if FLAGS .agent_policy_path :
192+ model_saver .restore (sess , FLAGS .agent_policy_path )
193+
119194 def generator (self , data_dir , tmp_dir ):
120- self .env .reset ()
121- action = self .get_action ()
122- prev_observation , observation = None , None
123- for _ in range (self .num_steps ):
124- prev_prev_observation = prev_observation
125- prev_observation = observation
126- observation , reward , done , _ = self .env .step (action )
127- action = self .get_action (observation )
128- if done :
129- self .env .reset ()
130- def flatten (nparray ):
131- flat1 = [x for sublist in nparray .tolist () for x in sublist ]
132- return [x for sublist in flat1 for x in sublist ]
133- if prev_prev_observation is not None :
134- yield {"inputs_prev" : flatten (prev_prev_observation ),
135- "inputs" : flatten (prev_observation ),
136- "action" : [action ],
137- "done" : [done ],
138- "reward" : [int (reward )],
139- "targets" : flatten (observation )}
195+ self ._setup ()
196+ clip_files = []
197+ with tf .Session () as sess :
198+ sess .run (tf .global_variables_initializer ())
199+ self .restore_networks (sess )
200+
201+ pieces_generated = 0
202+ while pieces_generated < self .num_steps + self .warm_up :
203+ avilable_data_size = sess .run (self .avilable_data_size_op )
204+ if avilable_data_size > 0 :
205+ observ , reward , action , done = sess .run (self .data_get_op )
206+ self .history_buffer .append (observ )
207+
208+ if self .movies == True and pieces_generated > self .warm_up :
209+ file_name = os .path .join (tmp_dir ,'output_{}.png' .format (pieces_generated ))
210+ clip_files .append (file_name )
211+ with open (file_name , 'wb' ) as f :
212+ f .write (observ )
213+
214+ if len (self .history_buffer )== self .history_size + 1 :
215+ pieces_generated += 1
216+ ret_dict = {
217+ "targets_encoded" : [observ ],
218+ "image/format" : ["png" ],
219+ "action" : [int (action )],
220+ # "done": [bool(done)],
221+ "reward" : [int (reward )],
222+ }
223+ for i , v in enumerate (list (self .history_buffer )[:- 1 ]):
224+ ret_dict ["inputs_encoded_{}" .format (i )] = [v ]
225+ if pieces_generated > self .warm_up :
226+ yield ret_dict
227+ else :
228+ sess .run (self .collect_trigger_op )
229+ if self .movies :
230+ # print(clip_files)
231+ clip = moviepy_editor ().ImageSequenceClip (clip_files , fps = self .movies_fps )
232+ clip .write_videofile (os .path .join (data_dir , 'output_{}.mp4' .format (self .name )),
233+ fps = self .movies_fps , codec = 'mpeg4' )
234+
140235
141236 def generate_data (self , data_dir , tmp_dir , task_id = - 1 ):
142237 train_paths = self .training_filepaths (
@@ -150,93 +245,23 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
150245
151246
152247@registry .register_problem
153- class GymPongRandom5k (GymDiscreteProblem ):
154- """Pong game, random actions."""
155-
156- @property
157- def env_name (self ):
158- return "PongDeterministic-v4"
159-
160- @property
161- def num_actions (self ):
162- return 4
163-
164- @property
165- def num_rewards (self ):
166- return 2
167-
168- @property
169- def num_steps (self ):
170- return 5000
171-
172-
173- @registry .register_problem
174- class GymPongTrajectoriesFromPolicy (GymDiscreteProblem ):
175- """Pong game, loaded actions."""
248+ class GymSimulatedDiscreteProblem (GymDiscreteProblem ):
249+ """Simulated gym environment with discrete actions and rewards."""
176250
177251 def __init__ (self , * args , ** kwargs ):
178- super (GymPongTrajectoriesFromPolicy , self ).__init__ (* args , ** kwargs )
179- self ._env = None
180- self ._last_policy_op = None
181- self ._max_frame_pl = None
182- self ._last_action = self .env .action_space .sample ()
183- self ._skip = 4
184- self ._skip_step = 0
185- self ._obs_buffer = np .zeros ((2 ,) + self .env .observation_space .shape ,
186- dtype = np .uint8 )
187-
188- def generator (self , data_dir , tmp_dir ):
189- env_spec = lambda : atari_wrappers .wrap_atari ( # pylint: disable=g-long-lambda
190- gym .make (self .env_name ),
191- warp = False ,
192- frame_skip = 4 ,
193- frame_stack = False )
194- hparams = rl .atari_base ()
195- with tf .variable_scope ("train" , reuse = tf .AUTO_REUSE ):
196- policy_lambda = hparams .network
197- policy_factory = tf .make_template (
198- "network" ,
199- functools .partial (policy_lambda , env_spec ().action_space , hparams ))
200- self ._max_frame_pl = tf .placeholder (
201- tf .float32 , self .env .observation_space .shape )
202- actor_critic = policy_factory (tf .expand_dims (tf .expand_dims (
203- self ._max_frame_pl , 0 ), 0 ))
204- policy = actor_critic .policy
205- self ._last_policy_op = policy .mode ()
206- with tf .Session () as sess :
207- model_saver = tf .train .Saver (
208- tf .global_variables (".*network_parameters.*" ))
209- model_saver .restore (sess , FLAGS .model_path )
210- for item in super (GymPongTrajectoriesFromPolicy ,
211- self ).generator (data_dir , tmp_dir ):
212- yield item
213-
214- # TODO(blazej0): For training of atari agents wrappers are usually used.
215- # Below we have a hacky solution which is a workaround to be used together
216- # with atari_wrappers.MaxAndSkipEnv.
217- def get_action (self , observation = None ):
218- if self ._skip_step == self ._skip - 2 : self ._obs_buffer [0 ] = observation
219- if self ._skip_step == self ._skip - 1 : self ._obs_buffer [1 ] = observation
220- self ._skip_step = (self ._skip_step + 1 ) % self ._skip
221- if self ._skip_step == 0 :
222- max_frame = self ._obs_buffer .max (axis = 0 )
223- self ._last_action = int (tf .get_default_session ().run (
224- self ._last_policy_op ,
225- feed_dict = {self ._max_frame_pl : max_frame })[0 , 0 ])
226- return self ._last_action
227-
228- @property
229- def env_name (self ):
230- return "PongDeterministic-v4"
231-
232- @property
233- def num_actions (self ):
234- return 4
235-
236- @property
237- def num_rewards (self ):
238- return 2
239-
240- @property
241- def num_steps (self ):
242- return 5000
252+ super (GymSimulatedDiscreteProblem , self ).__init__ (* args , ** kwargs )
253+ #TODO: pull it outside
254+ self .in_graph_wrappers = [(TimeLimitWrapper , {"timelimit" : 150 }), (MaxAndSkipWrapper , {"skip" : 4 })]
255+ self .simulated_environment = True
256+ self .movies_fps = 2
257+
258+ def restore_networks (self , sess ):
259+ super (GymSimulatedDiscreteProblem , self ).restore_networks (sess )
260+
261+ #TODO: adjust regexp for different models
262+ env_model_loader = tf .train .Saver (tf .global_variables (".*basic_conv_gen.*" ))
263+ sess = tf .get_default_session ()
264+
265+ ckpts = tf .train .get_checkpoint_state (FLAGS .output_dir )
266+ ckpt = ckpts .model_checkpoint_path
267+ env_model_loader .restore (sess , ckpt )
0 commit comments