|
| 1 | +import collections |
| 2 | +from datetime import datetime |
| 3 | +import functools |
| 4 | +import math |
| 5 | +import time |
| 6 | +from typing import Any, Callable, Dict, Optional, Sequence |
| 7 | + |
| 8 | +from brax import envs |
| 9 | +from brax.envs import to_torch |
| 10 | +from brax.io import metrics |
| 11 | +from brax.training.agents.ppo import train as ppo |
| 12 | +import gymnasium as gym |
| 13 | +import matplotlib.pyplot as plt |
| 14 | +import numpy as np |
| 15 | +import torch |
| 16 | +from torch import nn |
| 17 | +from torch import optim |
| 18 | +import torch.nn.functional as F |
| 19 | + |
| 20 | +# have torch allocate on device first, to prevent JAX from swallowing up all the |
| 21 | +# GPU memory. By default JAX will pre-allocate 90% of the available GPU memory: |
| 22 | +# https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html |
| 23 | + |
| 24 | +DEVICE = "cpu" if not torch.cuda.is_available() else "cuda" |
| 25 | +v = torch.ones(1, device=DEVICE) |
| 26 | + |
| 27 | +class PPOAgent(nn.Module): |
| 28 | + """ |
| 29 | + Standard PPO Agent with GAE and observation normalization. |
| 30 | + """ |
| 31 | + def __init__(self, policy_layers: Sequence[int], value_layers: Sequence[int], entropy_cost: float, discounting: float, reward_scaling: float, device: str): |
| 32 | + super(PPOAgent, self).__init__() |
| 33 | + policy = [] |
| 34 | + |
| 35 | + for w1, w2 in zip(policy_layers, policy_layers[1:]): |
| 36 | + policy.append(nn.Linear(w1, w2)) |
| 37 | + policy.append(nn.SiLU()) |
| 38 | + policy.pop() # drop the final activation |
| 39 | + self.policy = nn.Sequential(*policy) |
| 40 | + |
| 41 | + value = [] |
| 42 | + for w1, w2 in zip(value_layers, value_layers[1:]): |
| 43 | + value.append(nn.Linear(w1, w2)) |
| 44 | + value.append(nn.SiLU()) |
| 45 | + value.pop() # drop the final activation |
| 46 | + self.value = nn.Sequential(*value) |
| 47 | + |
| 48 | + self.num_steps = torch.zeros((), device=device) |
| 49 | + self.running_mean = torch.zeros(policy_layers[0], device=device) |
| 50 | + self.running_variance = torch.zeros(policy_layers[0], device=device) |
| 51 | + |
| 52 | + self.entropy_cost = entropy_cost |
| 53 | + self.discounting = discounting |
| 54 | + self.reward_scaling = reward_scaling |
| 55 | + self.lambda_ = 0.95 |
| 56 | + self.epsilon = 0.3 |
| 57 | + self.device = device |
| 58 | + |
| 59 | + @torch.jit.export |
| 60 | + def dist_create(self, logits): |
| 61 | + """ |
| 62 | + Normal followed by tanh. |
| 63 | +
|
| 64 | + torch.distribution doesn't work with torch.jit, so we roll our own. |
| 65 | + """ |
| 66 | + loc, scale = torch.split(logits, logits.shape[-1] // 2, dim=-1) |
| 67 | + scale = F.softplus(scale) + .001 |
| 68 | + return loc, scale |
| 69 | + |
| 70 | + @torch.jit.export |
| 71 | + def dist_sample_no_postprocess(self, loc, scale): |
| 72 | + return torch.normal(loc, scale) |
| 73 | + |
| 74 | + @classmethod |
| 75 | + def dist_postprocess(cls, x): |
| 76 | + return torch.tanh(x) |
| 77 | + |
| 78 | + @torch.jit.export |
| 79 | + def dist_entropy(self, loc, scale): |
| 80 | + log_normalized = 0.5 * math.log(2 * math.pi) + torch.log(scale) |
| 81 | + entropy = 0.5 + log_normalized |
| 82 | + entropy = entropy * torch.ones_like(loc) |
| 83 | + dist = torch.normal(loc, scale) |
| 84 | + log_det_jacobian = 2 * (math.log(2) - dist - F.softplus(-2 * dist)) |
| 85 | + entropy = entropy + log_det_jacobian |
| 86 | + return entropy.sum(dim=-1) |
| 87 | + |
| 88 | + @torch.jit.export |
| 89 | + def dist_log_prob(self, loc, scale, dist): |
| 90 | + log_unnormalized = -0.5 * ((dist - loc) / scale).square() |
| 91 | + log_normalized = 0.5 * math.log(2 * math.pi) + torch.log(scale) |
| 92 | + log_det_jacobian = 2 * (math.log(2) - dist - F.softplus(-2 * dist)) |
| 93 | + log_prob = log_unnormalized - log_normalized - log_det_jacobian |
| 94 | + return log_prob.sum(dim=-1) |
| 95 | + |
| 96 | + @torch.jit.export |
| 97 | + def update_normalization(self, observation): |
| 98 | + self.num_steps += observation.shape[0] * observation.shape[1] |
| 99 | + input_to_old_mean = observation - self.running_mean |
| 100 | + mean_diff = torch.sum(input_to_old_mean / self.num_steps, dim=(0, 1)) |
| 101 | + self.running_mean = self.running_mean + mean_diff |
| 102 | + input_to_new_mean = observation - self.running_mean |
| 103 | + var_diff = torch.sum(input_to_new_mean * input_to_old_mean, dim=(0, 1)) |
| 104 | + self.running_variance = self.running_variance + var_diff |
| 105 | + |
| 106 | + @torch.jit.export |
| 107 | + def normalize(self, observation): |
| 108 | + variance = self.running_variance / (self.num_steps + 1.0) |
| 109 | + variance = torch.clip(variance, 1e-6, 1e6) |
| 110 | + return ((observation - self.running_mean) / variance.sqrt()).clip(-5, 5) |
| 111 | + |
| 112 | + @torch.jit.export |
| 113 | + def get_logits_action(self, observation): |
| 114 | + observation = self.normalize(observation) |
| 115 | + logits = self.policy(observation) |
| 116 | + loc, scale = self.dist_create(logits) |
| 117 | + action = self.dist_sample_no_postprocess(loc, scale) |
| 118 | + return logits, action |
| 119 | + |
| 120 | + @torch.jit.export |
| 121 | + def compute_gae(self, truncation, termination, reward, values, bootstrap_value): |
| 122 | + truncation_mask = 1 - truncation |
| 123 | + # Append bootstrapped value to get [v1, ..., v_t+1] |
| 124 | + values_t_plus_1 = torch.cat([values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0) |
| 125 | + deltas = reward + self.discounting * (1 - termination) * values_t_plus_1 - values |
| 126 | + deltas *= truncation_mask |
| 127 | + |
| 128 | + acc = torch.zeros_like(bootstrap_value) |
| 129 | + vs_minus_v_xs = torch.zeros_like(truncation_mask) |
| 130 | + |
| 131 | + for ti in range(truncation_mask.shape[0]): |
| 132 | + ti = truncation_mask.shape[0] - ti - 1 |
| 133 | + acc = deltas[ti] + self.discounting * (1 - termination[ti]) * truncation_mask[ti] * self.lambda_ * acc |
| 134 | + vs_minus_v_xs[ti] = acc |
| 135 | + |
| 136 | + # Add V(x_s) to get v_s. |
| 137 | + vs = vs_minus_v_xs + values |
| 138 | + vs_t_plus_1 = torch.cat([vs[1:], torch.unsqueeze(bootstrap_value, 0)], 0) |
| 139 | + advantages = (reward + self.discounting * (1 - termination) * vs_t_plus_1 - values) * truncation_mask |
| 140 | + return vs, advantages |
| 141 | + |
| 142 | + @torch.jit.export |
| 143 | + def loss(self, td: Dict[str, torch.Tensor]): |
| 144 | + observation = self.normalize(td['observation']) |
| 145 | + policy_logits = self.policy(observation[:-1]) |
| 146 | + baseline = self.value(observation) |
| 147 | + baseline = torch.squeeze(baseline, dim=-1) |
| 148 | + |
| 149 | + # Use last baseline value (from the value function) to bootstrap. |
| 150 | + bootstrap_value = baseline[-1] |
| 151 | + baseline = baseline[:-1] |
| 152 | + reward = td['reward'] * self.reward_scaling |
| 153 | + termination = td['done'] * (1 - td['truncation']) |
| 154 | + |
| 155 | + loc, scale = self.dist_create(td['logits']) |
| 156 | + behaviour_action_log_probs = self.dist_log_prob(loc, scale, td['action']) |
| 157 | + loc, scale = self.dist_create(policy_logits) |
| 158 | + target_action_log_probs = self.dist_log_prob(loc, scale, td['action']) |
| 159 | + |
| 160 | + with torch.no_grad(): |
| 161 | + vs, advantages = self.compute_gae( |
| 162 | + truncation=td['truncation'], |
| 163 | + termination=termination, |
| 164 | + reward=reward, |
| 165 | + values=baseline, |
| 166 | + bootstrap_value=bootstrap_value) |
| 167 | + |
| 168 | + rho_s = torch.exp(target_action_log_probs - behaviour_action_log_probs) |
| 169 | + surrogate_loss1 = rho_s * advantages |
| 170 | + surrogate_loss2 = rho_s.clip(1 - self.epsilon, 1 + self.epsilon) * advantages |
| 171 | + policy_loss = -torch.mean(torch.minimum(surrogate_loss1, surrogate_loss2)) |
| 172 | + |
| 173 | + # Value function loss |
| 174 | + v_error = vs - baseline |
| 175 | + v_loss = torch.mean(v_error * v_error) * 0.5 * 0.5 |
| 176 | + |
| 177 | + # Entropy reward |
| 178 | + entropy = torch.mean(self.dist_entropy(loc, scale)) |
| 179 | + entropy_loss = self.entropy_cost * -entropy |
| 180 | + |
| 181 | + return policy_loss + v_loss + entropy_loss |
| 182 | + |
| 183 | +StepData = collections.namedtuple('StepData', ('observation', 'logits', 'action', 'reward', 'done', 'truncation')) |
| 184 | + |
| 185 | +def sd_map(f: Callable[..., torch.Tensor], *sds) -> StepData: |
| 186 | + """ |
| 187 | + Map a function over each field in StepData. |
| 188 | + """ |
| 189 | + items = {} |
| 190 | + keys = sds[0]._asdict().keys() |
| 191 | + for k in keys: |
| 192 | + items[k] = f(*[sd._asdict()[k] for sd in sds]) |
| 193 | + return StepData(**items) |
| 194 | + |
| 195 | +def eval_unroll(agent, env, length): |
| 196 | + """ |
| 197 | + Return number of episodes and average reward for a single unroll. |
| 198 | + """ |
| 199 | + observation = env.reset() |
| 200 | + episodes = torch.zeros((), device=agent.device) |
| 201 | + episode_reward = torch.zeros((), device=agent.device) |
| 202 | + for _ in range(length): |
| 203 | + _, action = agent.get_logits_action(observation) |
| 204 | + observation, reward, done, _ = env.step(PPOAgent.dist_postprocess(action)) |
| 205 | + episodes += torch.sum(done) |
| 206 | + episode_reward += torch.sum(reward) |
| 207 | + return episodes, episode_reward / episodes |
| 208 | + |
| 209 | +def train_unroll(agent, env, observation, num_unrolls, unroll_length): |
| 210 | + """ |
| 211 | + Return step data over multple unrolls. |
| 212 | + """ |
| 213 | + sd = StepData([], [], [], [], [], []) |
| 214 | + for _ in range(num_unrolls): |
| 215 | + one_unroll = StepData([observation], [], [], [], [], []) |
| 216 | + for _ in range(unroll_length): |
| 217 | + logits, action = agent.get_logits_action(observation) |
| 218 | + observation, reward, done, info = env.step(PPOAgent.dist_postprocess(action)) |
| 219 | + one_unroll.observation.append(observation) |
| 220 | + one_unroll.logits.append(logits) |
| 221 | + one_unroll.action.append(action) |
| 222 | + one_unroll.reward.append(reward) |
| 223 | + one_unroll.done.append(done) |
| 224 | + one_unroll.truncation.append(info['truncation']) |
| 225 | + one_unroll = sd_map(torch.stack, one_unroll) |
| 226 | + sd = sd_map(lambda x, y: x + [y], sd, one_unroll) |
| 227 | + td = sd_map(torch.stack, sd) |
| 228 | + return observation, td |
| 229 | + |
| 230 | +def train( |
| 231 | + env_name: str = 'ant', |
| 232 | + # env_name: str = 'FetchSlide-v2', |
| 233 | + num_envs: int = 2048, |
| 234 | + episode_length: int = 1000, |
| 235 | + device: str = DEVICE, |
| 236 | + num_timesteps: int = 30_000_000, |
| 237 | + eval_frequency: int = 10, |
| 238 | + unroll_length: int = 5, |
| 239 | + batch_size: int = 1024, |
| 240 | + num_minibatches: int = 32, |
| 241 | + num_update_epochs: int = 4, |
| 242 | + reward_scaling: float = .1, |
| 243 | + entropy_cost: float = 1e-2, |
| 244 | + discounting: float = .97, |
| 245 | + learning_rate: float = 3e-4, |
| 246 | + progress_fn: Optional[Callable[[int, Dict[str, Any]], None]] = None, |
| 247 | + ): |
| 248 | + """ |
| 249 | + Trains a policy via PPO. |
| 250 | + """ |
| 251 | + gym_name = f'brax-{env_name}-v0' |
| 252 | + # gym_name = f'FetchSlide-v2' |
| 253 | + if gym_name not in gym.envs.registry.keys(): |
| 254 | + entry_point = functools.partial(envs.create_gym_env, env_name=env_name) |
| 255 | + gym.register(gym_name, entry_point=entry_point) |
| 256 | + # env = gym.make(gym_name, batch_size=num_envs, episode_length=episode_length) |
| 257 | + env = gym.make(gym_name) |
| 258 | + # automatically convert between jax ndarrays and torch tensors: |
| 259 | + env = to_torch.JaxToTorchWrapper(env, device=device) |
| 260 | + |
| 261 | + # env warmup |
| 262 | + env.reset() |
| 263 | + action = torch.zeros(env.action_space.shape).to(device) |
| 264 | + env.step(action) |
| 265 | + |
| 266 | + # create the agent |
| 267 | + policy_layers = [env.observation_space.shape[-1], 64, 64, env.action_space.shape[-1] * 2] |
| 268 | + value_layers = [env.observation_space.shape[-1], 64, 64, 1] |
| 269 | + agent = PPOAgent(policy_layers, value_layers, entropy_cost, discounting, reward_scaling, device) |
| 270 | + agent = torch.jit.script(agent.to(device)) |
| 271 | + optimizer = optim.Adam(agent.parameters(), lr=learning_rate) |
| 272 | + |
| 273 | + sps = 0 |
| 274 | + total_steps = 0 |
| 275 | + total_loss = 0 |
| 276 | + for eval_i in range(eval_frequency + 1): |
| 277 | + if progress_fn: |
| 278 | + t = time.time() |
| 279 | + with torch.no_grad(): |
| 280 | + episode_count, episode_reward = eval_unroll(agent, env, episode_length) |
| 281 | + duration = time.time() - t |
| 282 | + # TODO: only count stats from completed episodes |
| 283 | + episode_avg_length = env.num_envs * episode_length / episode_count |
| 284 | + eval_sps = env.num_envs * episode_length / duration |
| 285 | + progress = { |
| 286 | + 'eval/episode_reward': episode_reward, |
| 287 | + 'eval/completed_episodes': episode_count, |
| 288 | + 'eval/avg_episode_length': episode_avg_length, |
| 289 | + 'speed/sps': sps, |
| 290 | + 'speed/eval_sps': eval_sps, |
| 291 | + 'losses/total_loss': total_loss, |
| 292 | + } |
| 293 | + progress_fn(total_steps, progress) |
| 294 | + |
| 295 | + if eval_i == eval_frequency: break |
| 296 | + |
| 297 | + observation = env.reset() |
| 298 | + num_steps = batch_size * num_minibatches * unroll_length |
| 299 | + num_epochs = num_timesteps // (num_steps * eval_frequency) |
| 300 | + num_unrolls = batch_size * num_minibatches // env.num_envs |
| 301 | + total_loss = 0 |
| 302 | + t = time.time() |
| 303 | + for _ in range(num_epochs): |
| 304 | + observation, td = train_unroll(agent, env, observation, num_unrolls, unroll_length) |
| 305 | + |
| 306 | + # make unroll first |
| 307 | + def unroll_first(data): |
| 308 | + data = data.swapaxes(0, 1) |
| 309 | + return data.reshape([data.shape[0], -1] + list(data.shape[3:])) |
| 310 | + td = sd_map(unroll_first, td) |
| 311 | + |
| 312 | + # update normalization statistics |
| 313 | + agent.update_normalization(td.observation) |
| 314 | + |
| 315 | + for _ in range(num_update_epochs): |
| 316 | + # shuffle and batch the data |
| 317 | + with torch.no_grad(): |
| 318 | + permutation = torch.randperm(td.observation.shape[1], device=device) |
| 319 | + def shuffle_batch(data): |
| 320 | + data = data[:, permutation] |
| 321 | + data = data.reshape([data.shape[0], num_minibatches, -1] + list(data.shape[2:])) |
| 322 | + return data.swapaxes(0, 1) |
| 323 | + epoch_td = sd_map(shuffle_batch, td) |
| 324 | + |
| 325 | + for minibatch_i in range(num_minibatches): |
| 326 | + td_minibatch = sd_map(lambda d: d[minibatch_i], epoch_td) |
| 327 | + loss = agent.loss(td_minibatch._asdict()) |
| 328 | + optimizer.zero_grad() |
| 329 | + loss.backward() |
| 330 | + optimizer.step() |
| 331 | + total_loss += loss.detach() |
| 332 | + |
| 333 | + duration = time.time() - t |
| 334 | + total_steps += num_epochs * num_steps |
| 335 | + total_loss = total_loss / (num_epochs * num_update_epochs * num_minibatches) |
| 336 | + sps = num_epochs * num_steps / duration |
| 337 | + |
| 338 | + return agent, env |
| 339 | + |
| 340 | +def main() -> None: |
| 341 | + xdata, ydata, eval_sps, train_sps, times = [], [], [], [], [datetime.now()] |
| 342 | + |
| 343 | + def progress(num_steps, metrics): |
| 344 | + times.append(datetime.now()) |
| 345 | + xdata.append(num_steps) |
| 346 | + # copy to cpu, otherwise matplotlib throws an exception |
| 347 | + reward = metrics['eval/episode_reward'].cpu() |
| 348 | + ydata.append(reward) |
| 349 | + eval_sps.append(metrics['speed/eval_sps']) |
| 350 | + train_sps.append(metrics['speed/sps']) |
| 351 | + |
| 352 | + # plt.xlim([0, 30_000_000]) |
| 353 | + # plt.ylim([0, 6_000]) |
| 354 | + # plt.xlabel('# environment steps') |
| 355 | + # plt.ylabel('reward per episode') |
| 356 | + # plt.plot(xdata, ydata) |
| 357 | + # plt.show() |
| 358 | + |
| 359 | + agent, env = train(progress_fn=progress) |
| 360 | + |
| 361 | + print(f'time to jit: {times[1] - times[0]}') |
| 362 | + print(f'time to train: {times[-1] - times[1]}') |
| 363 | + print(f'eval steps/sec: {np.mean(eval_sps[1:])}') |
| 364 | + print(f'train steps/sec: {np.mean(train_sps[1:])}') |
0 commit comments