2727from tensordict .utils import NestedKey
2828from torch import distributions as d
2929
30+ from torchrl ._utils import _standardize
3031from torchrl .objectives .common import LossModule
3132
3233from torchrl .objectives .utils import (
4647 TDLambdaEstimator ,
4748 VTrace ,
4849)
50+ from yaml import warnings
4951
5052
5153class PPOLoss (LossModule ):
@@ -87,6 +89,9 @@ class PPOLoss(LossModule):
8789 Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
8890 normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
8991 before being used. Defaults to ``False``.
92+ normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
93+ standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
94+ where the agent (or objective) dimension may be excluded from the reductions. Default: ().
9095 separate_losses (bool, optional): if ``True``, shared parameters between
9196 policy and critic will only be trained on the policy loss.
9297 Defaults to ``False``, i.e., gradients are propagated to shared
@@ -311,6 +316,7 @@ def __init__(
311316 critic_coef : float = 1.0 ,
312317 loss_critic_type : str = "smooth_l1" ,
313318 normalize_advantage : bool = False ,
319+ normalize_advantage_exclude_dims : Tuple [int ] = (),
314320 gamma : float = None ,
315321 separate_losses : bool = False ,
316322 advantage_key : str = None ,
@@ -381,6 +387,8 @@ def __init__(
381387 self .critic_coef = None
382388 self .loss_critic_type = loss_critic_type
383389 self .normalize_advantage = normalize_advantage
390+ self .normalize_advantage_exclude_dims = normalize_advantage_exclude_dims
391+
384392 if gamma is not None :
385393 raise TypeError (_GAMMA_LMBDA_DEPREC_ERROR )
386394 self ._set_deprecated_ctor_keys (
@@ -606,9 +614,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
606614 )
607615 advantage = tensordict .get (self .tensor_keys .advantage )
608616 if self .normalize_advantage and advantage .numel () > 1 :
609- loc = advantage .mean ()
610- scale = advantage .std ().clamp_min (1e-6 )
611- advantage = (advantage - loc ) / scale
617+ if advantage .numel () > tensordict .batch_size .numel () and not len (
618+ self .normalize_advantage_exclude_dims
619+ ):
620+ warnings .warn (
621+ "You requested advantage normalization and the advantage key has more dimensions"
622+ " than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
623+ "if you want to keep any dimension independent while computing normalization statistics. "
624+ "If you are working in multi-agent/multi-objective settings this is highly suggested."
625+ )
626+ advantage = _standardize (advantage , self .normalize_advantage_exclude_dims )
612627
613628 log_weight , dist , kl_approx = self ._log_weight (tensordict )
614629 if is_tensor_collection (log_weight ):
@@ -711,6 +726,9 @@ class ClipPPOLoss(PPOLoss):
711726 Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
712727 normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
713728 before being used. Defaults to ``False``.
729+ normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
730+ standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
731+ where the agent (or objective) dimension may be excluded from the reductions. Default: ().
714732 separate_losses (bool, optional): if ``True``, shared parameters between
715733 policy and critic will only be trained on the policy loss.
716734 Defaults to ``False``, i.e., gradients are propagated to shared
@@ -802,6 +820,7 @@ def __init__(
802820 critic_coef : float = 1.0 ,
803821 loss_critic_type : str = "smooth_l1" ,
804822 normalize_advantage : bool = False ,
823+ normalize_advantage_exclude_dims : Tuple [int ] = (),
805824 gamma : float = None ,
806825 separate_losses : bool = False ,
807826 reduction : str = None ,
@@ -821,6 +840,7 @@ def __init__(
821840 critic_coef = critic_coef ,
822841 loss_critic_type = loss_critic_type ,
823842 normalize_advantage = normalize_advantage ,
843+ normalize_advantage_exclude_dims = normalize_advantage_exclude_dims ,
824844 gamma = gamma ,
825845 separate_losses = separate_losses ,
826846 reduction = reduction ,
@@ -871,9 +891,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
871891 )
872892 advantage = tensordict .get (self .tensor_keys .advantage )
873893 if self .normalize_advantage and advantage .numel () > 1 :
874- loc = advantage .mean ()
875- scale = advantage .std ().clamp_min (1e-6 )
876- advantage = (advantage - loc ) / scale
894+ if advantage .numel () > tensordict .batch_size .numel () and not len (
895+ self .normalize_advantage_exclude_dims
896+ ):
897+ warnings .warn (
898+ "You requested advantage normalization and the advantage key has more dimensions"
899+ " than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
900+ "if you want to keep any dimension independent while computing normalization statistics. "
901+ "If you are working in multi-agent/multi-objective settings this is highly suggested."
902+ )
903+ advantage = _standardize (advantage , self .normalize_advantage_exclude_dims )
877904
878905 log_weight , dist , kl_approx = self ._log_weight (tensordict )
879906 # ESS for logging
@@ -955,6 +982,9 @@ class KLPENPPOLoss(PPOLoss):
955982 Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
956983 normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
957984 before being used. Defaults to ``False``.
985+ normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
986+ standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
987+ where the agent (or objective) dimension may be excluded from the reductions. Default: ().
958988 separate_losses (bool, optional): if ``True``, shared parameters between
959989 policy and critic will only be trained on the policy loss.
960990 Defaults to ``False``, i.e., gradients are propagated to shared
@@ -1048,6 +1078,7 @@ def __init__(
10481078 critic_coef : float = 1.0 ,
10491079 loss_critic_type : str = "smooth_l1" ,
10501080 normalize_advantage : bool = False ,
1081+ normalize_advantage_exclude_dims : Tuple [int ] = (),
10511082 gamma : float = None ,
10521083 separate_losses : bool = False ,
10531084 reduction : str = None ,
@@ -1063,6 +1094,7 @@ def __init__(
10631094 critic_coef = critic_coef ,
10641095 loss_critic_type = loss_critic_type ,
10651096 normalize_advantage = normalize_advantage ,
1097+ normalize_advantage_exclude_dims = normalize_advantage_exclude_dims ,
10661098 gamma = gamma ,
10671099 separate_losses = separate_losses ,
10681100 reduction = reduction ,
@@ -1151,9 +1183,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
11511183 )
11521184 advantage = tensordict_copy .get (self .tensor_keys .advantage )
11531185 if self .normalize_advantage and advantage .numel () > 1 :
1154- loc = advantage .mean ()
1155- scale = advantage .std ().clamp_min (1e-6 )
1156- advantage = (advantage - loc ) / scale
1186+ if advantage .numel () > tensordict .batch_size .numel () and not len (
1187+ self .normalize_advantage_exclude_dims
1188+ ):
1189+ warnings .warn (
1190+ "You requested advantage normalization and the advantage key has more dimensions"
1191+ " than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
1192+ "if you want to keep any dimension independent while computing normalization statistics. "
1193+ "If you are working in multi-agent/multi-objective settings this is highly suggested."
1194+ )
1195+ advantage = _standardize (advantage , self .normalize_advantage_exclude_dims )
1196+
11571197 log_weight , dist , kl_approx = self ._log_weight (tensordict_copy )
11581198 neg_loss = log_weight .exp () * advantage
11591199
0 commit comments