Skip to content

Commit 21af505

Browse files
Merge pull request #9 from mohammadzainabbas/dev
Dev
2 parents dc0265c + 5645230 commit 21af505

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

notebooks/demo_ppo_train.ipynb

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,18 +159,13 @@
159159
"metadata": {},
160160
"outputs": [],
161161
"source": [
162-
"training_num_timesteps = [1_000, 1_000_000, 100_000_000]\n",
163-
"vis_steps = [100, 150, 300]\n",
162+
"training_num_timesteps = [1_000, 5_000_000, 400_000_000]\n",
164163
"\n",
165-
"env_sys = []\n",
166-
"rollouts = []\n",
164+
"inference_fns = []\n",
167165
"\n",
168166
"for idx, num_timesteps in enumerate(training_num_timesteps):\n",
169167
"\tmake_inference_fn, params, times, xdata, ydata = train_ppo(num_timesteps, env_name)\n",
170-
"\tinference_fn = make_inference_fn(params)\n",
171-
"\tsys, rollout = visual_rollout(inference_fn, env_name, steps=vis_steps[idx], seed=SEED)\n",
172-
"\tenv_sys.append(sys)\n",
173-
"\trollouts.append(rollout)"
168+
"\tinference_fns.append(make_inference_fn(params))"
174169
]
175170
},
176171
{
@@ -181,6 +176,23 @@
181176
"#### Visualise learning"
182177
]
183178
},
179+
{
180+
"cell_type": "code",
181+
"execution_count": null,
182+
"metadata": {},
183+
"outputs": [],
184+
"source": [
185+
"vis_steps = [300, 500, 750]\n",
186+
"\n",
187+
"env_sys = []\n",
188+
"rollouts = []\n",
189+
"\n",
190+
"for idx, inference_fn in enumerate(inference_fns):\n",
191+
"\tsys, rollout = visual_rollout(inference_fn, env_name, steps=vis_steps[idx], seed=SEED)\n",
192+
"\tenv_sys.append(sys)\n",
193+
"\trollouts.append(rollout)"
194+
]
195+
},
184196
{
185197
"cell_type": "code",
186198
"execution_count": null,

0 commit comments

Comments
 (0)