Skip to content

Commit dc0265c

Browse files
Merge pull request #8 from mohammadzainabbas/dev
Dev
2 parents 40ccb09 + 44090bc commit dc0265c

File tree

1 file changed

+219
-0
lines changed

1 file changed

+219
-0
lines changed

notebooks/demo_ppo_train.ipynb

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
{
2+
"cells": [
3+
{
4+
"attachments": {},
5+
"cell_type": "markdown",
6+
"metadata": {
7+
"id": "ssCOanHc8JH_"
8+
},
9+
"source": [
10+
"## Demo for step-by-step training with PPO"
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": null,
16+
"metadata": {
17+
"id": "_sOmCoOrF0F8"
18+
},
19+
"outputs": [],
20+
"source": [
21+
"from datetime import datetime\n",
22+
"import functools\n",
23+
"import os\n",
24+
"from os import getcwd\n",
25+
"from os.path import join\n",
26+
"from IPython.display import HTML, clear_output\n",
27+
"\n",
28+
"import jax\n",
29+
"import jax.numpy as jnp\n",
30+
"import matplotlib.pyplot as plt\n",
31+
"\n",
32+
"try:\n",
33+
" import brax\n",
34+
"except ImportError:\n",
35+
" !pip install git+https://github.com/google/brax.git@main\n",
36+
" clear_output()\n",
37+
" import brax\n",
38+
"\n",
39+
"from brax import envs\n",
40+
"from brax import jumpy as jp\n",
41+
"from brax.io import html\n",
42+
"from brax.io import model\n",
43+
"from brax.training.agents.ppo import train as ppo\n",
44+
"\n",
45+
"from IPython.core.interactiveshell import InteractiveShell\n",
46+
"InteractiveShell.ast_node_interactivity = \"all\"\n",
47+
"\n",
48+
"if 'COLAB_TPU_ADDR' in os.environ:\n",
49+
" from jax.tools import colab_tpu\n",
50+
" colab_tpu.setup_tpu()"
51+
]
52+
},
53+
{
54+
"attachments": {},
55+
"cell_type": "markdown",
56+
"metadata": {
57+
"id": "Tm8zbPBcJ5RJ"
58+
},
59+
"source": [
60+
"#### Environment"
61+
]
62+
},
63+
{
64+
"cell_type": "code",
65+
"execution_count": null,
66+
"metadata": {
67+
"colab": {
68+
"base_uri": "https://localhost:8080/",
69+
"height": 480
70+
},
71+
"id": "NaJDZqhCLovU",
72+
"outputId": "50994b20-d788-4264-af00-a3f06d58f943"
73+
},
74+
"outputs": [],
75+
"source": [
76+
"SEED = 0\n",
77+
"env_name = \"grasp\"\n",
78+
"env = envs.get_environment(env_name=env_name)\n",
79+
"state = env.reset(rng=jp.random_prngkey(seed=SEED))\n",
80+
"\n",
81+
"HTML(html.render(env.sys, [state.qp]))"
82+
]
83+
},
84+
{
85+
"attachments": {},
86+
"cell_type": "markdown",
87+
"metadata": {},
88+
"source": [
89+
"#### Helper functions"
90+
]
91+
},
92+
{
93+
"cell_type": "code",
94+
"execution_count": null,
95+
"metadata": {},
96+
"outputs": [],
97+
"source": [
98+
"def train_ppo(num_timesteps, env_name):\n",
99+
"\tprint(f\"Training PPO for '{num_timesteps}' timesteps'\")\n",
100+
"\n",
101+
"\tenv = envs.get_environment(env_name=env_name)\n",
102+
"\tstate = env.reset(rng=jp.random_prngkey(seed=SEED))\n",
103+
"\n",
104+
"\ttrain_fn = functools.partial(ppo.train, num_timesteps=num_timesteps, num_evals=10, reward_scaling=10, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=20, num_minibatches=32, num_updates_per_batch=2, discounting=0.99, learning_rate=3e-4, entropy_cost=0.001, num_envs=2048, batch_size=256)\n",
105+
"\n",
106+
"\tmax_y = 100\n",
107+
"\tmin_y = 0\n",
108+
"\n",
109+
"\txdata, ydata = [], []\n",
110+
"\ttimes = [datetime.now()]\n",
111+
"\n",
112+
"\tdef progress(num_steps, metrics):\n",
113+
"\t\ttimes.append(datetime.now())\n",
114+
"\t\txdata.append(num_steps)\n",
115+
"\t\tydata.append(metrics['eval/episode_reward'])\n",
116+
"\t\tclear_output(wait=True)\n",
117+
"\t\t# plt.xlim([0, train_fn.keywords['num_timesteps']])\n",
118+
"\t\t# plt.ylim([min_y, max_y])\n",
119+
"\t\t# plt.xlabel('# environment steps')\n",
120+
"\t\t# plt.ylabel('reward per episode')\n",
121+
"\t\t# plt.plot(xdata, ydata)\n",
122+
"\t\t# plt.show()\n",
123+
"\n",
124+
"\tmake_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)\n",
125+
"\tprint(f'time to jit: {times[1] - times[0]}')\n",
126+
"\tprint(f'time to train: {times[-1] - times[1]}')\n",
127+
"\n",
128+
"\treturn make_inference_fn, params, times, xdata, ydata\n",
129+
"\n",
130+
"def visual_rollout(inference_fn, env_name, steps=100, seed=0):\n",
131+
"\tenv = envs.create(env_name=env_name)\n",
132+
"\tjit_env_reset = jax.jit(env.reset)\n",
133+
"\tjit_env_step = jax.jit(env.step)\n",
134+
"\tjit_inference_fn = jax.jit(inference_fn)\n",
135+
"\n",
136+
"\trollout = []\n",
137+
"\trng = jax.random.PRNGKey(seed=seed)\n",
138+
"\tstate = jit_env_reset(rng=rng)\n",
139+
"\tfor _ in range(steps):\n",
140+
"\t\trollout.append(state)\n",
141+
"\t\tact_rng, rng = jax.random.split(rng)\n",
142+
"\t\tact, _ = jit_inference_fn(state.obs, act_rng)\n",
143+
"\t\tstate = jit_env_step(state, act)\n",
144+
"\n",
145+
"\treturn env.sys, [s.qp for s in rollout]"
146+
]
147+
},
148+
{
149+
"attachments": {},
150+
"cell_type": "markdown",
151+
"metadata": {},
152+
"source": [
153+
"#### Training (step-by-step)"
154+
]
155+
},
156+
{
157+
"cell_type": "code",
158+
"execution_count": null,
159+
"metadata": {},
160+
"outputs": [],
161+
"source": [
162+
"training_num_timesteps = [1_000, 1_000_000, 100_000_000]\n",
163+
"vis_steps = [100, 150, 300]\n",
164+
"\n",
165+
"env_sys = []\n",
166+
"rollouts = []\n",
167+
"\n",
168+
"for idx, num_timesteps in enumerate(training_num_timesteps):\n",
169+
"\tmake_inference_fn, params, times, xdata, ydata = train_ppo(num_timesteps, env_name)\n",
170+
"\tinference_fn = make_inference_fn(params)\n",
171+
"\tsys, rollout = visual_rollout(inference_fn, env_name, steps=vis_steps[idx], seed=SEED)\n",
172+
"\tenv_sys.append(sys)\n",
173+
"\trollouts.append(rollout)"
174+
]
175+
},
176+
{
177+
"attachments": {},
178+
"cell_type": "markdown",
179+
"metadata": {},
180+
"source": [
181+
"#### Visualise learning"
182+
]
183+
},
184+
{
185+
"cell_type": "code",
186+
"execution_count": null,
187+
"metadata": {},
188+
"outputs": [],
189+
"source": [
190+
"for i, sys in enumerate(env_sys):\n",
191+
"\tHTML(html.render(sys, rollouts[i]))"
192+
]
193+
}
194+
],
195+
"metadata": {
196+
"accelerator": "TPU",
197+
"colab": {
198+
"name": "Brax Training.ipynb",
199+
"provenance": []
200+
},
201+
"gpuClass": "standard",
202+
"kernelspec": {
203+
"display_name": "reinforcement_learning",
204+
"language": "python",
205+
"name": "python3"
206+
},
207+
"language_info": {
208+
"name": "python",
209+
"version": "3.10.8 (main, Nov 24 2022, 14:13:03) [GCC 11.2.0]"
210+
},
211+
"vscode": {
212+
"interpreter": {
213+
"hash": "b329387e251b95764b8f65684563519503b45dc8027da482b0a7bdbaa4a30d3e"
214+
}
215+
}
216+
},
217+
"nbformat": 4,
218+
"nbformat_minor": 0
219+
}

0 commit comments

Comments
 (0)