-
Notifications
You must be signed in to change notification settings - Fork 424
Adds perceptive actor-critic class #114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
pascal-roth
wants to merge
10
commits into
main
Choose a base branch
from
feature/perceptive-nav-rl
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+494
−31
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
677a7b2
add files for perceptive example
pascal-roth 6b2910f
working training
pascal-roth 6edccbe
formatter
pascal-roth 364dcab
formatting 1
ClemensSchwarke 1517bd0
formatting 2
ClemensSchwarke d7bfc7a
CNN docstrings
ClemensSchwarke 0a1756d
format actor_critic_perceptive
ClemensSchwarke 3132a7e
extend CNN to more configuration options and better exportability
ClemensSchwarke cd53d29
fixes
ClemensSchwarke b0e24ae
rename perceptive actor critic
ClemensSchwarke File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,262 @@ | ||
| # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION | ||
| # All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from tensordict import TensorDict | ||
| from torch.distributions import Normal | ||
| from typing import Any | ||
|
|
||
| from rsl_rl.networks import CNN, MLP, EmpiricalNormalization | ||
|
|
||
| from .actor_critic import ActorCritic | ||
|
|
||
|
|
||
| class ActorCriticCNN(ActorCritic): | ||
| def __init__( | ||
| self, | ||
| obs: TensorDict, | ||
| obs_groups: dict[str, list[str]], | ||
| num_actions: int, | ||
| actor_obs_normalization: bool = False, | ||
| critic_obs_normalization: bool = False, | ||
| actor_hidden_dims: tuple[int] | list[int] = [256, 256, 256], | ||
| critic_hidden_dims: tuple[int] | list[int] = [256, 256, 256], | ||
| actor_cnn_cfg: dict[str, dict] | dict | None = None, | ||
| critic_cnn_cfg: dict[str, dict] | dict | None = None, | ||
| activation: str = "elu", | ||
| init_noise_std: float = 1.0, | ||
| noise_std_type: str = "scalar", | ||
| state_dependent_std: bool = False, | ||
| **kwargs: dict[str, Any], | ||
| ) -> None: | ||
| if kwargs: | ||
| print( | ||
| "ActorCriticCNN.__init__ got unexpected arguments, which will be ignored: " | ||
| + str([key for key in kwargs]) | ||
| ) | ||
| super(ActorCritic, self).__init__() | ||
|
|
||
| # Get the observation dimensions | ||
| self.obs_groups = obs_groups | ||
| num_actor_obs_1d = 0 | ||
| self.actor_obs_groups_1d = [] | ||
| actor_in_dims_2d = [] | ||
| actor_in_channels_2d = [] | ||
| self.actor_obs_groups_2d = [] | ||
| for obs_group in obs_groups["policy"]: | ||
| if len(obs[obs_group].shape) == 4: # B, C, H, W | ||
| self.actor_obs_groups_2d.append(obs_group) | ||
| actor_in_dims_2d.append(obs[obs_group].shape[2:4]) | ||
| actor_in_channels_2d.append(obs[obs_group].shape[1]) | ||
| elif len(obs[obs_group].shape) == 2: # B, C | ||
| self.actor_obs_groups_1d.append(obs_group) | ||
| num_actor_obs_1d += obs[obs_group].shape[-1] | ||
| else: | ||
| raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") | ||
| num_critic_obs_1d = 0 | ||
| self.critic_obs_groups_1d = [] | ||
| critic_in_dims_2d = [] | ||
| critic_in_channels_2d = [] | ||
| self.critic_obs_groups_2d = [] | ||
| for obs_group in obs_groups["critic"]: | ||
| if len(obs[obs_group].shape) == 4: # B, C, H, W | ||
| self.critic_obs_groups_2d.append(obs_group) | ||
| critic_in_dims_2d.append(obs[obs_group].shape[2:4]) | ||
| critic_in_channels_2d.append(obs[obs_group].shape[1]) | ||
| elif len(obs[obs_group].shape) == 2: # B, C | ||
| self.critic_obs_groups_1d.append(obs_group) | ||
| num_critic_obs_1d += obs[obs_group].shape[-1] | ||
| else: | ||
| raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") | ||
ClemensSchwarke marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Assert that there are 2D observations | ||
| assert self.actor_obs_groups_2d or self.critic_obs_groups_2d, ( | ||
| "No 2D observations are provided. If this is intentional, use the ActorCritic module instead." | ||
| ) | ||
|
|
||
| # Actor CNN | ||
| if self.actor_obs_groups_2d: | ||
| # Resolve the actor CNN configuration | ||
| assert actor_cnn_cfg is not None, "An actor CNN configuration is required for 2D actor observations." | ||
| # If a single configuration dictionary is provided, create a dictionary for each 2D observation group | ||
| if not all(isinstance(v, dict) for v in actor_cnn_cfg.values()): | ||
| actor_cnn_cfg = {group: actor_cnn_cfg for group in self.actor_obs_groups_2d} | ||
| # Check that the number of configs matches the number of observation groups | ||
| assert len(actor_cnn_cfg) == len(self.actor_obs_groups_2d), ( | ||
| "The number of CNN configurations must match the number of 2D actor observations." | ||
| ) | ||
|
|
||
| # Create CNNs for each 2D actor observation | ||
| self.actor_cnns = nn.ModuleDict() | ||
| encoding_dim = 0 | ||
| for idx, obs_group in enumerate(self.actor_obs_groups_2d): | ||
| self.actor_cnns[obs_group] = CNN( | ||
| input_dim=actor_in_dims_2d[idx], | ||
| input_channels=actor_in_channels_2d[idx], | ||
| **actor_cnn_cfg[obs_group], | ||
| ) | ||
| print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}") | ||
| # Get the output dimension of the CNN | ||
| if self.actor_cnns[obs_group].output_channels is None: | ||
| encoding_dim += int(self.actor_cnns[obs_group].output_dim) # type: ignore | ||
| else: | ||
| raise ValueError("The output of the actor CNN must be flattened before passing it to the MLP.") | ||
| else: | ||
| self.actor_cnns = None | ||
| encoding_dim = 0 | ||
|
|
||
| # Actor MLP | ||
| self.state_dependent_std = state_dependent_std | ||
| if self.state_dependent_std: | ||
| self.actor = MLP(num_actor_obs_1d + encoding_dim, [2, num_actions], actor_hidden_dims, activation) | ||
ClemensSchwarke marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| else: | ||
| self.actor = MLP(num_actor_obs_1d + encoding_dim, num_actions, actor_hidden_dims, activation) | ||
| print(f"Actor MLP: {self.actor}") | ||
|
|
||
| # Actor observation normalization (only for 1D actor observations) | ||
| self.actor_obs_normalization = actor_obs_normalization | ||
| if actor_obs_normalization: | ||
| self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs_1d) | ||
| else: | ||
| self.actor_obs_normalizer = torch.nn.Identity() | ||
|
|
||
| # Critic CNN | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for sharing CNN between Actor and Critic, can we just add a boolean flag and if true the critic CNN is just the actor one? |
||
| if self.critic_obs_groups_2d: | ||
| # Resolve the critic CNN configuration | ||
| assert critic_cnn_cfg is not None, "A critic CNN configuration is required for 2D critic observations." | ||
| # If a single configuration dictionary is provided, create a dictionary for each 2D observation group | ||
| if not all(isinstance(v, dict) for v in critic_cnn_cfg.values()): | ||
| critic_cnn_cfg = {group: critic_cnn_cfg for group in self.critic_obs_groups_2d} | ||
| # Check that the number of configs matches the number of observation groups | ||
| assert len(critic_cnn_cfg) == len(self.critic_obs_groups_2d), ( | ||
| "The number of CNN configurations must match the number of 2D critic observations." | ||
| ) | ||
|
|
||
| # Create CNNs for each 2D critic observation | ||
| self.critic_cnns = nn.ModuleDict() | ||
| encoding_dim = 0 | ||
| for idx, obs_group in enumerate(self.critic_obs_groups_2d): | ||
| self.critic_cnns[obs_group] = CNN( | ||
| input_dim=critic_in_dims_2d[idx], | ||
| input_channels=critic_in_channels_2d[idx], | ||
| **critic_cnn_cfg[obs_group], | ||
| ) | ||
| print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}") | ||
| # Get the output dimension of the CNN | ||
| if self.critic_cnns[obs_group].output_channels is None: | ||
| encoding_dim += int(self.critic_cnns[obs_group].output_dim) # type: ignore | ||
| else: | ||
| raise ValueError("The output of the critic CNN must be flattened before passing it to the MLP.") | ||
| else: | ||
| self.critic_cnns = None | ||
| encoding_dim = 0 | ||
|
|
||
| # Critic MLP | ||
| self.critic = MLP(num_critic_obs_1d + encoding_dim, 1, critic_hidden_dims, activation) | ||
| print(f"Critic MLP: {self.critic}") | ||
|
|
||
| # Critic observation normalization (only for 1D critic observations) | ||
| self.critic_obs_normalization = critic_obs_normalization | ||
| if critic_obs_normalization: | ||
| self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs_1d) | ||
| else: | ||
| self.critic_obs_normalizer = torch.nn.Identity() | ||
|
|
||
| # Action noise | ||
| self.noise_std_type = noise_std_type | ||
ClemensSchwarke marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if self.state_dependent_std: | ||
| torch.nn.init.zeros_(self.actor[-2].weight[num_actions:]) | ||
| if self.noise_std_type == "scalar": | ||
| torch.nn.init.constant_(self.actor[-2].bias[num_actions:], init_noise_std) | ||
| elif self.noise_std_type == "log": | ||
| torch.nn.init.constant_( | ||
| self.actor[-2].bias[num_actions:], torch.log(torch.tensor(init_noise_std + 1e-7)) | ||
| ) | ||
| else: | ||
| raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") | ||
| else: | ||
| if self.noise_std_type == "scalar": | ||
| self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) | ||
| elif self.noise_std_type == "log": | ||
| self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions))) | ||
| else: | ||
| raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") | ||
|
|
||
| # Action distribution | ||
| # Note: Populated in update_distribution | ||
| self.distribution = None | ||
|
|
||
| # Disable args validation for speedup | ||
| Normal.set_default_validate_args(False) | ||
|
|
||
| def _update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]) -> None: | ||
| if self.actor_cnns is not None: | ||
ClemensSchwarke marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Encode the 2D actor observations | ||
| cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d] | ||
| cnn_enc = torch.cat(cnn_enc_list, dim=-1) | ||
| # Concatenate to the MLP observations | ||
| mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) | ||
|
|
||
| super()._update_distribution(mlp_obs) | ||
|
|
||
| def act(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor: | ||
| mlp_obs, cnn_obs = self.get_actor_obs(obs) | ||
| mlp_obs = self.actor_obs_normalizer(mlp_obs) | ||
| self._update_distribution(mlp_obs, cnn_obs) | ||
| return self.distribution.sample() | ||
|
|
||
| def act_inference(self, obs: TensorDict) -> torch.Tensor: | ||
| mlp_obs, cnn_obs = self.get_actor_obs(obs) | ||
| mlp_obs = self.actor_obs_normalizer(mlp_obs) | ||
|
|
||
| if self.actor_cnns is not None: | ||
| # Encode the 2D actor observations | ||
| cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d] | ||
| cnn_enc = torch.cat(cnn_enc_list, dim=-1) | ||
| # Concatenate to the MLP observations | ||
| mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) | ||
|
|
||
| if self.state_dependent_std: | ||
| return self.actor(mlp_obs)[..., 0, :] | ||
| else: | ||
| return self.actor(mlp_obs) | ||
|
|
||
| def evaluate(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor: | ||
| mlp_obs, cnn_obs = self.get_critic_obs(obs) | ||
| mlp_obs = self.critic_obs_normalizer(mlp_obs) | ||
|
|
||
| if self.critic_cnns is not None: | ||
| # Encode the 2D critic observations | ||
| cnn_enc_list = [self.critic_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.critic_obs_groups_2d] | ||
| cnn_enc = torch.cat(cnn_enc_list, dim=-1) | ||
| # Concatenate to the MLP observations | ||
| mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) | ||
|
|
||
| return self.critic(mlp_obs) | ||
|
|
||
| def get_actor_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: | ||
| obs_list_1d = [obs[obs_group] for obs_group in self.actor_obs_groups_1d] | ||
| obs_dict_2d = {} | ||
| for obs_group in self.actor_obs_groups_2d: | ||
| obs_dict_2d[obs_group] = obs[obs_group] | ||
| return torch.cat(obs_list_1d, dim=-1), obs_dict_2d | ||
epalmaEth marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def get_critic_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: | ||
| obs_list_1d = [obs[obs_group] for obs_group in self.critic_obs_groups_1d] | ||
| obs_dict_2d = {} | ||
| for obs_group in self.critic_obs_groups_2d: | ||
| obs_dict_2d[obs_group] = obs[obs_group] | ||
| return torch.cat(obs_list_1d, dim=-1), obs_dict_2d | ||
|
|
||
| def update_normalization(self, obs: TensorDict) -> None: | ||
| if self.actor_obs_normalization: | ||
| actor_obs, _ = self.get_actor_obs(obs) | ||
| self.actor_obs_normalizer.update(actor_obs) | ||
| if self.critic_obs_normalization: | ||
| critic_obs, _ = self.get_critic_obs(obs) | ||
| self.critic_obs_normalizer.update(critic_obs) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.