From 677a7b20eb7b323a4bffaaf7ba252de7a7aee22b Mon Sep 17 00:00:00 2001 From: Pascal Roth Date: Sat, 13 Sep 2025 18:41:06 +0200 Subject: [PATCH 01/10] add files for perceptive example --- rsl_rl/modules/__init__.py | 2 + rsl_rl/modules/perceptive_actor_critic.py | 236 ++++++++++++++++++++++ rsl_rl/networks/__init__.py | 1 + rsl_rl/networks/cnn.py | 94 +++++++++ 4 files changed, 333 insertions(+) create mode 100644 rsl_rl/modules/perceptive_actor_critic.py create mode 100644 rsl_rl/networks/cnn.py diff --git a/rsl_rl/modules/__init__.py b/rsl_rl/modules/__init__.py index efb8613a..04a684c1 100644 --- a/rsl_rl/modules/__init__.py +++ b/rsl_rl/modules/__init__.py @@ -7,6 +7,7 @@ from .actor_critic import ActorCritic from .actor_critic_recurrent import ActorCriticRecurrent +from .perceptive_actor_critic import PerceptiveActorCritic from .rnd import RandomNetworkDistillation, resolve_rnd_config from .student_teacher import StudentTeacher from .student_teacher_recurrent import StudentTeacherRecurrent @@ -15,6 +16,7 @@ __all__ = [ "ActorCritic", "ActorCriticRecurrent", + "PerceptiveActorCritic", "RandomNetworkDistillation", "StudentTeacher", "StudentTeacherRecurrent", diff --git a/rsl_rl/modules/perceptive_actor_critic.py b/rsl_rl/modules/perceptive_actor_critic.py new file mode 100644 index 00000000..ff270645 --- /dev/null +++ b/rsl_rl/modules/perceptive_actor_critic.py @@ -0,0 +1,236 @@ +# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import torch +import torch.nn as nn +from torch.distributions import Normal + +from .actor_critic import ActorCritic + +from rsl_rl.networks import MLP, CNN, CNNConfig, EmpiricalNormalization + + +class PerceptiveActorCritic(ActorCritic): + def __init__( + self, + obs, + obs_groups, + num_actions, + actor_obs_normalization: bool = False, + critic_obs_normalization: bool = False, + actor_hidden_dims: list[int] = [256, 256, 256], + critic_hidden_dims: list[int] = [256, 256, 256], + actor_cnn_config: dict[str, CNNConfig] | CNNConfig | None = None, + critic_cnn_config: dict[str, CNNConfig] | CNNConfig | None = None, + activation: str = "elu", + init_noise_std: float = 1.0, + noise_std_type: str = "scalar", + **kwargs, + ): + if kwargs: + print( + "PerceptiveActorCritic.__init__ got unexpected arguments, which will be ignored: " + + str([key for key in kwargs.keys()]) + ) + nn.Module.__init__(self) + + # get the observation dimensions + self.obs_groups = obs_groups + num_actor_obs = 0 + num_actor_in_channels = [] + self.actor_obs_group_1d = [] + self.actor_obs_group_2d = [] + for obs_group in obs_groups["policy"]: + if len(obs[obs_group].shape) == 2: # FIXME: should be 3??? + self.actor_obs_group_2d.append(obs_group) + num_actor_in_channels.append(obs[obs_group].shape[0]) + elif len(obs[obs_group].shape) == 1: + self.actor_obs_group_1d.append(obs_group) + num_actor_obs += obs[obs_group].shape[-1] + else: + raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") + + self.critic_obs_group_1d = [] + self.critic_obs_group_2d = [] + num_critic_obs = 0 + num_critic_in_channels = [] + for obs_group in obs_groups["critic"]: + if len(obs[obs_group].shape) == 2: # FIXME: should be 3??? + self.critic_obs_group_2d.append(obs_group) + num_critic_in_channels.append(obs[obs_group].shape[0]) + else: + self.critic_obs_group_1d.append(obs_group) + num_critic_obs += obs[obs_group].shape[-1] + + # actor cnn + if self.actor_obs_group_2d: + assert actor_cnn_config is not None, "Actor CNN config is required for 2D actor observations." + + # check if multiple 2D actor observations are provided + if len(self.actor_obs_group_2d) > 1 and isinstance(actor_cnn_config, CNNConfig): + print(f"Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups.") + actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config] * len(self.actor_obs_group_2d))) + elif len(self.actor_obs_group_2d) > 1 and isinstance(actor_cnn_config, dict): + assert len(actor_cnn_config) == len(self.actor_obs_group_2d), "Number of CNN configs must match number of 2D actor observations." + elif len(self.actor_obs_group_2d) == 1 and isinstance(actor_cnn_config, CNNConfig): + actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config])) + else: + raise ValueError(f"Invalid combination of 2D actor observations {self.actor_obs_group_2d} and actor CNN config {actor_cnn_config}.") + + self.actor_cnns = {} + encoding_dims = [] + for idx, obs_group in enumerate(self.actor_obs_group_2d): + self.actor_cnns[obs_group] = CNN(actor_cnn_config[obs_group], num_actor_in_channels[idx], activation) + print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}") + + # compute the encoding dimension + encoding_dims.append(self.actor_cnns[obs_group](obs[obs_group]).shape[-1]) + + encoding_dim = sum(encoding_dims) + else: + self.actor_cnns = None + encoding_dim = 0 + + # actor mlp + self.actor = MLP(num_actor_obs + encoding_dim, num_actions, actor_hidden_dims, activation) + + # actor observation normalization (only for 1D actor observations) + self.actor_obs_normalization = actor_obs_normalization + if actor_obs_normalization: + self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs) + else: + self.actor_obs_normalizer = torch.nn.Identity() + print(f"Actor MLP: {self.actor}") + + # critic cnn + if self.critic_obs_group_2d: + assert critic_cnn_config is not None, "Critic CNN config is required for 2D critic observations." + + # check if multiple 2D critic observations are provided + if len(self.critic_obs_group_2d) > 1 and isinstance(critic_cnn_config, CNNConfig): + print(f"Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups.") + critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config] * len(self.critic_obs_group_2d))) + elif len(self.critic_obs_group_2d) > 1 and isinstance(critic_cnn_config, dict): + assert len(critic_cnn_config) == len(self.critic_obs_group_2d), "Number of CNN configs must match number of 2D critic observations." + elif len(self.critic_obs_group_2d) == 1 and isinstance(critic_cnn_config, CNNConfig): + critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config])) + else: + raise ValueError(f"Invalid combination of 2D critic observations {self.critic_obs_group_2d} and critic CNN config {critic_cnn_config}.") + + self.critic_cnns = {} + encoding_dims = [] + for idx, obs_group in enumerate(self.critic_obs_group_2d): + self.critic_cnns[obs_group] = CNN(critic_cnn_config[obs_group], num_critic_in_channels[idx], activation) + print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}") + + # compute the encoding dimension + encoding_dims.append(self.critic_cnns[obs_group](obs[obs_group]).shape[-1]) + + encoding_dim = sum(encoding_dims) + else: + self.critic_cnns = None + encoding_dim = 0 + + # critic mlp + self.critic = MLP(num_critic_obs + encoding_dim, 1, critic_hidden_dims, activation) + + # critic observation normalization (only for 1D critic observations) + self.critic_obs_normalization = critic_obs_normalization + if critic_obs_normalization: + self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs) + else: + self.critic_obs_normalizer = torch.nn.Identity() + print(f"Critic MLP: {self.critic}") + + # Action noise + self.noise_std_type = noise_std_type + if self.noise_std_type == "scalar": + self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) + elif self.noise_std_type == "log": + self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions))) + else: + raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") + + # Action distribution (populated in update_distribution) + self.distribution: Normal = None + # disable args validation for speedup + Normal.set_default_validate_args(False) + + def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]): + + if self.actor_cnns is not None: + # encode the 2D actor observations + cnn_enc_list = [] + for obs_group in self.actor_obs_group_2d: + cnn_enc_list.append(self.actor_cnns[obs_group](cnn_obs[obs_group])) + cnn_enc = torch.cat(cnn_enc_list, dim=-1) + # update mlp obs + mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) + + super().update_distribution(mlp_obs) + + def act(self, obs, **kwargs): + mlp_obs, cnn_obs = self.get_actor_obs(obs) + mlp_obs = self.actor_obs_normalizer(mlp_obs) + self.update_distribution(mlp_obs, cnn_obs) + return self.distribution.sample() + + def act_inference(self, obs): + mlp_obs, cnn_obs = self.get_actor_obs(obs) + mlp_obs = self.actor_obs_normalizer(mlp_obs) + + if self.actor_cnns is not None: + # encode the 2D actor observations + cnn_enc_list = [] + for obs_group in self.actor_obs_group_2d: + cnn_enc_list.append(self.actor_cnns[obs_group](cnn_obs[obs_group])) + cnn_enc = torch.cat(cnn_enc_list, dim=-1) + # update mlp obs + mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) + + return self.actor(mlp_obs) + + def evaluate(self, obs, **kwargs): + mlp_obs, cnn_obs = self.get_critic_obs(obs) + mlp_obs = self.critic_obs_normalizer(mlp_obs) + + if self.critic_cnns is not None: + # encode the 2D critic observations + cnn_enc_list = [] + for obs_group in self.critic_obs_group_2d: + cnn_enc_list.append(self.critic_cnns[obs_group](cnn_obs[obs_group])) + cnn_enc = torch.cat(cnn_enc_list, dim=-1) + # update mlp obs + mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) + + return self.critic(mlp_obs) + + def get_actor_obs(self, obs): + obs_list_1d = [] + obs_dict_2d = {} + for obs_group in self.actor_obs_group_1d: + obs_list_1d.append(obs[obs_group]) + for obs_group in self.actor_obs_group_2d: + obs_dict_2d[obs_group] = obs[obs_group] + return torch.cat(obs_list_1d, dim=-1), obs_dict_2d + + def get_critic_obs(self, obs): + obs_list_1d = [] + obs_dict_2d = {} + for obs_group in self.critic_obs_group_1d: + obs_list_1d.append(obs[obs_group]) + for obs_group in self.critic_obs_group_2d: + obs_dict_2d[obs_group] = obs[obs_group] + return torch.cat(obs_list_1d, dim=-1), obs_dict_2d + + def update_normalization(self, obs): + if self.actor_obs_normalization: + actor_obs, _ = self.get_actor_obs(obs) + self.actor_obs_normalizer.update(actor_obs) + if self.critic_obs_normalization: + critic_obs, _ = self.get_critic_obs(obs) + self.critic_obs_normalizer.update(critic_obs) \ No newline at end of file diff --git a/rsl_rl/networks/__init__.py b/rsl_rl/networks/__init__.py index 7ede0665..830f86d1 100644 --- a/rsl_rl/networks/__init__.py +++ b/rsl_rl/networks/__init__.py @@ -7,6 +7,7 @@ from .memory import HiddenState, Memory from .mlp import MLP +from .cnn import CNN, CNNConfig from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization __all__ = [ diff --git a/rsl_rl/networks/cnn.py b/rsl_rl/networks/cnn.py new file mode 100644 index 00000000..a1368635 --- /dev/null +++ b/rsl_rl/networks/cnn.py @@ -0,0 +1,94 @@ +# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import torch +from dataclasses import MISSING, dataclass +from torch import nn as nn + +from rsl_rl.utils import resolve_nn_activation + + +@dataclass +class CNNConfig: + out_channels: list[int] = MISSING + kernel_size: list[tuple[int, int]] | tuple[int, int] = MISSING + stride: list[int] | int = 1 + flatten: bool = True + avg_pool: tuple[int, int] | None = None + batchnorm: bool | list[bool] = False + max_pool: bool | list[bool] = False + + +class CNN(nn.Module): + def __init__(self, cfg: CNNConfig, in_channels: int, activation: str): + """ + Convolutional Neural Network model. + + .. note:: + Do not save config to allow for the model to be jit compiled. + """ + super().__init__() + + if isinstance(cfg.batchnorm, bool): + cfg.batchnorm = [cfg.batchnorm] * len(cfg.out_channels) + if isinstance(cfg.max_pool, bool): + cfg.max_pool = [cfg.max_pool] * len(cfg.out_channels) + if isinstance(cfg.kernel_size, tuple): + cfg.kernel_size = [cfg.kernel_size] * len(cfg.out_channels) + if isinstance(cfg.stride, int): + cfg.stride = [cfg.stride] * len(cfg.out_channels) + + # get activation function + activation_function = resolve_nn_activation(activation) + + # build model layers + modules = [] + + for idx in range(len(cfg.out_channels)): + in_channels = cfg.in_channels if idx == 0 else cfg.out_channels[idx - 1] + modules.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=cfg.out_channels[idx], + kernel_size=cfg.kernel_size[idx], + stride=cfg.stride[idx], + ) + ) + if cfg.batchnorm[idx]: + modules.append(nn.BatchNorm2d(num_features=cfg.out_channels[idx])) + modules.append(activation_function) + if cfg.max_pool[idx]: + modules.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + + self.architecture = nn.Sequential(*modules) + + if cfg.avg_pool is not None: + self.avgpool = nn.AdaptiveAvgPool2d(cfg.avg_pool) + else: + self.avgpool = None + + # initialize weights + self.init_weights(self.architecture) + + # save flatten config for forward function + self.flatten = cfg.flatten + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.architecture(x) + if self.flatten: + x = x.flatten(start_dim=1) + elif self.avgpool is not None: + x = self.avgpool(x) + x = x.flatten(start_dim=1) + return x + + @staticmethod + def init_weights(sequential): + [ + torch.nn.init.xavier_uniform_(module.weight) + for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Conv2d)) + ] From 6b2910f2c8ae752048e42eb6f61e19292dc57d6d Mon Sep 17 00:00:00 2001 From: Pascal Roth Date: Tue, 16 Sep 2025 10:25:45 +0200 Subject: [PATCH 02/10] working training --- rsl_rl/modules/perceptive_actor_critic.py | 56 +++++++------- rsl_rl/networks/__init__.py | 2 +- rsl_rl/networks/cnn.py | 89 ++++++++++------------- rsl_rl/runners/on_policy_runner.py | 4 +- 4 files changed, 69 insertions(+), 82 deletions(-) diff --git a/rsl_rl/modules/perceptive_actor_critic.py b/rsl_rl/modules/perceptive_actor_critic.py index ff270645..9e3e9790 100644 --- a/rsl_rl/modules/perceptive_actor_critic.py +++ b/rsl_rl/modules/perceptive_actor_critic.py @@ -11,7 +11,7 @@ from .actor_critic import ActorCritic -from rsl_rl.networks import MLP, CNN, CNNConfig, EmpiricalNormalization +from rsl_rl.networks import MLP, CNN, EmpiricalNormalization class PerceptiveActorCritic(ActorCritic): @@ -24,8 +24,8 @@ def __init__( critic_obs_normalization: bool = False, actor_hidden_dims: list[int] = [256, 256, 256], critic_hidden_dims: list[int] = [256, 256, 256], - actor_cnn_config: dict[str, CNNConfig] | CNNConfig | None = None, - critic_cnn_config: dict[str, CNNConfig] | CNNConfig | None = None, + actor_cnn_config: dict[str, dict] | dict | None = None, + critic_cnn_config: dict[str, dict] | dict | None = None, activation: str = "elu", init_noise_std: float = 1.0, noise_std_type: str = "scalar", @@ -45,10 +45,10 @@ def __init__( self.actor_obs_group_1d = [] self.actor_obs_group_2d = [] for obs_group in obs_groups["policy"]: - if len(obs[obs_group].shape) == 2: # FIXME: should be 3??? + if len(obs[obs_group].shape) == 4: # B, C, H, W self.actor_obs_group_2d.append(obs_group) - num_actor_in_channels.append(obs[obs_group].shape[0]) - elif len(obs[obs_group].shape) == 1: + num_actor_in_channels.append(obs[obs_group].shape[1]) + elif len(obs[obs_group].shape) == 2: # B, C self.actor_obs_group_1d.append(obs_group) num_actor_obs += obs[obs_group].shape[-1] else: @@ -59,36 +59,36 @@ def __init__( num_critic_obs = 0 num_critic_in_channels = [] for obs_group in obs_groups["critic"]: - if len(obs[obs_group].shape) == 2: # FIXME: should be 3??? + if len(obs[obs_group].shape) == 4: # B, C, H, W self.critic_obs_group_2d.append(obs_group) - num_critic_in_channels.append(obs[obs_group].shape[0]) - else: + num_critic_in_channels.append(obs[obs_group].shape[1]) + elif len(obs[obs_group].shape) == 2: # B, C self.critic_obs_group_1d.append(obs_group) num_critic_obs += obs[obs_group].shape[-1] + else: + raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") # actor cnn if self.actor_obs_group_2d: assert actor_cnn_config is not None, "Actor CNN config is required for 2D actor observations." # check if multiple 2D actor observations are provided - if len(self.actor_obs_group_2d) > 1 and isinstance(actor_cnn_config, CNNConfig): + if len(self.actor_obs_group_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_config.values()): + assert len(actor_cnn_config) == len(self.actor_obs_group_2d), "Number of CNN configs must match number of 2D actor observations." + elif len(self.actor_obs_group_2d) > 1: print(f"Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups.") actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config] * len(self.actor_obs_group_2d))) - elif len(self.actor_obs_group_2d) > 1 and isinstance(actor_cnn_config, dict): - assert len(actor_cnn_config) == len(self.actor_obs_group_2d), "Number of CNN configs must match number of 2D actor observations." - elif len(self.actor_obs_group_2d) == 1 and isinstance(actor_cnn_config, CNNConfig): - actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config])) else: - raise ValueError(f"Invalid combination of 2D actor observations {self.actor_obs_group_2d} and actor CNN config {actor_cnn_config}.") + actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config])) - self.actor_cnns = {} + self.actor_cnns = nn.ModuleDict() encoding_dims = [] for idx, obs_group in enumerate(self.actor_obs_group_2d): - self.actor_cnns[obs_group] = CNN(actor_cnn_config[obs_group], num_actor_in_channels[idx], activation) + self.actor_cnns[obs_group] = CNN(num_actor_in_channels[idx], activation, **actor_cnn_config[obs_group]) print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}") - # compute the encoding dimension - encoding_dims.append(self.actor_cnns[obs_group](obs[obs_group]).shape[-1]) + # compute the encoding dimension (cpu necessary as model not moved to device yet) + encoding_dims.append(self.actor_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1]) encoding_dim = sum(encoding_dims) else: @@ -111,24 +111,22 @@ def __init__( assert critic_cnn_config is not None, "Critic CNN config is required for 2D critic observations." # check if multiple 2D critic observations are provided - if len(self.critic_obs_group_2d) > 1 and isinstance(critic_cnn_config, CNNConfig): + if len(self.critic_obs_group_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_config.values()): + assert len(critic_cnn_config) == len(self.critic_obs_group_2d), "Number of CNN configs must match number of 2D critic observations." + elif len(self.critic_obs_group_2d) > 1: print(f"Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups.") critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config] * len(self.critic_obs_group_2d))) - elif len(self.critic_obs_group_2d) > 1 and isinstance(critic_cnn_config, dict): - assert len(critic_cnn_config) == len(self.critic_obs_group_2d), "Number of CNN configs must match number of 2D critic observations." - elif len(self.critic_obs_group_2d) == 1 and isinstance(critic_cnn_config, CNNConfig): - critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config])) else: - raise ValueError(f"Invalid combination of 2D critic observations {self.critic_obs_group_2d} and critic CNN config {critic_cnn_config}.") + critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config])) - self.critic_cnns = {} + self.critic_cnns = nn.ModuleDict() encoding_dims = [] for idx, obs_group in enumerate(self.critic_obs_group_2d): - self.critic_cnns[obs_group] = CNN(critic_cnn_config[obs_group], num_critic_in_channels[idx], activation) + self.critic_cnns[obs_group] = CNN(num_critic_in_channels[idx], activation, **critic_cnn_config[obs_group]) print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}") - # compute the encoding dimension - encoding_dims.append(self.critic_cnns[obs_group](obs[obs_group]).shape[-1]) + # compute the encoding dimension (cpu necessary as model not moved to device yet) + encoding_dims.append(self.critic_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1]) encoding_dim = sum(encoding_dims) else: diff --git a/rsl_rl/networks/__init__.py b/rsl_rl/networks/__init__.py index 830f86d1..d0514bdf 100644 --- a/rsl_rl/networks/__init__.py +++ b/rsl_rl/networks/__init__.py @@ -7,7 +7,7 @@ from .memory import HiddenState, Memory from .mlp import MLP -from .cnn import CNN, CNNConfig +from .cnn import CNN from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization __all__ = [ diff --git a/rsl_rl/networks/cnn.py b/rsl_rl/networks/cnn.py index a1368635..fbe06b07 100644 --- a/rsl_rl/networks/cnn.py +++ b/rsl_rl/networks/cnn.py @@ -6,25 +6,13 @@ from __future__ import annotations import torch -from dataclasses import MISSING, dataclass from torch import nn as nn from rsl_rl.utils import resolve_nn_activation -@dataclass -class CNNConfig: - out_channels: list[int] = MISSING - kernel_size: list[tuple[int, int]] | tuple[int, int] = MISSING - stride: list[int] | int = 1 - flatten: bool = True - avg_pool: tuple[int, int] | None = None - batchnorm: bool | list[bool] = False - max_pool: bool | list[bool] = False - - -class CNN(nn.Module): - def __init__(self, cfg: CNNConfig, in_channels: int, activation: str): +class CNN(nn.Sequential): + def __init__(self, in_channels: int, activation: str, out_channels: list[int], kernel_size: list[tuple[int, int]] | tuple[int, int], stride: list[int] | int = 1, flatten: bool = True, avg_pool: tuple[int, int] | None = None, batchnorm: bool | list[bool] = False, max_pool: bool | list[bool] = False): """ Convolutional Neural Network model. @@ -33,52 +21,52 @@ def __init__(self, cfg: CNNConfig, in_channels: int, activation: str): """ super().__init__() - if isinstance(cfg.batchnorm, bool): - cfg.batchnorm = [cfg.batchnorm] * len(cfg.out_channels) - if isinstance(cfg.max_pool, bool): - cfg.max_pool = [cfg.max_pool] * len(cfg.out_channels) - if isinstance(cfg.kernel_size, tuple): - cfg.kernel_size = [cfg.kernel_size] * len(cfg.out_channels) - if isinstance(cfg.stride, int): - cfg.stride = [cfg.stride] * len(cfg.out_channels) + if isinstance(batchnorm, bool): + batchnorm = [batchnorm] * len(out_channels) + if isinstance(max_pool, bool): + max_pool = [max_pool] * len(out_channels) + if isinstance(kernel_size, tuple): + kernel_size = [kernel_size] * len(out_channels) + if isinstance(stride, int): + stride = [stride] * len(out_channels) # get activation function activation_function = resolve_nn_activation(activation) # build model layers - modules = [] + layers = [] - for idx in range(len(cfg.out_channels)): - in_channels = cfg.in_channels if idx == 0 else cfg.out_channels[idx - 1] - modules.append( + for idx in range(len(out_channels)): + in_channels = in_channels if idx == 0 else out_channels[idx - 1] + layers.append( nn.Conv2d( in_channels=in_channels, - out_channels=cfg.out_channels[idx], - kernel_size=cfg.kernel_size[idx], - stride=cfg.stride[idx], + out_channels=out_channels[idx], + kernel_size=kernel_size[idx], + stride=stride[idx], ) ) - if cfg.batchnorm[idx]: - modules.append(nn.BatchNorm2d(num_features=cfg.out_channels[idx])) - modules.append(activation_function) - if cfg.max_pool[idx]: - modules.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) - - self.architecture = nn.Sequential(*modules) - - if cfg.avg_pool is not None: - self.avgpool = nn.AdaptiveAvgPool2d(cfg.avg_pool) + if batchnorm[idx]: + layers.append(nn.BatchNorm2d(num_features=out_channels[idx])) + layers.append(activation_function) + if max_pool[idx]: + layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + + # register the layers + for idx, layer in enumerate(layers): + self.add_module(f"{idx}", layer) + + if avg_pool is not None: + self.avgpool = nn.AdaptiveAvgPool2d(avg_pool) else: self.avgpool = None - # initialize weights - self.init_weights(self.architecture) - # save flatten config for forward function - self.flatten = cfg.flatten + self.flatten = flatten def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.architecture(x) + for layer in self: + x = layer(x) if self.flatten: x = x.flatten(start_dim=1) elif self.avgpool is not None: @@ -86,9 +74,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.flatten(start_dim=1) return x - @staticmethod - def init_weights(sequential): - [ - torch.nn.init.xavier_uniform_(module.weight) - for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Conv2d)) - ] + def init_weights(self, scales: float | tuple[float]): + """Initialize the weights of the CNN.""" + + # initialize the weights + for idx, module in enumerate(self): + if isinstance(module, nn.Conv2d): + nn.init.xavier_uniform_(module.weight) diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 46a9b524..2dbdfd92 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -16,7 +16,7 @@ import rsl_rl from rsl_rl.algorithms import PPO from rsl_rl.env import VecEnv -from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, resolve_rnd_config, resolve_symmetry_config +from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, PerceptiveActorCritic, resolve_rnd_config, resolve_symmetry_config from rsl_rl.utils import resolve_obs_groups, store_code_state @@ -414,7 +414,7 @@ def _construct_algorithm(self, obs: TensorDict) -> PPO: # Initialize the policy actor_critic_class = eval(self.policy_cfg.pop("class_name")) - actor_critic: ActorCritic | ActorCriticRecurrent = actor_critic_class( + actor_critic: ActorCritic | ActorCriticRecurrent | PerceptiveActorCritic = actor_critic_class( obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg ).to(self.device) From 6edccbe56fcb01a50986471ca3ab0bd2b2e047b5 Mon Sep 17 00:00:00 2001 From: Pascal Roth Date: Tue, 16 Sep 2025 10:27:01 +0200 Subject: [PATCH 03/10] formatter --- rsl_rl/modules/perceptive_actor_critic.py | 56 ++++++++++++++--------- rsl_rl/networks/__init__.py | 2 +- rsl_rl/networks/cnn.py | 13 +++++- rsl_rl/runners/on_policy_runner.py | 8 +++- 4 files changed, 54 insertions(+), 25 deletions(-) diff --git a/rsl_rl/modules/perceptive_actor_critic.py b/rsl_rl/modules/perceptive_actor_critic.py index 9e3e9790..9862a6c6 100644 --- a/rsl_rl/modules/perceptive_actor_critic.py +++ b/rsl_rl/modules/perceptive_actor_critic.py @@ -9,13 +9,13 @@ import torch.nn as nn from torch.distributions import Normal -from .actor_critic import ActorCritic +from rsl_rl.networks import CNN, MLP, EmpiricalNormalization -from rsl_rl.networks import MLP, CNN, EmpiricalNormalization +from .actor_critic import ActorCritic class PerceptiveActorCritic(ActorCritic): - def __init__( + def __init__( # noqa: C901 self, obs, obs_groups, @@ -53,7 +53,7 @@ def __init__( num_actor_obs += obs[obs_group].shape[-1] else: raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") - + self.critic_obs_group_1d = [] self.critic_obs_group_2d = [] num_critic_obs = 0 @@ -71,12 +71,16 @@ def __init__( # actor cnn if self.actor_obs_group_2d: assert actor_cnn_config is not None, "Actor CNN config is required for 2D actor observations." - + # check if multiple 2D actor observations are provided if len(self.actor_obs_group_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_config.values()): - assert len(actor_cnn_config) == len(self.actor_obs_group_2d), "Number of CNN configs must match number of 2D actor observations." + assert len(actor_cnn_config) == len( + self.actor_obs_group_2d + ), "Number of CNN configs must match number of 2D actor observations." elif len(self.actor_obs_group_2d) > 1: - print(f"Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups.") + print( + "Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups." + ) actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config] * len(self.actor_obs_group_2d))) else: actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config])) @@ -89,7 +93,7 @@ def __init__( # compute the encoding dimension (cpu necessary as model not moved to device yet) encoding_dims.append(self.actor_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1]) - + encoding_dim = sum(encoding_dims) else: self.actor_cnns = None @@ -97,7 +101,7 @@ def __init__( # actor mlp self.actor = MLP(num_actor_obs + encoding_dim, num_actions, actor_hidden_dims, activation) - + # actor observation normalization (only for 1D actor observations) self.actor_obs_normalization = actor_obs_normalization if actor_obs_normalization: @@ -109,25 +113,33 @@ def __init__( # critic cnn if self.critic_obs_group_2d: assert critic_cnn_config is not None, "Critic CNN config is required for 2D critic observations." - + # check if multiple 2D critic observations are provided if len(self.critic_obs_group_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_config.values()): - assert len(critic_cnn_config) == len(self.critic_obs_group_2d), "Number of CNN configs must match number of 2D critic observations." + assert len(critic_cnn_config) == len( + self.critic_obs_group_2d + ), "Number of CNN configs must match number of 2D critic observations." elif len(self.critic_obs_group_2d) > 1: - print(f"Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups.") - critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config] * len(self.critic_obs_group_2d))) + print( + "Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups." + ) + critic_cnn_config = dict( + zip(self.critic_obs_group_2d, [critic_cnn_config] * len(self.critic_obs_group_2d)) + ) else: critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config])) self.critic_cnns = nn.ModuleDict() encoding_dims = [] for idx, obs_group in enumerate(self.critic_obs_group_2d): - self.critic_cnns[obs_group] = CNN(num_critic_in_channels[idx], activation, **critic_cnn_config[obs_group]) + self.critic_cnns[obs_group] = CNN( + num_critic_in_channels[idx], activation, **critic_cnn_config[obs_group] + ) print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}") # compute the encoding dimension (cpu necessary as model not moved to device yet) encoding_dims.append(self.critic_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1]) - + encoding_dim = sum(encoding_dims) else: self.critic_cnns = None @@ -135,7 +147,7 @@ def __init__( # critic mlp self.critic = MLP(num_critic_obs + encoding_dim, 1, critic_hidden_dims, activation) - + # critic observation normalization (only for 1D critic observations) self.critic_obs_normalization = critic_obs_normalization if critic_obs_normalization: @@ -159,7 +171,7 @@ def __init__( Normal.set_default_validate_args(False) def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]): - + if self.actor_cnns is not None: # encode the 2D actor observations cnn_enc_list = [] @@ -168,7 +180,7 @@ def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Te cnn_enc = torch.cat(cnn_enc_list, dim=-1) # update mlp obs mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) - + super().update_distribution(mlp_obs) def act(self, obs, **kwargs): @@ -180,7 +192,7 @@ def act(self, obs, **kwargs): def act_inference(self, obs): mlp_obs, cnn_obs = self.get_actor_obs(obs) mlp_obs = self.actor_obs_normalizer(mlp_obs) - + if self.actor_cnns is not None: # encode the 2D actor observations cnn_enc_list = [] @@ -189,7 +201,7 @@ def act_inference(self, obs): cnn_enc = torch.cat(cnn_enc_list, dim=-1) # update mlp obs mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) - + return self.actor(mlp_obs) def evaluate(self, obs, **kwargs): @@ -204,7 +216,7 @@ def evaluate(self, obs, **kwargs): cnn_enc = torch.cat(cnn_enc_list, dim=-1) # update mlp obs mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) - + return self.critic(mlp_obs) def get_actor_obs(self, obs): @@ -231,4 +243,4 @@ def update_normalization(self, obs): self.actor_obs_normalizer.update(actor_obs) if self.critic_obs_normalization: critic_obs, _ = self.get_critic_obs(obs) - self.critic_obs_normalizer.update(critic_obs) \ No newline at end of file + self.critic_obs_normalizer.update(critic_obs) diff --git a/rsl_rl/networks/__init__.py b/rsl_rl/networks/__init__.py index d0514bdf..8eb6ac48 100644 --- a/rsl_rl/networks/__init__.py +++ b/rsl_rl/networks/__init__.py @@ -5,9 +5,9 @@ """Definitions for components of modules.""" +from .cnn import CNN from .memory import HiddenState, Memory from .mlp import MLP -from .cnn import CNN from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization __all__ = [ diff --git a/rsl_rl/networks/cnn.py b/rsl_rl/networks/cnn.py index fbe06b07..49b576ed 100644 --- a/rsl_rl/networks/cnn.py +++ b/rsl_rl/networks/cnn.py @@ -12,7 +12,18 @@ class CNN(nn.Sequential): - def __init__(self, in_channels: int, activation: str, out_channels: list[int], kernel_size: list[tuple[int, int]] | tuple[int, int], stride: list[int] | int = 1, flatten: bool = True, avg_pool: tuple[int, int] | None = None, batchnorm: bool | list[bool] = False, max_pool: bool | list[bool] = False): + def __init__( + self, + in_channels: int, + activation: str, + out_channels: list[int], + kernel_size: list[tuple[int, int]] | tuple[int, int], + stride: list[int] | int = 1, + flatten: bool = True, + avg_pool: tuple[int, int] | None = None, + batchnorm: bool | list[bool] = False, + max_pool: bool | list[bool] = False, + ): """ Convolutional Neural Network model. diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 2dbdfd92..b8659d07 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -16,7 +16,13 @@ import rsl_rl from rsl_rl.algorithms import PPO from rsl_rl.env import VecEnv -from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, PerceptiveActorCritic, resolve_rnd_config, resolve_symmetry_config +from rsl_rl.modules import ( + ActorCritic, + ActorCriticRecurrent, + PerceptiveActorCritic, + resolve_rnd_config, + resolve_symmetry_config, +) from rsl_rl.utils import resolve_obs_groups, store_code_state From 364dcab9387274fa4df1f40ed155bbfbd14577f5 Mon Sep 17 00:00:00 2001 From: ClemensSchwarke Date: Thu, 23 Oct 2025 11:58:08 +0200 Subject: [PATCH 04/10] formatting 1 --- rsl_rl/modules/__init__.py | 4 ++-- ...r_critic.py => actor_critic_perceptive.py} | 19 +++++++++---------- rsl_rl/networks/__init__.py | 1 + rsl_rl/networks/cnn.py | 4 +--- rsl_rl/runners/on_policy_runner.py | 4 ++-- 5 files changed, 15 insertions(+), 17 deletions(-) rename rsl_rl/modules/{perceptive_actor_critic.py => actor_critic_perceptive.py} (95%) diff --git a/rsl_rl/modules/__init__.py b/rsl_rl/modules/__init__.py index 04a684c1..7803aa08 100644 --- a/rsl_rl/modules/__init__.py +++ b/rsl_rl/modules/__init__.py @@ -6,8 +6,8 @@ """Definitions for neural-network components for RL-agents.""" from .actor_critic import ActorCritic +from .actor_critic_perceptive import ActorCriticPerceptive from .actor_critic_recurrent import ActorCriticRecurrent -from .perceptive_actor_critic import PerceptiveActorCritic from .rnd import RandomNetworkDistillation, resolve_rnd_config from .student_teacher import StudentTeacher from .student_teacher_recurrent import StudentTeacherRecurrent @@ -15,8 +15,8 @@ __all__ = [ "ActorCritic", + "ActorCriticPerceptive", "ActorCriticRecurrent", - "PerceptiveActorCritic", "RandomNetworkDistillation", "StudentTeacher", "StudentTeacherRecurrent", diff --git a/rsl_rl/modules/perceptive_actor_critic.py b/rsl_rl/modules/actor_critic_perceptive.py similarity index 95% rename from rsl_rl/modules/perceptive_actor_critic.py rename to rsl_rl/modules/actor_critic_perceptive.py index 9862a6c6..3afca634 100644 --- a/rsl_rl/modules/perceptive_actor_critic.py +++ b/rsl_rl/modules/actor_critic_perceptive.py @@ -14,8 +14,8 @@ from .actor_critic import ActorCritic -class PerceptiveActorCritic(ActorCritic): - def __init__( # noqa: C901 +class ActorCriticPerceptive(ActorCritic): + def __init__( self, obs, obs_groups, @@ -34,7 +34,7 @@ def __init__( # noqa: C901 if kwargs: print( "PerceptiveActorCritic.__init__ got unexpected arguments, which will be ignored: " - + str([key for key in kwargs.keys()]) + + str([key for key in kwargs]) ) nn.Module.__init__(self) @@ -74,9 +74,9 @@ def __init__( # noqa: C901 # check if multiple 2D actor observations are provided if len(self.actor_obs_group_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_config.values()): - assert len(actor_cnn_config) == len( - self.actor_obs_group_2d - ), "Number of CNN configs must match number of 2D actor observations." + assert len(actor_cnn_config) == len(self.actor_obs_group_2d), ( + "Number of CNN configs must match number of 2D actor observations." + ) elif len(self.actor_obs_group_2d) > 1: print( "Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups." @@ -116,9 +116,9 @@ def __init__( # noqa: C901 # check if multiple 2D critic observations are provided if len(self.critic_obs_group_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_config.values()): - assert len(critic_cnn_config) == len( - self.critic_obs_group_2d - ), "Number of CNN configs must match number of 2D critic observations." + assert len(critic_cnn_config) == len(self.critic_obs_group_2d), ( + "Number of CNN configs must match number of 2D critic observations." + ) elif len(self.critic_obs_group_2d) > 1: print( "Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups." @@ -171,7 +171,6 @@ def __init__( # noqa: C901 Normal.set_default_validate_args(False) def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]): - if self.actor_cnns is not None: # encode the 2D actor observations cnn_enc_list = [] diff --git a/rsl_rl/networks/__init__.py b/rsl_rl/networks/__init__.py index 8eb6ac48..5050fcc0 100644 --- a/rsl_rl/networks/__init__.py +++ b/rsl_rl/networks/__init__.py @@ -11,6 +11,7 @@ from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization __all__ = [ + "CNN", "MLP", "EmpiricalDiscountedVariationNormalization", "EmpiricalNormalization", diff --git a/rsl_rl/networks/cnn.py b/rsl_rl/networks/cnn.py index 49b576ed..71114ed9 100644 --- a/rsl_rl/networks/cnn.py +++ b/rsl_rl/networks/cnn.py @@ -24,8 +24,7 @@ def __init__( batchnorm: bool | list[bool] = False, max_pool: bool | list[bool] = False, ): - """ - Convolutional Neural Network model. + """Convolutional Neural Network model. .. note:: Do not save config to allow for the model to be jit compiled. @@ -87,7 +86,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def init_weights(self, scales: float | tuple[float]): """Initialize the weights of the CNN.""" - # initialize the weights for idx, module in enumerate(self): if isinstance(module, nn.Conv2d): diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index b8659d07..2b0d7664 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -18,8 +18,8 @@ from rsl_rl.env import VecEnv from rsl_rl.modules import ( ActorCritic, + ActorCriticPerceptive, ActorCriticRecurrent, - PerceptiveActorCritic, resolve_rnd_config, resolve_symmetry_config, ) @@ -420,7 +420,7 @@ def _construct_algorithm(self, obs: TensorDict) -> PPO: # Initialize the policy actor_critic_class = eval(self.policy_cfg.pop("class_name")) - actor_critic: ActorCritic | ActorCriticRecurrent | PerceptiveActorCritic = actor_critic_class( + actor_critic: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive = actor_critic_class( obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg ).to(self.device) From 1517bd0c001da6e0760748eb978d5d13d629f222 Mon Sep 17 00:00:00 2001 From: ClemensSchwarke Date: Thu, 23 Oct 2025 12:01:44 +0200 Subject: [PATCH 05/10] formatting 2 --- rsl_rl/modules/actor_critic_perceptive.py | 26 +++++++---------------- rsl_rl/networks/cnn.py | 4 ++-- 2 files changed, 10 insertions(+), 20 deletions(-) diff --git a/rsl_rl/modules/actor_critic_perceptive.py b/rsl_rl/modules/actor_critic_perceptive.py index 3afca634..dbdb0390 100644 --- a/rsl_rl/modules/actor_critic_perceptive.py +++ b/rsl_rl/modules/actor_critic_perceptive.py @@ -30,7 +30,7 @@ def __init__( init_noise_std: float = 1.0, noise_std_type: str = "scalar", **kwargs, - ): + ) -> None: if kwargs: print( "PerceptiveActorCritic.__init__ got unexpected arguments, which will be ignored: " @@ -170,12 +170,10 @@ def __init__( # disable args validation for speedup Normal.set_default_validate_args(False) - def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]): + def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]) -> None: if self.actor_cnns is not None: # encode the 2D actor observations - cnn_enc_list = [] - for obs_group in self.actor_obs_group_2d: - cnn_enc_list.append(self.actor_cnns[obs_group](cnn_obs[obs_group])) + cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_group_2d] cnn_enc = torch.cat(cnn_enc_list, dim=-1) # update mlp obs mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) @@ -194,9 +192,7 @@ def act_inference(self, obs): if self.actor_cnns is not None: # encode the 2D actor observations - cnn_enc_list = [] - for obs_group in self.actor_obs_group_2d: - cnn_enc_list.append(self.actor_cnns[obs_group](cnn_obs[obs_group])) + cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_group_2d] cnn_enc = torch.cat(cnn_enc_list, dim=-1) # update mlp obs mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) @@ -209,9 +205,7 @@ def evaluate(self, obs, **kwargs): if self.critic_cnns is not None: # encode the 2D critic observations - cnn_enc_list = [] - for obs_group in self.critic_obs_group_2d: - cnn_enc_list.append(self.critic_cnns[obs_group](cnn_obs[obs_group])) + cnn_enc_list = [self.critic_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.critic_obs_group_2d] cnn_enc = torch.cat(cnn_enc_list, dim=-1) # update mlp obs mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) @@ -219,24 +213,20 @@ def evaluate(self, obs, **kwargs): return self.critic(mlp_obs) def get_actor_obs(self, obs): - obs_list_1d = [] obs_dict_2d = {} - for obs_group in self.actor_obs_group_1d: - obs_list_1d.append(obs[obs_group]) + obs_list_1d = [obs[obs_group] for obs_group in self.actor_obs_group_1d] for obs_group in self.actor_obs_group_2d: obs_dict_2d[obs_group] = obs[obs_group] return torch.cat(obs_list_1d, dim=-1), obs_dict_2d def get_critic_obs(self, obs): - obs_list_1d = [] obs_dict_2d = {} - for obs_group in self.critic_obs_group_1d: - obs_list_1d.append(obs[obs_group]) + obs_list_1d = [obs[obs_group] for obs_group in self.critic_obs_group_1d] for obs_group in self.critic_obs_group_2d: obs_dict_2d[obs_group] = obs[obs_group] return torch.cat(obs_list_1d, dim=-1), obs_dict_2d - def update_normalization(self, obs): + def update_normalization(self, obs) -> None: if self.actor_obs_normalization: actor_obs, _ = self.get_actor_obs(obs) self.actor_obs_normalizer.update(actor_obs) diff --git a/rsl_rl/networks/cnn.py b/rsl_rl/networks/cnn.py index 71114ed9..a00721e7 100644 --- a/rsl_rl/networks/cnn.py +++ b/rsl_rl/networks/cnn.py @@ -23,7 +23,7 @@ def __init__( avg_pool: tuple[int, int] | None = None, batchnorm: bool | list[bool] = False, max_pool: bool | list[bool] = False, - ): + ) -> None: """Convolutional Neural Network model. .. note:: @@ -84,7 +84,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.flatten(start_dim=1) return x - def init_weights(self, scales: float | tuple[float]): + def init_weights(self, scales: float | tuple[float]) -> None: """Initialize the weights of the CNN.""" # initialize the weights for idx, module in enumerate(self): From d7bfc7a034f64103ce1bab725c4066a08829254f Mon Sep 17 00:00:00 2001 From: ClemensSchwarke Date: Thu, 23 Oct 2025 12:16:40 +0200 Subject: [PATCH 06/10] CNN docstrings --- rsl_rl/algorithms/ppo.py | 6 +++--- rsl_rl/networks/cnn.py | 35 ++++++++++++++++++++++++++--------- rsl_rl/networks/memory.py | 4 ++-- 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/rsl_rl/algorithms/ppo.py b/rsl_rl/algorithms/ppo.py index 1479c06a..410d52ea 100644 --- a/rsl_rl/algorithms/ppo.py +++ b/rsl_rl/algorithms/ppo.py @@ -11,7 +11,7 @@ from itertools import chain from tensordict import TensorDict -from rsl_rl.modules import ActorCritic, ActorCriticRecurrent +from rsl_rl.modules import ActorCritic, ActorCriticPerceptive, ActorCriticRecurrent from rsl_rl.modules.rnd import RandomNetworkDistillation from rsl_rl.storage import RolloutStorage from rsl_rl.utils import string_to_callable @@ -20,12 +20,12 @@ class PPO: """Proximal Policy Optimization algorithm (https://arxiv.org/abs/1707.06347).""" - policy: ActorCritic | ActorCriticRecurrent + policy: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive """The actor critic module.""" def __init__( self, - policy: ActorCritic | ActorCriticRecurrent, + policy: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive, num_learning_epochs: int = 5, num_mini_batches: int = 4, clip_param: float = 0.2, diff --git a/rsl_rl/networks/cnn.py b/rsl_rl/networks/cnn.py index a00721e7..feffb255 100644 --- a/rsl_rl/networks/cnn.py +++ b/rsl_rl/networks/cnn.py @@ -12,6 +12,12 @@ class CNN(nn.Sequential): + """Convolutional Neural Network (CNN). + + The CNN network is a sequence of convolutional layers, optional batch normalization, activation functions, and + optional max pooling. The final output can be flattened or pooled depending on the configuration. + """ + def __init__( self, in_channels: int, @@ -24,13 +30,25 @@ def __init__( batchnorm: bool | list[bool] = False, max_pool: bool | list[bool] = False, ) -> None: - """Convolutional Neural Network model. + """Initialize the CNN. + + Args: + in_channels: Number of input channels. + activation: Activation function to use. + out_channels: List of output channels for each convolutional layer. + kernel_size: List of kernel sizes for each convolutional layer or a single kernel size for all layers. + stride: List of strides for each convolutional layer or a single stride for all layers. + flatten: Whether to flatten the output tensor. + avg_pool: If specified, applies an adaptive average pooling to the given output size after the convolutions. + batchnorm: Whether to apply batch normalization after each convolutional layer. + max_pool: Whether to apply max pooling after each convolutional layer. .. note:: Do not save config to allow for the model to be jit compiled. """ super().__init__() + # If parameters are not lists, convert them to lists if isinstance(batchnorm, bool): batchnorm = [batchnorm] * len(out_channels) if isinstance(max_pool, bool): @@ -40,12 +58,11 @@ def __init__( if isinstance(stride, int): stride = [stride] * len(out_channels) - # get activation function + # Resolve activation function activation_function = resolve_nn_activation(activation) - # build model layers + # Create layers sequentially layers = [] - for idx in range(len(out_channels)): in_channels = in_channels if idx == 0 else out_channels[idx - 1] layers.append( @@ -62,16 +79,17 @@ def __init__( if max_pool[idx]: layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) - # register the layers + # Register the layers for idx, layer in enumerate(layers): self.add_module(f"{idx}", layer) + # Add avgpool if specified if avg_pool is not None: self.avgpool = nn.AdaptiveAvgPool2d(avg_pool) else: self.avgpool = None - # save flatten config for forward function + # Save flatten flag for forward function self.flatten = flatten def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -84,9 +102,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.flatten(start_dim=1) return x - def init_weights(self, scales: float | tuple[float]) -> None: - """Initialize the weights of the CNN.""" - # initialize the weights + def init_weights(self) -> None: + """Initialize the weights of the CNN with Xavier initialization.""" for idx, module in enumerate(self): if isinstance(module, nn.Conv2d): nn.init.xavier_uniform_(module.weight) diff --git a/rsl_rl/networks/memory.py b/rsl_rl/networks/memory.py index dd40afc2..dc67abed 100644 --- a/rsl_rl/networks/memory.py +++ b/rsl_rl/networks/memory.py @@ -18,9 +18,9 @@ class Memory(nn.Module): - """Memory module for recurrent networks. + """Memory network for recurrent architectures. - This module is used to store the hidden state of the policy. It currently supports GRU and LSTM. + This network is used to store the hidden state of the policy. It currently supports GRU and LSTM. """ def __init__(self, input_size: int, hidden_dim: int = 256, num_layers: int = 1, type: str = "lstm") -> None: From 0a1756d61cb9011475b549c2327464e8351ce412 Mon Sep 17 00:00:00 2001 From: ClemensSchwarke Date: Thu, 23 Oct 2025 14:09:53 +0200 Subject: [PATCH 07/10] format actor_critic_perceptive --- rsl_rl/modules/actor_critic.py | 5 +- rsl_rl/modules/actor_critic_perceptive.py | 194 ++++++++++++---------- rsl_rl/modules/actor_critic_recurrent.py | 5 +- 3 files changed, 111 insertions(+), 93 deletions(-) diff --git a/rsl_rl/modules/actor_critic.py b/rsl_rl/modules/actor_critic.py index 9f01b2f4..da55e704 100644 --- a/rsl_rl/modules/actor_critic.py +++ b/rsl_rl/modules/actor_critic.py @@ -49,9 +49,8 @@ def __init__( assert len(obs[obs_group].shape) == 2, "The ActorCritic module only supports 1D observations." num_critic_obs += obs[obs_group].shape[-1] - self.state_dependent_std = state_dependent_std - # Actor + self.state_dependent_std = state_dependent_std if self.state_dependent_std: self.actor = MLP(num_actor_obs, [2, num_actions], actor_hidden_dims, activation) else: @@ -121,7 +120,7 @@ def action_std(self) -> torch.Tensor: def entropy(self) -> torch.Tensor: return self.distribution.entropy().sum(dim=-1) - def _update_distribution(self, obs: TensorDict) -> None: + def _update_distribution(self, obs: torch.Tensor) -> None: if self.state_dependent_std: # Compute mean and standard deviation mean_and_std = self.actor(obs) diff --git a/rsl_rl/modules/actor_critic_perceptive.py b/rsl_rl/modules/actor_critic_perceptive.py index dbdb0390..7d693223 100644 --- a/rsl_rl/modules/actor_critic_perceptive.py +++ b/rsl_rl/modules/actor_critic_perceptive.py @@ -7,7 +7,9 @@ import torch import torch.nn as nn +from tensordict import TensorDict from torch.distributions import Normal +from typing import Any from rsl_rl.networks import CNN, MLP, EmpiricalNormalization @@ -17,19 +19,20 @@ class ActorCriticPerceptive(ActorCritic): def __init__( self, - obs, - obs_groups, - num_actions, + obs: TensorDict, + obs_groups: dict[str, list[str]], + num_actions: int, actor_obs_normalization: bool = False, critic_obs_normalization: bool = False, actor_hidden_dims: list[int] = [256, 256, 256], critic_hidden_dims: list[int] = [256, 256, 256], - actor_cnn_config: dict[str, dict] | dict | None = None, - critic_cnn_config: dict[str, dict] | dict | None = None, + actor_cnn_cfg: dict[str, dict] | dict | None = None, + critic_cnn_cfg: dict[str, dict] | dict | None = None, activation: str = "elu", init_noise_std: float = 1.0, noise_std_type: str = "scalar", - **kwargs, + state_dependent_std: bool = False, + **kwargs: dict[str, Any], ) -> None: if kwargs: print( @@ -38,195 +41,212 @@ def __init__( ) nn.Module.__init__(self) - # get the observation dimensions + # Get the observation dimensions self.obs_groups = obs_groups num_actor_obs = 0 num_actor_in_channels = [] - self.actor_obs_group_1d = [] - self.actor_obs_group_2d = [] + self.actor_obs_groups_1d = [] + self.actor_obs_groups_2d = [] for obs_group in obs_groups["policy"]: if len(obs[obs_group].shape) == 4: # B, C, H, W - self.actor_obs_group_2d.append(obs_group) + self.actor_obs_groups_2d.append(obs_group) num_actor_in_channels.append(obs[obs_group].shape[1]) elif len(obs[obs_group].shape) == 2: # B, C - self.actor_obs_group_1d.append(obs_group) + self.actor_obs_groups_1d.append(obs_group) num_actor_obs += obs[obs_group].shape[-1] else: raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") - - self.critic_obs_group_1d = [] - self.critic_obs_group_2d = [] num_critic_obs = 0 num_critic_in_channels = [] + self.critic_obs_groups_1d = [] + self.critic_obs_groups_2d = [] for obs_group in obs_groups["critic"]: if len(obs[obs_group].shape) == 4: # B, C, H, W - self.critic_obs_group_2d.append(obs_group) + self.critic_obs_groups_2d.append(obs_group) num_critic_in_channels.append(obs[obs_group].shape[1]) elif len(obs[obs_group].shape) == 2: # B, C - self.critic_obs_group_1d.append(obs_group) + self.critic_obs_groups_1d.append(obs_group) num_critic_obs += obs[obs_group].shape[-1] else: raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") - # actor cnn - if self.actor_obs_group_2d: - assert actor_cnn_config is not None, "Actor CNN config is required for 2D actor observations." + # Actor CNN + if self.actor_obs_groups_2d: + assert actor_cnn_cfg is not None, "An actor CNN configuration is required for 2D actor observations." - # check if multiple 2D actor observations are provided - if len(self.actor_obs_group_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_config.values()): - assert len(actor_cnn_config) == len(self.actor_obs_group_2d), ( - "Number of CNN configs must match number of 2D actor observations." + # Check if multiple 2D actor observations are provided + if len(self.actor_obs_groups_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_cfg.values()): + assert len(actor_cnn_cfg) == len(self.actor_obs_groups_2d), ( + "The number of CNN configurations must match the number of 2D actor observations." ) - elif len(self.actor_obs_group_2d) > 1: + elif len(self.actor_obs_groups_2d) > 1: print( - "Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups." + "Only one CNN configuration for multiple 2D actor observations given, using the same configuration " + "for all groups." ) - actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config] * len(self.actor_obs_group_2d))) + actor_cnn_cfg = dict(zip(self.actor_obs_groups_2d, [actor_cnn_cfg] * len(self.actor_obs_groups_2d))) else: - actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config])) + actor_cnn_cfg = dict(zip(self.actor_obs_groups_2d, [actor_cnn_cfg])) + # Create CNNs for each 2D actor observation self.actor_cnns = nn.ModuleDict() encoding_dims = [] - for idx, obs_group in enumerate(self.actor_obs_group_2d): - self.actor_cnns[obs_group] = CNN(num_actor_in_channels[idx], activation, **actor_cnn_config[obs_group]) + for idx, obs_group in enumerate(self.actor_obs_groups_2d): + self.actor_cnns[obs_group] = CNN(num_actor_in_channels[idx], activation, **actor_cnn_cfg[obs_group]) print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}") - # compute the encoding dimension (cpu necessary as model not moved to device yet) + # Compute the encoding dimension (cpu necessary as model not moved to device yet) encoding_dims.append(self.actor_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1]) - encoding_dim = sum(encoding_dims) else: self.actor_cnns = None encoding_dim = 0 - # actor mlp - self.actor = MLP(num_actor_obs + encoding_dim, num_actions, actor_hidden_dims, activation) + # Actor MLP + self.state_dependent_std = state_dependent_std + if self.state_dependent_std: + self.actor = MLP(num_actor_obs + encoding_dim, [2, num_actions], actor_hidden_dims, activation) + else: + self.actor = MLP(num_actor_obs + encoding_dim, num_actions, actor_hidden_dims, activation) + print(f"Actor MLP: {self.actor}") - # actor observation normalization (only for 1D actor observations) + # Actor observation normalization (only for 1D actor observations) self.actor_obs_normalization = actor_obs_normalization if actor_obs_normalization: self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs) else: self.actor_obs_normalizer = torch.nn.Identity() - print(f"Actor MLP: {self.actor}") - # critic cnn - if self.critic_obs_group_2d: - assert critic_cnn_config is not None, "Critic CNN config is required for 2D critic observations." + # Critic CNN + if self.critic_obs_groups_2d: + assert critic_cnn_cfg is not None, " A critic CNN configuration is required for 2D critic observations." # check if multiple 2D critic observations are provided - if len(self.critic_obs_group_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_config.values()): - assert len(critic_cnn_config) == len(self.critic_obs_group_2d), ( - "Number of CNN configs must match number of 2D critic observations." + if len(self.critic_obs_groups_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_cfg.values()): + assert len(critic_cnn_cfg) == len(self.critic_obs_groups_2d), ( + "The number of CNN configurations must match the number of 2D critic observations." ) - elif len(self.critic_obs_group_2d) > 1: + elif len(self.critic_obs_groups_2d) > 1: print( - "Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups." - ) - critic_cnn_config = dict( - zip(self.critic_obs_group_2d, [critic_cnn_config] * len(self.critic_obs_group_2d)) + "Only one CNN configuration for multiple 2D critic observations given, using the same configuration" + " for all groups." ) + critic_cnn_cfg = dict(zip(self.critic_obs_groups_2d, [critic_cnn_cfg] * len(self.critic_obs_groups_2d))) else: - critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config])) + critic_cnn_cfg = dict(zip(self.critic_obs_groups_2d, [critic_cnn_cfg])) + # Create CNNs for each 2D critic observation self.critic_cnns = nn.ModuleDict() encoding_dims = [] - for idx, obs_group in enumerate(self.critic_obs_group_2d): - self.critic_cnns[obs_group] = CNN( - num_critic_in_channels[idx], activation, **critic_cnn_config[obs_group] - ) + for idx, obs_group in enumerate(self.critic_obs_groups_2d): + self.critic_cnns[obs_group] = CNN(num_critic_in_channels[idx], activation, **critic_cnn_cfg[obs_group]) print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}") - # compute the encoding dimension (cpu necessary as model not moved to device yet) + # Compute the encoding dimension (cpu necessary as model not moved to device yet) encoding_dims.append(self.critic_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1]) - encoding_dim = sum(encoding_dims) else: self.critic_cnns = None encoding_dim = 0 - # critic mlp + # Critic MLP self.critic = MLP(num_critic_obs + encoding_dim, 1, critic_hidden_dims, activation) + print(f"Critic MLP: {self.critic}") - # critic observation normalization (only for 1D critic observations) + # Critic observation normalization (only for 1D critic observations) self.critic_obs_normalization = critic_obs_normalization if critic_obs_normalization: self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs) else: self.critic_obs_normalizer = torch.nn.Identity() - print(f"Critic MLP: {self.critic}") # Action noise self.noise_std_type = noise_std_type - if self.noise_std_type == "scalar": - self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) - elif self.noise_std_type == "log": - self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions))) + if self.state_dependent_std: + torch.nn.init.zeros_(self.actor[-2].weight[num_actions:]) + if self.noise_std_type == "scalar": + torch.nn.init.constant_(self.actor[-2].bias[num_actions:], init_noise_std) + elif self.noise_std_type == "log": + torch.nn.init.constant_( + self.actor[-2].bias[num_actions:], torch.log(torch.tensor(init_noise_std + 1e-7)) + ) + else: + raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") else: - raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") + if self.noise_std_type == "scalar": + self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) + elif self.noise_std_type == "log": + self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions))) + else: + raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") + + # Action distribution + # Note: Populated in update_distribution + self.distribution = None - # Action distribution (populated in update_distribution) - self.distribution: Normal = None - # disable args validation for speedup + # Disable args validation for speedup Normal.set_default_validate_args(False) - def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]) -> None: + def _update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]) -> None: if self.actor_cnns is not None: - # encode the 2D actor observations - cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_group_2d] + # Encode the 2D actor observations + cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d] cnn_enc = torch.cat(cnn_enc_list, dim=-1) - # update mlp obs + # Concatenate to the MLP observations mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) - super().update_distribution(mlp_obs) + super()._update_distribution(mlp_obs) - def act(self, obs, **kwargs): + def act(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor: mlp_obs, cnn_obs = self.get_actor_obs(obs) mlp_obs = self.actor_obs_normalizer(mlp_obs) - self.update_distribution(mlp_obs, cnn_obs) + self._update_distribution(mlp_obs, cnn_obs) return self.distribution.sample() - def act_inference(self, obs): + def act_inference(self, obs: TensorDict) -> torch.Tensor: mlp_obs, cnn_obs = self.get_actor_obs(obs) mlp_obs = self.actor_obs_normalizer(mlp_obs) if self.actor_cnns is not None: - # encode the 2D actor observations - cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_group_2d] + # Encode the 2D actor observations + cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d] cnn_enc = torch.cat(cnn_enc_list, dim=-1) - # update mlp obs + # Concatenate to the MLP observations mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) - return self.actor(mlp_obs) + if self.state_dependent_std: + return self.actor(obs)[..., 0, :] + else: + return self.actor(mlp_obs) - def evaluate(self, obs, **kwargs): + def evaluate(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor: mlp_obs, cnn_obs = self.get_critic_obs(obs) mlp_obs = self.critic_obs_normalizer(mlp_obs) if self.critic_cnns is not None: - # encode the 2D critic observations - cnn_enc_list = [self.critic_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.critic_obs_group_2d] + # Encode the 2D critic observations + cnn_enc_list = [self.critic_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.critic_obs_groups_2d] cnn_enc = torch.cat(cnn_enc_list, dim=-1) - # update mlp obs + # Concatenate to the MLP observations mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) return self.critic(mlp_obs) - def get_actor_obs(self, obs): + def get_actor_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + obs_list_1d = [obs[obs_group] for obs_group in self.actor_obs_groups_1d] obs_dict_2d = {} - obs_list_1d = [obs[obs_group] for obs_group in self.actor_obs_group_1d] - for obs_group in self.actor_obs_group_2d: + for obs_group in self.actor_obs_groups_2d: obs_dict_2d[obs_group] = obs[obs_group] return torch.cat(obs_list_1d, dim=-1), obs_dict_2d - def get_critic_obs(self, obs): + def get_critic_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + obs_list_1d = [obs[obs_group] for obs_group in self.critic_obs_groups_1d] obs_dict_2d = {} - obs_list_1d = [obs[obs_group] for obs_group in self.critic_obs_group_1d] - for obs_group in self.critic_obs_group_2d: + for obs_group in self.critic_obs_groups_2d: obs_dict_2d[obs_group] = obs[obs_group] return torch.cat(obs_list_1d, dim=-1), obs_dict_2d - def update_normalization(self, obs) -> None: + def update_normalization(self, obs: TensorDict) -> None: if self.actor_obs_normalization: actor_obs, _ = self.get_actor_obs(obs) self.actor_obs_normalizer.update(actor_obs) diff --git a/rsl_rl/modules/actor_critic_recurrent.py b/rsl_rl/modules/actor_critic_recurrent.py index 509b6821..0c3805be 100644 --- a/rsl_rl/modules/actor_critic_recurrent.py +++ b/rsl_rl/modules/actor_critic_recurrent.py @@ -61,9 +61,8 @@ def __init__( assert len(obs[obs_group].shape) == 2, "The ActorCriticRecurrent module only supports 1D observations." num_critic_obs += obs[obs_group].shape[-1] - self.state_dependent_std = state_dependent_std - # Actor + self.state_dependent_std = state_dependent_std self.memory_a = Memory(num_actor_obs, rnn_hidden_dim, rnn_num_layers, rnn_type) if self.state_dependent_std: self.actor = MLP(rnn_hidden_dim, [2, num_actions], actor_hidden_dims, activation) @@ -138,7 +137,7 @@ def reset(self, dones: torch.Tensor | None = None) -> None: def forward(self) -> NoReturn: raise NotImplementedError - def _update_distribution(self, obs: TensorDict) -> None: + def _update_distribution(self, obs: torch.Tensor) -> None: if self.state_dependent_std: # Compute mean and standard deviation mean_and_std = self.actor(obs) From 3132a7ea326ef7d854468cb2d876efdbc81573a2 Mon Sep 17 00:00:00 2001 From: ClemensSchwarke Date: Wed, 29 Oct 2025 15:05:10 +0100 Subject: [PATCH 08/10] extend CNN to more configuration options and better exportability --- rsl_rl/modules/actor_critic_perceptive.py | 68 +++++--- rsl_rl/networks/cnn.py | 195 +++++++++++++++------- rsl_rl/networks/mlp.py | 20 +-- rsl_rl/utils/__init__.py | 2 + rsl_rl/utils/utils.py | 15 +- 5 files changed, 199 insertions(+), 101 deletions(-) diff --git a/rsl_rl/modules/actor_critic_perceptive.py b/rsl_rl/modules/actor_critic_perceptive.py index 7d693223..46533860 100644 --- a/rsl_rl/modules/actor_critic_perceptive.py +++ b/rsl_rl/modules/actor_critic_perceptive.py @@ -24,8 +24,8 @@ def __init__( num_actions: int, actor_obs_normalization: bool = False, critic_obs_normalization: bool = False, - actor_hidden_dims: list[int] = [256, 256, 256], - critic_hidden_dims: list[int] = [256, 256, 256], + actor_hidden_dims: tuple[int] | list[int] = [256, 256, 256], + critic_hidden_dims: tuple[int] | list[int] = [256, 256, 256], actor_cnn_cfg: dict[str, dict] | dict | None = None, critic_cnn_cfg: dict[str, dict] | dict | None = None, activation: str = "elu", @@ -43,30 +43,34 @@ def __init__( # Get the observation dimensions self.obs_groups = obs_groups - num_actor_obs = 0 - num_actor_in_channels = [] + num_actor_obs_1d = 0 self.actor_obs_groups_1d = [] + actor_in_dims_2d = [] + actor_in_channels_2d = [] self.actor_obs_groups_2d = [] for obs_group in obs_groups["policy"]: if len(obs[obs_group].shape) == 4: # B, C, H, W self.actor_obs_groups_2d.append(obs_group) - num_actor_in_channels.append(obs[obs_group].shape[1]) + actor_in_dims_2d.append(obs[obs_group].shape[2:4]) + actor_in_channels_2d.append(obs[obs_group].shape[1]) elif len(obs[obs_group].shape) == 2: # B, C self.actor_obs_groups_1d.append(obs_group) - num_actor_obs += obs[obs_group].shape[-1] + num_actor_obs_1d += obs[obs_group].shape[-1] else: raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") - num_critic_obs = 0 - num_critic_in_channels = [] + num_critic_obs_1d = 0 self.critic_obs_groups_1d = [] + critic_in_dims_2d = [] + critic_in_channels_2d = [] self.critic_obs_groups_2d = [] for obs_group in obs_groups["critic"]: if len(obs[obs_group].shape) == 4: # B, C, H, W self.critic_obs_groups_2d.append(obs_group) - num_critic_in_channels.append(obs[obs_group].shape[1]) + critic_in_dims_2d.append(obs[obs_group].shape[2:4]) + critic_in_channels_2d.append(obs[obs_group].shape[1]) elif len(obs[obs_group].shape) == 2: # B, C self.critic_obs_groups_1d.append(obs_group) - num_critic_obs += obs[obs_group].shape[-1] + num_critic_obs_1d += obs[obs_group].shape[-1] else: raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") @@ -90,14 +94,19 @@ def __init__( # Create CNNs for each 2D actor observation self.actor_cnns = nn.ModuleDict() - encoding_dims = [] + encoding_dim = 0 for idx, obs_group in enumerate(self.actor_obs_groups_2d): - self.actor_cnns[obs_group] = CNN(num_actor_in_channels[idx], activation, **actor_cnn_cfg[obs_group]) + self.actor_cnns[obs_group] = CNN( + input_dim=actor_in_dims_2d[idx], + input_channels=actor_in_channels_2d[idx], + **actor_cnn_cfg[obs_group], + ) print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}") - - # Compute the encoding dimension (cpu necessary as model not moved to device yet) - encoding_dims.append(self.actor_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1]) - encoding_dim = sum(encoding_dims) + # Get the output dimension of the CNN + if self.actor_cnns[obs_group].output_channels is None: + encoding_dim += int(self.actor_cnns[obs_group].output_dim) # type: ignore + else: + raise ValueError("The output of the actor CNN must be flattened before passing it to the MLP.") else: self.actor_cnns = None encoding_dim = 0 @@ -105,15 +114,15 @@ def __init__( # Actor MLP self.state_dependent_std = state_dependent_std if self.state_dependent_std: - self.actor = MLP(num_actor_obs + encoding_dim, [2, num_actions], actor_hidden_dims, activation) + self.actor = MLP(num_actor_obs_1d + encoding_dim, [2, num_actions], actor_hidden_dims, activation) else: - self.actor = MLP(num_actor_obs + encoding_dim, num_actions, actor_hidden_dims, activation) + self.actor = MLP(num_actor_obs_1d + encoding_dim, num_actions, actor_hidden_dims, activation) print(f"Actor MLP: {self.actor}") # Actor observation normalization (only for 1D actor observations) self.actor_obs_normalization = actor_obs_normalization if actor_obs_normalization: - self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs) + self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs_1d) else: self.actor_obs_normalizer = torch.nn.Identity() @@ -137,26 +146,31 @@ def __init__( # Create CNNs for each 2D critic observation self.critic_cnns = nn.ModuleDict() - encoding_dims = [] + encoding_dim = 0 for idx, obs_group in enumerate(self.critic_obs_groups_2d): - self.critic_cnns[obs_group] = CNN(num_critic_in_channels[idx], activation, **critic_cnn_cfg[obs_group]) + self.critic_cnns[obs_group] = CNN( + input_dim=critic_in_dims_2d[idx], + input_channels=critic_in_channels_2d[idx], + **critic_cnn_cfg[obs_group], + ) print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}") - - # Compute the encoding dimension (cpu necessary as model not moved to device yet) - encoding_dims.append(self.critic_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1]) - encoding_dim = sum(encoding_dims) + # Get the output dimension of the CNN + if self.critic_cnns[obs_group].output_channels is None: + encoding_dim += int(self.critic_cnns[obs_group].output_dim) # type: ignore + else: + raise ValueError("The output of the critic CNN must be flattened before passing it to the MLP.") else: self.critic_cnns = None encoding_dim = 0 # Critic MLP - self.critic = MLP(num_critic_obs + encoding_dim, 1, critic_hidden_dims, activation) + self.critic = MLP(num_critic_obs_1d + encoding_dim, 1, critic_hidden_dims, activation) print(f"Critic MLP: {self.critic}") # Critic observation normalization (only for 1D critic observations) self.critic_obs_normalization = critic_obs_normalization if critic_obs_normalization: - self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs) + self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs_1d) else: self.critic_obs_normalizer = torch.nn.Identity() diff --git a/rsl_rl/networks/cnn.py b/rsl_rl/networks/cnn.py index feffb255..bac6274b 100644 --- a/rsl_rl/networks/cnn.py +++ b/rsl_rl/networks/cnn.py @@ -5,105 +5,188 @@ from __future__ import annotations +import math import torch from torch import nn as nn -from rsl_rl.utils import resolve_nn_activation +from rsl_rl.utils import get_param, resolve_nn_activation class CNN(nn.Sequential): """Convolutional Neural Network (CNN). - The CNN network is a sequence of convolutional layers, optional batch normalization, activation functions, and - optional max pooling. The final output can be flattened or pooled depending on the configuration. + The CNN network is a sequence of convolutional layers, optional normalization layers, optional activation functions, + and optional pooling. The final output can be flattened. """ def __init__( self, - in_channels: int, - activation: str, - out_channels: list[int], - kernel_size: list[tuple[int, int]] | tuple[int, int], - stride: list[int] | int = 1, + input_dim: tuple[int, int], + input_channels: int, + output_channels: tuple[int] | list[int], + kernel_size: int | tuple[int] | list[int], + stride: int | tuple[int] | list[int] = 1, + dilation: int | tuple[int] | list[int] = 1, + padding: str = "none", + norm: str | tuple[str] | list[str] = "none", + activation: str = "elu", + max_pool: bool | tuple[bool] | list[bool] = False, + global_pool: str = "none", flatten: bool = True, - avg_pool: tuple[int, int] | None = None, - batchnorm: bool | list[bool] = False, - max_pool: bool | list[bool] = False, ) -> None: """Initialize the CNN. Args: - in_channels: Number of input channels. - activation: Activation function to use. - out_channels: List of output channels for each convolutional layer. + input_dim: Height and width of the input. + input_channels: Number of input channels. + output_channels: List of output channels for each convolutional layer. kernel_size: List of kernel sizes for each convolutional layer or a single kernel size for all layers. stride: List of strides for each convolutional layer or a single stride for all layers. + dilation: List of dilations for each convolutional layer or a single dilation for all layers. + padding: Padding type to use. Either 'none', 'zeros', 'reflect', 'replicate', or 'circular'. + norm: List of normalization types for each convolutional layer or a single type for all layers. Either + 'none', 'batch', or 'layer'. + activation: Activation function to use. + max_pool: List of booleans indicating whether to apply max pooling after each convolutional layer or a + single boolean for all layers. + global_pool: Global pooling type to apply at the end. Either 'none', 'max', or 'avg'. flatten: Whether to flatten the output tensor. - avg_pool: If specified, applies an adaptive average pooling to the given output size after the convolutions. - batchnorm: Whether to apply batch normalization after each convolutional layer. - max_pool: Whether to apply max pooling after each convolutional layer. - - .. note:: - Do not save config to allow for the model to be jit compiled. """ super().__init__() - # If parameters are not lists, convert them to lists - if isinstance(batchnorm, bool): - batchnorm = [batchnorm] * len(out_channels) - if isinstance(max_pool, bool): - max_pool = [max_pool] * len(out_channels) - if isinstance(kernel_size, tuple): - kernel_size = [kernel_size] * len(out_channels) - if isinstance(stride, int): - stride = [stride] * len(out_channels) - # Resolve activation function activation_function = resolve_nn_activation(activation) # Create layers sequentially layers = [] - for idx in range(len(out_channels)): - in_channels = in_channels if idx == 0 else out_channels[idx - 1] + last_channels = input_channels + last_dim = input_dim + for idx in range(len(output_channels)): + # Get parameters for the current layer + k = get_param(kernel_size, idx) + s = get_param(stride, idx) + d = get_param(dilation, idx) + p = ( + _compute_padding(last_dim, k, s, d) + if padding in ["zeros", "reflect", "replicate", "circular"] + else (0, 0) + ) + + # Append convolutional layer layers.append( nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels[idx], - kernel_size=kernel_size[idx], - stride=stride[idx], + in_channels=last_channels, + out_channels=output_channels[idx], + kernel_size=k, + stride=s, + padding=p, + dilation=d, + padding_mode=padding if padding in ["zeros", "reflect", "replicate", "circular"] else "zeros", ) ) - if batchnorm[idx]: - layers.append(nn.BatchNorm2d(num_features=out_channels[idx])) + + # Append normalization layer if specified + n = get_param(norm, idx) + if n == "none": + pass + elif n == "batch": + layers.append(nn.BatchNorm2d(output_channels[idx])) + elif n == "layer": + norm_input_dim = _compute_output_dim(last_dim, k, s, d, p) + layers.append(nn.LayerNorm([output_channels[idx], norm_input_dim[0], norm_input_dim[1]])) + else: + raise ValueError( + f"Unsupported normalization type: {n}. Supported types are 'none', 'batch', and 'layer'." + ) + + # Append activation function layers.append(activation_function) - if max_pool[idx]: + + # Apply max pooling if specified + if get_param(max_pool, idx): layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + # Update last channels and dimensions + last_channels = output_channels[idx] + last_dim = _compute_output_dim(last_dim, k, s, d, p, is_max_pool=get_param(max_pool, idx)) + + # Apply global pooling if specified + if global_pool == "none": + pass + elif global_pool == "max": + layers.append(nn.AdaptiveMaxPool2d((1, 1))) + last_dim = (1, 1) + elif global_pool == "avg": + layers.append(nn.AdaptiveAvgPool2d((1, 1))) + last_dim = (1, 1) + else: + raise ValueError( + f"Unsupported global pooling type: {global_pool}. Supported types are 'none', 'max', and 'avg'." + ) + + # Apply flattening if specified + if flatten: + layers.append(nn.Flatten(start_dim=1)) + + # Store final output dimension + self._output_channels = last_channels if not flatten else None + self._output_dim = last_dim if not flatten else last_channels * last_dim[0] * last_dim[1] + # Register the layers for idx, layer in enumerate(layers): self.add_module(f"{idx}", layer) - # Add avgpool if specified - if avg_pool is not None: - self.avgpool = nn.AdaptiveAvgPool2d(avg_pool) - else: - self.avgpool = None + @property + def output_channels(self) -> int | None: + """Get the number of output channels or None if output is flattened.""" + return self._output_channels - # Save flatten flag for forward function - self.flatten = flatten + @property + def output_dim(self) -> tuple[int, int] | int: + """Get the output height and width or total output dimension if output is flattened.""" + return self._output_dim + + def init_weights(self) -> None: + """Initialize the weights of the CNN with Xavier initialization.""" + for idx, module in enumerate(self): + if isinstance(module, nn.Conv2d): + torch.nn.init.kaiming_normal_(module.weight) + torch.nn.init.zeros_(module.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the CNN.""" for layer in self: x = layer(x) - if self.flatten: - x = x.flatten(start_dim=1) - elif self.avgpool is not None: - x = self.avgpool(x) - x = x.flatten(start_dim=1) return x - def init_weights(self) -> None: - """Initialize the weights of the CNN with Xavier initialization.""" - for idx, module in enumerate(self): - if isinstance(module, nn.Conv2d): - nn.init.xavier_uniform_(module.weight) + +def _compute_padding(input_hw: tuple[int, int], kernel: int, stride: int, dilation: int) -> tuple[int, int]: + """Compute the optimal padding for the current layer. + + Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + """ + h = math.ceil((stride * math.floor(input_hw[0] / stride) - input_hw[0] - stride + dilation * (kernel - 1) + 1) / 2) + w = math.ceil((stride * math.floor(input_hw[1] / stride) - input_hw[1] - stride + dilation * (kernel - 1) + 1) / 2) + return (h, w) + + +def _compute_output_dim( + input_hw: tuple[int, int], + kernel: int, + stride: int, + dilation: int, + padding: tuple[int, int], + is_max_pool: bool = False, +) -> tuple[int, int]: + """Compute the output height and width of the current layer. + + Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + """ + h = math.floor((input_hw[0] + 2 * padding[0] - dilation * (kernel - 1) - 1) / stride + 1) + w = math.floor((input_hw[1] + 2 * padding[1] - dilation * (kernel - 1) - 1) / stride + 1) + + if is_max_pool: + h = math.ceil(h / 2) + w = math.ceil(w / 2) + + return (h, w) diff --git a/rsl_rl/networks/mlp.py b/rsl_rl/networks/mlp.py index f01a7577..25f26804 100644 --- a/rsl_rl/networks/mlp.py +++ b/rsl_rl/networks/mlp.py @@ -9,7 +9,7 @@ import torch.nn as nn from functools import reduce -from rsl_rl.utils import resolve_nn_activation +from rsl_rl.utils import get_param, resolve_nn_activation class MLP(nn.Sequential): @@ -82,27 +82,13 @@ def init_weights(self, scales: float | tuple[float]) -> None: Args: scales: Scale factor for the weights. """ - - def get_scale(idx: int) -> float: - """Get the scale factor for the weights of the MLP. - - Args: - idx: Index of the layer. - """ - return scales[idx] if isinstance(scales, (list, tuple)) else scales - - # Initialize the weights for idx, module in enumerate(self): if isinstance(module, nn.Linear): - nn.init.orthogonal_(module.weight, gain=get_scale(idx)) + nn.init.orthogonal_(module.weight, gain=get_param(scales, idx)) nn.init.zeros_(module.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass of the MLP. - - Args: - x: Input tensor. - """ + """Forward pass of the MLP.""" for layer in self: x = layer(x) return x diff --git a/rsl_rl/utils/__init__.py b/rsl_rl/utils/__init__.py index a11074e0..a44bc678 100644 --- a/rsl_rl/utils/__init__.py +++ b/rsl_rl/utils/__init__.py @@ -6,6 +6,7 @@ """Helper functions.""" from .utils import ( + get_param, resolve_nn_activation, resolve_obs_groups, resolve_optimizer, @@ -16,6 +17,7 @@ ) __all__ = [ + "get_param", "resolve_nn_activation", "resolve_obs_groups", "resolve_optimizer", diff --git a/rsl_rl/utils/utils.py b/rsl_rl/utils/utils.py index 7a044e83..c1638d29 100644 --- a/rsl_rl/utils/utils.py +++ b/rsl_rl/utils/utils.py @@ -12,7 +12,20 @@ import torch import warnings from tensordict import TensorDict -from typing import Callable +from typing import Any, Callable + + +def get_param(param: Any, idx: int) -> Any: + """Get a parameter for the given index. + + Args: + param: Parameter or list/tuple of parameters. + idx: Index to get the parameter for. + """ + if isinstance(param, (tuple, list)): + return param[idx] + else: + return param def resolve_nn_activation(act_name: str) -> torch.nn.Module: From cd53d291ef9966633868ba0fad6865a4395d83fe Mon Sep 17 00:00:00 2001 From: ClemensSchwarke Date: Fri, 14 Nov 2025 14:29:56 +0100 Subject: [PATCH 09/10] fixes --- rsl_rl/modules/actor_critic_perceptive.py | 55 ++++++++++------------- 1 file changed, 24 insertions(+), 31 deletions(-) diff --git a/rsl_rl/modules/actor_critic_perceptive.py b/rsl_rl/modules/actor_critic_perceptive.py index 46533860..27648da5 100644 --- a/rsl_rl/modules/actor_critic_perceptive.py +++ b/rsl_rl/modules/actor_critic_perceptive.py @@ -39,7 +39,7 @@ def __init__( "PerceptiveActorCritic.__init__ got unexpected arguments, which will be ignored: " + str([key for key in kwargs]) ) - nn.Module.__init__(self) + super(ActorCritic, self).__init__() # Get the observation dimensions self.obs_groups = obs_groups @@ -74,23 +74,22 @@ def __init__( else: raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") + # Assert that there are 2D observations + assert self.actor_obs_groups_2d or self.critic_obs_groups_2d, ( + "No 2D observations are provided. If this is intentional, use the ActorCritic module instead." + ) + # Actor CNN if self.actor_obs_groups_2d: + # Resolve the actor CNN configuration assert actor_cnn_cfg is not None, "An actor CNN configuration is required for 2D actor observations." - - # Check if multiple 2D actor observations are provided - if len(self.actor_obs_groups_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_cfg.values()): - assert len(actor_cnn_cfg) == len(self.actor_obs_groups_2d), ( - "The number of CNN configurations must match the number of 2D actor observations." - ) - elif len(self.actor_obs_groups_2d) > 1: - print( - "Only one CNN configuration for multiple 2D actor observations given, using the same configuration " - "for all groups." - ) - actor_cnn_cfg = dict(zip(self.actor_obs_groups_2d, [actor_cnn_cfg] * len(self.actor_obs_groups_2d))) - else: - actor_cnn_cfg = dict(zip(self.actor_obs_groups_2d, [actor_cnn_cfg])) + # If a single configuration dictionary is provided, create a dictionary for each 2D observation group + if not all(isinstance(v, dict) for v in actor_cnn_cfg.values()): + actor_cnn_cfg = {group: actor_cnn_cfg for group in self.actor_obs_groups_2d} + # Check that the number of configs matches the number of observation groups + assert len(actor_cnn_cfg) == len(self.actor_obs_groups_2d), ( + "The number of CNN configurations must match the number of 2D actor observations." + ) # Create CNNs for each 2D actor observation self.actor_cnns = nn.ModuleDict() @@ -128,21 +127,15 @@ def __init__( # Critic CNN if self.critic_obs_groups_2d: - assert critic_cnn_cfg is not None, " A critic CNN configuration is required for 2D critic observations." - - # check if multiple 2D critic observations are provided - if len(self.critic_obs_groups_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_cfg.values()): - assert len(critic_cnn_cfg) == len(self.critic_obs_groups_2d), ( - "The number of CNN configurations must match the number of 2D critic observations." - ) - elif len(self.critic_obs_groups_2d) > 1: - print( - "Only one CNN configuration for multiple 2D critic observations given, using the same configuration" - " for all groups." - ) - critic_cnn_cfg = dict(zip(self.critic_obs_groups_2d, [critic_cnn_cfg] * len(self.critic_obs_groups_2d))) - else: - critic_cnn_cfg = dict(zip(self.critic_obs_groups_2d, [critic_cnn_cfg])) + # Resolve the critic CNN configuration + assert critic_cnn_cfg is not None, "A critic CNN configuration is required for 2D critic observations." + # If a single configuration dictionary is provided, create a dictionary for each 2D observation group + if not all(isinstance(v, dict) for v in critic_cnn_cfg.values()): + critic_cnn_cfg = {group: critic_cnn_cfg for group in self.critic_obs_groups_2d} + # Check that the number of configs matches the number of observation groups + assert len(critic_cnn_cfg) == len(self.critic_obs_groups_2d), ( + "The number of CNN configurations must match the number of 2D critic observations." + ) # Create CNNs for each 2D critic observation self.critic_cnns = nn.ModuleDict() @@ -229,7 +222,7 @@ def act_inference(self, obs: TensorDict) -> torch.Tensor: mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) if self.state_dependent_std: - return self.actor(obs)[..., 0, :] + return self.actor(mlp_obs)[..., 0, :] else: return self.actor(mlp_obs) From b0e24aee94280fcc6978dac46a3ab604ec259842 Mon Sep 17 00:00:00 2001 From: ClemensSchwarke Date: Fri, 14 Nov 2025 15:05:39 +0100 Subject: [PATCH 10/10] rename perceptive actor critic --- rsl_rl/algorithms/ppo.py | 6 +++--- rsl_rl/modules/__init__.py | 4 ++-- .../{actor_critic_perceptive.py => actor_critic_cnn.py} | 4 ++-- rsl_rl/runners/on_policy_runner.py | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) rename rsl_rl/modules/{actor_critic_perceptive.py => actor_critic_cnn.py} (98%) diff --git a/rsl_rl/algorithms/ppo.py b/rsl_rl/algorithms/ppo.py index 410d52ea..b9ae2737 100644 --- a/rsl_rl/algorithms/ppo.py +++ b/rsl_rl/algorithms/ppo.py @@ -11,7 +11,7 @@ from itertools import chain from tensordict import TensorDict -from rsl_rl.modules import ActorCritic, ActorCriticPerceptive, ActorCriticRecurrent +from rsl_rl.modules import ActorCritic, ActorCriticCNN, ActorCriticRecurrent from rsl_rl.modules.rnd import RandomNetworkDistillation from rsl_rl.storage import RolloutStorage from rsl_rl.utils import string_to_callable @@ -20,12 +20,12 @@ class PPO: """Proximal Policy Optimization algorithm (https://arxiv.org/abs/1707.06347).""" - policy: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive + policy: ActorCritic | ActorCriticRecurrent | ActorCriticCNN """The actor critic module.""" def __init__( self, - policy: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive, + policy: ActorCritic | ActorCriticRecurrent | ActorCriticCNN, num_learning_epochs: int = 5, num_mini_batches: int = 4, clip_param: float = 0.2, diff --git a/rsl_rl/modules/__init__.py b/rsl_rl/modules/__init__.py index 7803aa08..3c867598 100644 --- a/rsl_rl/modules/__init__.py +++ b/rsl_rl/modules/__init__.py @@ -6,7 +6,7 @@ """Definitions for neural-network components for RL-agents.""" from .actor_critic import ActorCritic -from .actor_critic_perceptive import ActorCriticPerceptive +from .actor_critic_cnn import ActorCriticCNN from .actor_critic_recurrent import ActorCriticRecurrent from .rnd import RandomNetworkDistillation, resolve_rnd_config from .student_teacher import StudentTeacher @@ -15,7 +15,7 @@ __all__ = [ "ActorCritic", - "ActorCriticPerceptive", + "ActorCriticCNN", "ActorCriticRecurrent", "RandomNetworkDistillation", "StudentTeacher", diff --git a/rsl_rl/modules/actor_critic_perceptive.py b/rsl_rl/modules/actor_critic_cnn.py similarity index 98% rename from rsl_rl/modules/actor_critic_perceptive.py rename to rsl_rl/modules/actor_critic_cnn.py index 27648da5..b9ca2ab9 100644 --- a/rsl_rl/modules/actor_critic_perceptive.py +++ b/rsl_rl/modules/actor_critic_cnn.py @@ -16,7 +16,7 @@ from .actor_critic import ActorCritic -class ActorCriticPerceptive(ActorCritic): +class ActorCriticCNN(ActorCritic): def __init__( self, obs: TensorDict, @@ -36,7 +36,7 @@ def __init__( ) -> None: if kwargs: print( - "PerceptiveActorCritic.__init__ got unexpected arguments, which will be ignored: " + "ActorCriticCNN.__init__ got unexpected arguments, which will be ignored: " + str([key for key in kwargs]) ) super(ActorCritic, self).__init__() diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 2b0d7664..b5817c7a 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -18,7 +18,7 @@ from rsl_rl.env import VecEnv from rsl_rl.modules import ( ActorCritic, - ActorCriticPerceptive, + ActorCriticCNN, ActorCriticRecurrent, resolve_rnd_config, resolve_symmetry_config, @@ -420,7 +420,7 @@ def _construct_algorithm(self, obs: TensorDict) -> PPO: # Initialize the policy actor_critic_class = eval(self.policy_cfg.pop("class_name")) - actor_critic: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive = actor_critic_class( + actor_critic: ActorCritic | ActorCriticRecurrent | ActorCriticCNN = actor_critic_class( obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg ).to(self.device)