Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ reportMissingImports = "none"
# This is required to ignore type checks of modules with stubs missing.
reportMissingModuleSource = "none" # -> most common: prettytable in mdp managers
reportGeneralTypeIssues = "none" # -> usage of literal MISSING in dataclasses
reportOptionalMemberAccess = "warning"
reportOptionalMemberAccess = "none"
reportPrivateUsage = "warning"
28 changes: 5 additions & 23 deletions rsl_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class Distillation:
def __init__(
self,
policy: StudentTeacher | StudentTeacherRecurrent,
storage: RolloutStorage,
num_learning_epochs: int = 1,
gradient_length: int = 15,
learning_rate: float = 1e-3,
Expand All @@ -46,12 +47,12 @@ def __init__(
# Distillation components
self.policy = policy
self.policy.to(self.device)
self.storage = None # Initialized later

# Initialize the optimizer
# Create the optimizer
self.optimizer = resolve_optimizer(optimizer)(self.policy.parameters(), lr=learning_rate)

# Initialize the transition
# Add storage
self.storage = storage
self.transition = RolloutStorage.Transition()
self.last_hidden_states = (None, None)

Expand All @@ -73,24 +74,6 @@ def __init__(

self.num_updates = 0

def init_storage(
self,
training_type: str,
num_envs: int,
num_transitions_per_env: int,
obs: TensorDict,
actions_shape: tuple[int],
) -> None:
# Create rollout storage
self.storage = RolloutStorage(
training_type,
num_envs,
num_transitions_per_env,
obs,
actions_shape,
self.device,
)

def act(self, obs: TensorDict) -> torch.Tensor:
# Compute the actions
self.transition.actions = self.policy.act(obs).detach()
Expand All @@ -104,12 +87,11 @@ def process_env_step(
) -> None:
# Update the normalizers
self.policy.update_normalization(obs)

# Record the rewards and dones
self.transition.rewards = rewards
self.transition.dones = dones
# Record the transition
self.storage.add_transitions(self.transition)
self.storage.add_transition(self.transition)
self.transition.clear()
self.policy.reset(dones)

Expand Down
51 changes: 25 additions & 26 deletions rsl_rl/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class PPO:
def __init__(
self,
policy: ActorCritic | ActorCriticRecurrent,
storage: RolloutStorage,
num_learning_epochs: int = 5,
num_mini_batches: int = 4,
clip_param: float = 0.2,
Expand All @@ -38,8 +39,8 @@ def __init__(
use_clipped_value_loss: bool = True,
schedule: str = "adaptive",
desired_kl: float = 0.01,
device: str = "cpu",
normalize_advantage_per_mini_batch: bool = False,
device: str = "cpu",
# RND parameters
rnd_cfg: dict | None = None,
# Symmetry parameters
Expand Down Expand Up @@ -100,11 +101,11 @@ def __init__(
self.policy = policy
self.policy.to(self.device)

# Create optimizer
# Create the optimizer
self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)

# Create rollout storage
self.storage: RolloutStorage | None = None
# Add storage
self.storage = storage
self.transition = RolloutStorage.Transition()

# PPO parameters
Expand All @@ -122,24 +123,6 @@ def __init__(
self.learning_rate = learning_rate
self.normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch

def init_storage(
self,
training_type: str,
num_envs: int,
num_transitions_per_env: int,
obs: TensorDict,
actions_shape: tuple[int] | list[int],
) -> None:
# Create rollout storage
self.storage = RolloutStorage(
training_type,
num_envs,
num_transitions_per_env,
obs,
actions_shape,
self.device,
)

def act(self, obs: TensorDict) -> torch.Tensor:
if self.policy.is_recurrent:
self.transition.hidden_states = self.policy.get_hidden_states()
Expand Down Expand Up @@ -180,16 +163,32 @@ def process_env_step(
)

# Record the transition
self.storage.add_transitions(self.transition)
self.storage.add_transition(self.transition)
self.transition.clear()
self.policy.reset(dones)

def compute_returns(self, obs: TensorDict) -> None:
st = self.storage
# Compute value for the last step
last_values = self.policy.evaluate(obs).detach()
self.storage.compute_returns(
last_values, self.gamma, self.lam, normalize_advantage=not self.normalize_advantage_per_mini_batch
)
# Compute returns and advantages
advantage = 0
for step in reversed(range(st.num_transitions_per_env)):
# If we are at the last step, bootstrap the return value
next_values = last_values if step == st.num_transitions_per_env - 1 else st.values[step + 1]
# 1 if we are not in a terminal state, 0 otherwise
next_is_not_terminal = 1.0 - st.dones[step].float()
# TD error: r_t + gamma * V(s_{t+1}) - V(s_t)
delta = st.rewards[step] + next_is_not_terminal * self.gamma * next_values - st.values[step]
# Advantage: A(s_t, a_t) = delta_t + gamma * lambda * A(s_{t+1}, a_{t+1})
advantage = delta + next_is_not_terminal * self.gamma * self.lam * advantage
# Return: R_t = A(s_t, a_t) + V(s_t)
st.returns[step] = advantage + st.values[step]
# Compute the advantages
st.advantages = st.returns - st.values
# Normalize the advantages if per minibatch normalization is not used
if not self.normalize_advantage_per_mini_batch:
st.advantages = (st.advantages - st.advantages.mean()) / (st.advantages.std() + 1e-8)

def update(self) -> dict[str, float]:
mean_value_loss = 0
Expand Down
17 changes: 7 additions & 10 deletions rsl_rl/runners/distillation_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from rsl_rl.env import VecEnv
from rsl_rl.modules import StudentTeacher, StudentTeacherRecurrent
from rsl_rl.runners import OnPolicyRunner
from rsl_rl.storage import RolloutStorage
from rsl_rl.utils import resolve_obs_groups, store_code_state


Expand Down Expand Up @@ -158,19 +159,15 @@ def _construct_algorithm(self, obs: TensorDict) -> Distillation:
obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg
).to(self.device)

# Initialize the storage
storage = RolloutStorage(
"distillation", self.env.num_envs, self.num_steps_per_env, obs, [self.env.num_actions], self.device
)

# Initialize the algorithm
alg_class = eval(self.alg_cfg.pop("class_name"))
alg: Distillation = alg_class(
student_teacher, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg
)

# Initialize the storage
alg.init_storage(
"distillation",
self.env.num_envs,
self.num_steps_per_env,
obs,
[self.env.num_actions],
student_teacher, storage, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg
)

return alg
17 changes: 8 additions & 9 deletions rsl_rl/runners/on_policy_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from rsl_rl.algorithms import PPO
from rsl_rl.env import VecEnv
from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, resolve_rnd_config, resolve_symmetry_config
from rsl_rl.storage import RolloutStorage
from rsl_rl.utils import resolve_obs_groups, store_code_state


Expand Down Expand Up @@ -418,17 +419,15 @@ def _construct_algorithm(self, obs: TensorDict) -> PPO:
obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg
).to(self.device)

# Initialize the storage
storage = RolloutStorage(
"rl", self.env.num_envs, self.num_steps_per_env, obs, [self.env.num_actions], self.device
)

# Initialize the algorithm
alg_class = eval(self.alg_cfg.pop("class_name"))
alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg)

# Initialize the storage
alg.init_storage(
"rl",
self.env.num_envs,
self.num_steps_per_env,
obs,
[self.env.num_actions],
alg: PPO = alg_class(
actor_critic, storage, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg
)

return alg
Expand Down
75 changes: 30 additions & 45 deletions rsl_rl/storage/rollout_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@


class RolloutStorage:
"""Storage for the data collected during a rollout.

The rollout storage is populated by adding transitions during the rollout phase. It then returns a generator for
learning, depending on the algorithm and the policy architecture.
"""

class Transition:
"""Storage for a single state transition."""

def __init__(self) -> None:
self.observations: TensorDict | None = None
self.actions: torch.Tensor | None = None
Expand Down Expand Up @@ -75,7 +83,7 @@ def __init__(
# Counter for the number of transitions stored
self.step = 0

def add_transitions(self, transition: Transition) -> None:
def add_transition(self, transition: Transition) -> None:
# Check if the transition is valid
if self.step >= self.num_transitions_per_env:
raise OverflowError("Rollout buffer overflow! You should call clear() before adding new transitions.")
Expand Down Expand Up @@ -103,53 +111,9 @@ def add_transitions(self, transition: Transition) -> None:
# Increment the counter
self.step += 1

def _save_hidden_states(self, hidden_states: tuple[HiddenState, HiddenState]) -> None:
if hidden_states == (None, None):
return
# Make a tuple out of GRU hidden states to match the LSTM format
hidden_state_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
hidden_state_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)
# Initialize hidden states if needed
if self.saved_hidden_state_a is None:
self.saved_hidden_state_a = [
torch.zeros(self.observations.shape[0], *hidden_state_a[i].shape, device=self.device)
for i in range(len(hidden_state_a))
]
self.saved_hidden_state_c = [
torch.zeros(self.observations.shape[0], *hidden_state_c[i].shape, device=self.device)
for i in range(len(hidden_state_c))
]
# Copy the states
for i in range(len(hidden_state_a)):
self.saved_hidden_state_a[i][self.step].copy_(hidden_state_a[i])
self.saved_hidden_state_c[i][self.step].copy_(hidden_state_c[i])

def clear(self) -> None:
self.step = 0

def compute_returns(
self, last_values: torch.Tensor, gamma: float, lam: float, normalize_advantage: bool = True
) -> None:
advantage = 0
for step in reversed(range(self.num_transitions_per_env)):
# If we are at the last step, bootstrap the return value
next_values = last_values if step == self.num_transitions_per_env - 1 else self.values[step + 1]
# 1 if we are not in a terminal state, 0 otherwise
next_is_not_terminal = 1.0 - self.dones[step].float()
# TD error: r_t + gamma * V(s_{t+1}) - V(s_t)
delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
# Advantage: A(s_t, a_t) = delta_t + gamma * lambda * A(s_{t+1}, a_{t+1})
advantage = delta + next_is_not_terminal * gamma * lam * advantage
# Return: R_t = A(s_t, a_t) + V(s_t)
self.returns[step] = advantage + self.values[step]

# Compute the advantages
self.advantages = self.returns - self.values
# Normalize the advantages if flag is set
# Note: This is to prevent double normalization (i.e. if per minibatch normalization is used)
if normalize_advantage:
self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)

# For distillation
def generator(self) -> Generator:
if self.training_type != "distillation":
Expand Down Expand Up @@ -289,3 +253,24 @@ def recurrent_mini_batch_generator(self, num_mini_batches: int, num_epochs: int
)

first_traj = last_traj

def _save_hidden_states(self, hidden_states: tuple[HiddenState, HiddenState]) -> None:
if hidden_states == (None, None):
return
# Make a tuple out of GRU hidden states to match the LSTM format
hidden_state_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
hidden_state_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)
# Initialize hidden states if needed
if self.saved_hidden_state_a is None:
self.saved_hidden_state_a = [
torch.zeros(self.observations.shape[0], *hidden_state_a[i].shape, device=self.device)
for i in range(len(hidden_state_a))
]
self.saved_hidden_state_c = [
torch.zeros(self.observations.shape[0], *hidden_state_c[i].shape, device=self.device)
for i in range(len(hidden_state_c))
]
# Copy the states
for i in range(len(hidden_state_a)):
self.saved_hidden_state_a[i][self.step].copy_(hidden_state_a[i])
self.saved_hidden_state_c[i][self.step].copy_(hidden_state_c[i])