|
159 | 159 | "metadata": {}, |
160 | 160 | "outputs": [], |
161 | 161 | "source": [ |
162 | | - "training_num_timesteps = [1_000, 1_000_000, 100_000_000]\n", |
163 | | - "vis_steps = [100, 150, 300]\n", |
| 162 | + "training_num_timesteps = [1_000, 5_000_000, 400_000_000]\n", |
164 | 163 | "\n", |
165 | | - "env_sys = []\n", |
166 | | - "rollouts = []\n", |
| 164 | + "inference_fns = []\n", |
167 | 165 | "\n", |
168 | 166 | "for idx, num_timesteps in enumerate(training_num_timesteps):\n", |
169 | 167 | "\tmake_inference_fn, params, times, xdata, ydata = train_ppo(num_timesteps, env_name)\n", |
170 | | - "\tinference_fn = make_inference_fn(params)\n", |
171 | | - "\tsys, rollout = visual_rollout(inference_fn, env_name, steps=vis_steps[idx], seed=SEED)\n", |
172 | | - "\tenv_sys.append(sys)\n", |
173 | | - "\trollouts.append(rollout)" |
| 168 | + "\tinference_fns.append(make_inference_fn(params))" |
174 | 169 | ] |
175 | 170 | }, |
176 | 171 | { |
|
181 | 176 | "#### Visualise learning" |
182 | 177 | ] |
183 | 178 | }, |
| 179 | + { |
| 180 | + "cell_type": "code", |
| 181 | + "execution_count": null, |
| 182 | + "metadata": {}, |
| 183 | + "outputs": [], |
| 184 | + "source": [ |
| 185 | + "vis_steps = [300, 500, 750]\n", |
| 186 | + "\n", |
| 187 | + "env_sys = []\n", |
| 188 | + "rollouts = []\n", |
| 189 | + "\n", |
| 190 | + "for idx, inference_fn in enumerate(inference_fns):\n", |
| 191 | + "\tsys, rollout = visual_rollout(inference_fn, env_name, steps=vis_steps[idx], seed=SEED)\n", |
| 192 | + "\tenv_sys.append(sys)\n", |
| 193 | + "\trollouts.append(rollout)" |
| 194 | + ] |
| 195 | + }, |
184 | 196 | { |
185 | 197 | "cell_type": "code", |
186 | 198 | "execution_count": null, |
|
0 commit comments