Skip to content

Commit 8c1805f

Browse files
Merge pull request #10 from mohammadzainabbas/dev
Dev
2 parents 21af505 + 4968c77 commit 8c1805f

File tree

9 files changed

+1925
-96
lines changed

9 files changed

+1925
-96
lines changed

β€Žnotebooks/demo.ipynbβ€Ž

Lines changed: 383 additions & 71 deletions
Large diffs are not rendered by default.

β€Žnotebooks/demo_ppo_train.ipynbβ€Ž

Lines changed: 225 additions & 25 deletions
Large diffs are not rendered by default.

β€Žsrc/ppo_with_pytorch.pyβ€Ž

Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
import collections
2+
from datetime import datetime
3+
import functools
4+
import math
5+
import time
6+
from typing import Any, Callable, Dict, Optional, Sequence
7+
8+
from brax import envs
9+
from brax.envs import to_torch
10+
from brax.io import metrics
11+
from brax.training.agents.ppo import train as ppo
12+
import gymnasium as gym
13+
import matplotlib.pyplot as plt
14+
import numpy as np
15+
import torch
16+
from torch import nn
17+
from torch import optim
18+
import torch.nn.functional as F
19+
20+
# have torch allocate on device first, to prevent JAX from swallowing up all the
21+
# GPU memory. By default JAX will pre-allocate 90% of the available GPU memory:
22+
# https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
23+
24+
DEVICE = "cpu" if not torch.cuda.is_available() else "cuda"
25+
v = torch.ones(1, device=DEVICE)
26+
27+
class PPOAgent(nn.Module):
28+
"""
29+
Standard PPO Agent with GAE and observation normalization.
30+
"""
31+
def __init__(self, policy_layers: Sequence[int], value_layers: Sequence[int], entropy_cost: float, discounting: float, reward_scaling: float, device: str):
32+
super(PPOAgent, self).__init__()
33+
policy = []
34+
35+
for w1, w2 in zip(policy_layers, policy_layers[1:]):
36+
policy.append(nn.Linear(w1, w2))
37+
policy.append(nn.SiLU())
38+
policy.pop() # drop the final activation
39+
self.policy = nn.Sequential(*policy)
40+
41+
value = []
42+
for w1, w2 in zip(value_layers, value_layers[1:]):
43+
value.append(nn.Linear(w1, w2))
44+
value.append(nn.SiLU())
45+
value.pop() # drop the final activation
46+
self.value = nn.Sequential(*value)
47+
48+
self.num_steps = torch.zeros((), device=device)
49+
self.running_mean = torch.zeros(policy_layers[0], device=device)
50+
self.running_variance = torch.zeros(policy_layers[0], device=device)
51+
52+
self.entropy_cost = entropy_cost
53+
self.discounting = discounting
54+
self.reward_scaling = reward_scaling
55+
self.lambda_ = 0.95
56+
self.epsilon = 0.3
57+
self.device = device
58+
59+
@torch.jit.export
60+
def dist_create(self, logits):
61+
"""
62+
Normal followed by tanh.
63+
64+
torch.distribution doesn't work with torch.jit, so we roll our own.
65+
"""
66+
loc, scale = torch.split(logits, logits.shape[-1] // 2, dim=-1)
67+
scale = F.softplus(scale) + .001
68+
return loc, scale
69+
70+
@torch.jit.export
71+
def dist_sample_no_postprocess(self, loc, scale):
72+
return torch.normal(loc, scale)
73+
74+
@classmethod
75+
def dist_postprocess(cls, x):
76+
return torch.tanh(x)
77+
78+
@torch.jit.export
79+
def dist_entropy(self, loc, scale):
80+
log_normalized = 0.5 * math.log(2 * math.pi) + torch.log(scale)
81+
entropy = 0.5 + log_normalized
82+
entropy = entropy * torch.ones_like(loc)
83+
dist = torch.normal(loc, scale)
84+
log_det_jacobian = 2 * (math.log(2) - dist - F.softplus(-2 * dist))
85+
entropy = entropy + log_det_jacobian
86+
return entropy.sum(dim=-1)
87+
88+
@torch.jit.export
89+
def dist_log_prob(self, loc, scale, dist):
90+
log_unnormalized = -0.5 * ((dist - loc) / scale).square()
91+
log_normalized = 0.5 * math.log(2 * math.pi) + torch.log(scale)
92+
log_det_jacobian = 2 * (math.log(2) - dist - F.softplus(-2 * dist))
93+
log_prob = log_unnormalized - log_normalized - log_det_jacobian
94+
return log_prob.sum(dim=-1)
95+
96+
@torch.jit.export
97+
def update_normalization(self, observation):
98+
self.num_steps += observation.shape[0] * observation.shape[1]
99+
input_to_old_mean = observation - self.running_mean
100+
mean_diff = torch.sum(input_to_old_mean / self.num_steps, dim=(0, 1))
101+
self.running_mean = self.running_mean + mean_diff
102+
input_to_new_mean = observation - self.running_mean
103+
var_diff = torch.sum(input_to_new_mean * input_to_old_mean, dim=(0, 1))
104+
self.running_variance = self.running_variance + var_diff
105+
106+
@torch.jit.export
107+
def normalize(self, observation):
108+
variance = self.running_variance / (self.num_steps + 1.0)
109+
variance = torch.clip(variance, 1e-6, 1e6)
110+
return ((observation - self.running_mean) / variance.sqrt()).clip(-5, 5)
111+
112+
@torch.jit.export
113+
def get_logits_action(self, observation):
114+
observation = self.normalize(observation)
115+
logits = self.policy(observation)
116+
loc, scale = self.dist_create(logits)
117+
action = self.dist_sample_no_postprocess(loc, scale)
118+
return logits, action
119+
120+
@torch.jit.export
121+
def compute_gae(self, truncation, termination, reward, values, bootstrap_value):
122+
truncation_mask = 1 - truncation
123+
# Append bootstrapped value to get [v1, ..., v_t+1]
124+
values_t_plus_1 = torch.cat([values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0)
125+
deltas = reward + self.discounting * (1 - termination) * values_t_plus_1 - values
126+
deltas *= truncation_mask
127+
128+
acc = torch.zeros_like(bootstrap_value)
129+
vs_minus_v_xs = torch.zeros_like(truncation_mask)
130+
131+
for ti in range(truncation_mask.shape[0]):
132+
ti = truncation_mask.shape[0] - ti - 1
133+
acc = deltas[ti] + self.discounting * (1 - termination[ti]) * truncation_mask[ti] * self.lambda_ * acc
134+
vs_minus_v_xs[ti] = acc
135+
136+
# Add V(x_s) to get v_s.
137+
vs = vs_minus_v_xs + values
138+
vs_t_plus_1 = torch.cat([vs[1:], torch.unsqueeze(bootstrap_value, 0)], 0)
139+
advantages = (reward + self.discounting * (1 - termination) * vs_t_plus_1 - values) * truncation_mask
140+
return vs, advantages
141+
142+
@torch.jit.export
143+
def loss(self, td: Dict[str, torch.Tensor]):
144+
observation = self.normalize(td['observation'])
145+
policy_logits = self.policy(observation[:-1])
146+
baseline = self.value(observation)
147+
baseline = torch.squeeze(baseline, dim=-1)
148+
149+
# Use last baseline value (from the value function) to bootstrap.
150+
bootstrap_value = baseline[-1]
151+
baseline = baseline[:-1]
152+
reward = td['reward'] * self.reward_scaling
153+
termination = td['done'] * (1 - td['truncation'])
154+
155+
loc, scale = self.dist_create(td['logits'])
156+
behaviour_action_log_probs = self.dist_log_prob(loc, scale, td['action'])
157+
loc, scale = self.dist_create(policy_logits)
158+
target_action_log_probs = self.dist_log_prob(loc, scale, td['action'])
159+
160+
with torch.no_grad():
161+
vs, advantages = self.compute_gae(
162+
truncation=td['truncation'],
163+
termination=termination,
164+
reward=reward,
165+
values=baseline,
166+
bootstrap_value=bootstrap_value)
167+
168+
rho_s = torch.exp(target_action_log_probs - behaviour_action_log_probs)
169+
surrogate_loss1 = rho_s * advantages
170+
surrogate_loss2 = rho_s.clip(1 - self.epsilon, 1 + self.epsilon) * advantages
171+
policy_loss = -torch.mean(torch.minimum(surrogate_loss1, surrogate_loss2))
172+
173+
# Value function loss
174+
v_error = vs - baseline
175+
v_loss = torch.mean(v_error * v_error) * 0.5 * 0.5
176+
177+
# Entropy reward
178+
entropy = torch.mean(self.dist_entropy(loc, scale))
179+
entropy_loss = self.entropy_cost * -entropy
180+
181+
return policy_loss + v_loss + entropy_loss
182+
183+
StepData = collections.namedtuple('StepData', ('observation', 'logits', 'action', 'reward', 'done', 'truncation'))
184+
185+
def sd_map(f: Callable[..., torch.Tensor], *sds) -> StepData:
186+
"""
187+
Map a function over each field in StepData.
188+
"""
189+
items = {}
190+
keys = sds[0]._asdict().keys()
191+
for k in keys:
192+
items[k] = f(*[sd._asdict()[k] for sd in sds])
193+
return StepData(**items)
194+
195+
def eval_unroll(agent, env, length):
196+
"""
197+
Return number of episodes and average reward for a single unroll.
198+
"""
199+
observation = env.reset()
200+
episodes = torch.zeros((), device=agent.device)
201+
episode_reward = torch.zeros((), device=agent.device)
202+
for _ in range(length):
203+
_, action = agent.get_logits_action(observation)
204+
observation, reward, done, _ = env.step(PPOAgent.dist_postprocess(action))
205+
episodes += torch.sum(done)
206+
episode_reward += torch.sum(reward)
207+
return episodes, episode_reward / episodes
208+
209+
def train_unroll(agent, env, observation, num_unrolls, unroll_length):
210+
"""
211+
Return step data over multple unrolls.
212+
"""
213+
sd = StepData([], [], [], [], [], [])
214+
for _ in range(num_unrolls):
215+
one_unroll = StepData([observation], [], [], [], [], [])
216+
for _ in range(unroll_length):
217+
logits, action = agent.get_logits_action(observation)
218+
observation, reward, done, info = env.step(PPOAgent.dist_postprocess(action))
219+
one_unroll.observation.append(observation)
220+
one_unroll.logits.append(logits)
221+
one_unroll.action.append(action)
222+
one_unroll.reward.append(reward)
223+
one_unroll.done.append(done)
224+
one_unroll.truncation.append(info['truncation'])
225+
one_unroll = sd_map(torch.stack, one_unroll)
226+
sd = sd_map(lambda x, y: x + [y], sd, one_unroll)
227+
td = sd_map(torch.stack, sd)
228+
return observation, td
229+
230+
def train(
231+
env_name: str = 'ant',
232+
# env_name: str = 'FetchSlide-v2',
233+
num_envs: int = 2048,
234+
episode_length: int = 1000,
235+
device: str = DEVICE,
236+
num_timesteps: int = 30_000_000,
237+
eval_frequency: int = 10,
238+
unroll_length: int = 5,
239+
batch_size: int = 1024,
240+
num_minibatches: int = 32,
241+
num_update_epochs: int = 4,
242+
reward_scaling: float = .1,
243+
entropy_cost: float = 1e-2,
244+
discounting: float = .97,
245+
learning_rate: float = 3e-4,
246+
progress_fn: Optional[Callable[[int, Dict[str, Any]], None]] = None,
247+
):
248+
"""
249+
Trains a policy via PPO.
250+
"""
251+
gym_name = f'brax-{env_name}-v0'
252+
# gym_name = f'FetchSlide-v2'
253+
if gym_name not in gym.envs.registry.keys():
254+
entry_point = functools.partial(envs.create_gym_env, env_name=env_name)
255+
gym.register(gym_name, entry_point=entry_point)
256+
# env = gym.make(gym_name, batch_size=num_envs, episode_length=episode_length)
257+
env = gym.make(gym_name)
258+
# automatically convert between jax ndarrays and torch tensors:
259+
env = to_torch.JaxToTorchWrapper(env, device=device)
260+
261+
# env warmup
262+
env.reset()
263+
action = torch.zeros(env.action_space.shape).to(device)
264+
env.step(action)
265+
266+
# create the agent
267+
policy_layers = [env.observation_space.shape[-1], 64, 64, env.action_space.shape[-1] * 2]
268+
value_layers = [env.observation_space.shape[-1], 64, 64, 1]
269+
agent = PPOAgent(policy_layers, value_layers, entropy_cost, discounting, reward_scaling, device)
270+
agent = torch.jit.script(agent.to(device))
271+
optimizer = optim.Adam(agent.parameters(), lr=learning_rate)
272+
273+
sps = 0
274+
total_steps = 0
275+
total_loss = 0
276+
for eval_i in range(eval_frequency + 1):
277+
if progress_fn:
278+
t = time.time()
279+
with torch.no_grad():
280+
episode_count, episode_reward = eval_unroll(agent, env, episode_length)
281+
duration = time.time() - t
282+
# TODO: only count stats from completed episodes
283+
episode_avg_length = env.num_envs * episode_length / episode_count
284+
eval_sps = env.num_envs * episode_length / duration
285+
progress = {
286+
'eval/episode_reward': episode_reward,
287+
'eval/completed_episodes': episode_count,
288+
'eval/avg_episode_length': episode_avg_length,
289+
'speed/sps': sps,
290+
'speed/eval_sps': eval_sps,
291+
'losses/total_loss': total_loss,
292+
}
293+
progress_fn(total_steps, progress)
294+
295+
if eval_i == eval_frequency: break
296+
297+
observation = env.reset()
298+
num_steps = batch_size * num_minibatches * unroll_length
299+
num_epochs = num_timesteps // (num_steps * eval_frequency)
300+
num_unrolls = batch_size * num_minibatches // env.num_envs
301+
total_loss = 0
302+
t = time.time()
303+
for _ in range(num_epochs):
304+
observation, td = train_unroll(agent, env, observation, num_unrolls, unroll_length)
305+
306+
# make unroll first
307+
def unroll_first(data):
308+
data = data.swapaxes(0, 1)
309+
return data.reshape([data.shape[0], -1] + list(data.shape[3:]))
310+
td = sd_map(unroll_first, td)
311+
312+
# update normalization statistics
313+
agent.update_normalization(td.observation)
314+
315+
for _ in range(num_update_epochs):
316+
# shuffle and batch the data
317+
with torch.no_grad():
318+
permutation = torch.randperm(td.observation.shape[1], device=device)
319+
def shuffle_batch(data):
320+
data = data[:, permutation]
321+
data = data.reshape([data.shape[0], num_minibatches, -1] + list(data.shape[2:]))
322+
return data.swapaxes(0, 1)
323+
epoch_td = sd_map(shuffle_batch, td)
324+
325+
for minibatch_i in range(num_minibatches):
326+
td_minibatch = sd_map(lambda d: d[minibatch_i], epoch_td)
327+
loss = agent.loss(td_minibatch._asdict())
328+
optimizer.zero_grad()
329+
loss.backward()
330+
optimizer.step()
331+
total_loss += loss.detach()
332+
333+
duration = time.time() - t
334+
total_steps += num_epochs * num_steps
335+
total_loss = total_loss / (num_epochs * num_update_epochs * num_minibatches)
336+
sps = num_epochs * num_steps / duration
337+
338+
return agent, env
339+
340+
def main() -> None:
341+
xdata, ydata, eval_sps, train_sps, times = [], [], [], [], [datetime.now()]
342+
343+
def progress(num_steps, metrics):
344+
times.append(datetime.now())
345+
xdata.append(num_steps)
346+
# copy to cpu, otherwise matplotlib throws an exception
347+
reward = metrics['eval/episode_reward'].cpu()
348+
ydata.append(reward)
349+
eval_sps.append(metrics['speed/eval_sps'])
350+
train_sps.append(metrics['speed/sps'])
351+
352+
# plt.xlim([0, 30_000_000])
353+
# plt.ylim([0, 6_000])
354+
# plt.xlabel('# environment steps')
355+
# plt.ylabel('reward per episode')
356+
# plt.plot(xdata, ydata)
357+
# plt.show()
358+
359+
agent, env = train(progress_fn=progress)
360+
361+
print(f'time to jit: {times[1] - times[0]}')
362+
print(f'time to train: {times[-1] - times[1]}')
363+
print(f'eval steps/sec: {np.mean(eval_sps[1:])}')
364+
print(f'train steps/sec: {np.mean(train_sps[1:])}')
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
Β (0)