88import warnings
99from copy import deepcopy
1010from dataclasses import dataclass
11+ from typing import Mapping
1112
1213import torch
1314from tensordict import (
@@ -84,7 +85,9 @@ class PPOLoss(LossModule):
8485 ``samples_mc_entropy`` will control how many
8586 samples will be used to compute this estimate.
8687 Defaults to ``1``.
87- entropy_coef (scalar, optional): entropy multiplier when computing the total loss.
88+ entropy_coef (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
89+ * **Scalar**: one value applied to the summed entropy of every action head.
90+ * **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
8891 Defaults to ``0.01``.
8992 critic_coef (scalar, optional): critic loss multiplier when computing the total
9093 loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value
@@ -330,7 +333,7 @@ def __init__(
330333 * ,
331334 entropy_bonus : bool = True ,
332335 samples_mc_entropy : int = 1 ,
333- entropy_coef : float = 0.01 ,
336+ entropy_coef : float | Mapping [ str , float ] = 0.01 ,
334337 critic_coef : float | None = None ,
335338 loss_critic_type : str = "smooth_l1" ,
336339 normalize_advantage : bool = False ,
@@ -408,7 +411,22 @@ def __init__(
408411 torch , "get_default_device" , lambda : torch .device ("cpu" )
409412 )()
410413
411- self .register_buffer ("entropy_coef" , torch .tensor (entropy_coef , device = device ))
414+ if isinstance (entropy_coef , Mapping ):
415+ # Store the mapping for per-head coefficients
416+ self ._entropy_coef_map = {str (k ): float (v ) for k , v in entropy_coef .items ()}
417+ # Register an empty buffer for compatibility
418+ self .register_buffer ("entropy_coef" , torch .tensor (0.0 ))
419+ elif isinstance (entropy_coef , (float , int , torch .Tensor )):
420+ # Register the scalar entropy coefficient
421+ coef = (
422+ float (entropy_coef )
423+ if not torch .is_tensor (entropy_coef )
424+ else float (entropy_coef .item ())
425+ )
426+ self .register_buffer ("entropy_coef" , torch .tensor (coef ))
427+ self ._entropy_coef_map = None
428+ else :
429+ raise TypeError ("entropy_coef must be a float or a Mapping[str, float]" )
412430 if critic_coef is not None :
413431 self .register_buffer (
414432 "critic_coef" , torch .tensor (critic_coef , device = device )
@@ -540,7 +558,6 @@ def _get_entropy(
540558 return entropy .unsqueeze (- 1 )
541559
542560 def _get_cur_log_prob (self , tensordict ):
543-
544561 if isinstance (
545562 self .actor_network ,
546563 (ProbabilisticTensorDictSequential , ProbabilisticTensorDictModule ),
@@ -589,7 +606,6 @@ def _get_cur_log_prob(self, tensordict):
589606 def _log_weight (
590607 self , tensordict : TensorDictBase , adv_shape : torch .Size
591608 ) -> tuple [torch .Tensor , d .Distribution , torch .Tensor ]:
592-
593609 prev_log_prob = _maybe_get_or_select (
594610 tensordict ,
595611 self .tensor_keys .sample_log_prob ,
@@ -745,9 +761,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
745761 if is_tensor_collection (entropy ):
746762 # Reports the entropy of each action head.
747763 td_out .set ("composite_entropy" , entropy .detach ())
748- entropy = _sum_td_features (entropy )
749- td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
750- td_out .set ("loss_entropy" , - self .entropy_coef * entropy )
764+ td_out .set (
765+ "entropy" , _sum_td_features (entropy ).detach ().mean ()
766+ ) # for logging
767+ else :
768+ td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
769+ td_out .set ("loss_entropy" , self ._weighted_loss_entropy (entropy ))
751770 if self ._has_critic :
752771 loss_critic , value_clip_fraction = self .loss_critic (tensordict )
753772 td_out .set ("loss_critic" , loss_critic )
@@ -814,6 +833,35 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
814833 }
815834 self ._value_estimator .set_keys (** tensor_keys )
816835
836+ def _weighted_loss_entropy (
837+ self , entropy : torch .Tensor | TensorDictBase
838+ ) -> torch .Tensor :
839+ """Compute the weighted entropy loss.
840+
841+ If `self._entropy_coef_map` is provided, apply per-head entropy coefficients.
842+ Otherwise, use the scalar `self.entropy_coef`.
843+ """
844+ if self ._entropy_coef_map is None :
845+ if is_tensor_collection (entropy ):
846+ entropy = _sum_td_features (entropy )
847+ return - self .entropy_coef * entropy
848+
849+ loss_term = None # running sum over heads
850+ for head_name , entropy_head in entropy .items ():
851+ try :
852+ coeff = self ._entropy_coef_map [head_name ]
853+ except KeyError as exc :
854+ raise KeyError (f"Missing entropy coef for head '{ head_name } '" ) from exc
855+ coeff_t = torch .as_tensor (
856+ coeff , dtype = entropy_head .dtype , device = entropy_head .device
857+ )
858+ head_loss_term = - coeff_t * _sum_td_features (entropy_head )
859+ loss_term = (
860+ head_loss_term if loss_term is None else loss_term + head_loss_term
861+ ) # accumulate
862+
863+ return loss_term
864+
817865
818866class ClipPPOLoss (PPOLoss ):
819867 """Clipped PPO loss.
@@ -836,7 +884,9 @@ class ClipPPOLoss(PPOLoss):
836884 ``samples_mc_entropy`` will control how many
837885 samples will be used to compute this estimate.
838886 Defaults to ``1``.
839- entropy_coef (scalar, optional): entropy multiplier when computing the total loss.
887+ entropy_coef (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
888+ * **Scalar**: one value applied to the summed entropy of every action head.
889+ * **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
840890 Defaults to ``0.01``.
841891 critic_coef (scalar, optional): critic loss multiplier when computing the total
842892 loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value
@@ -939,7 +989,7 @@ def __init__(
939989 clip_epsilon : float = 0.2 ,
940990 entropy_bonus : bool = True ,
941991 samples_mc_entropy : int = 1 ,
942- entropy_coef : float = 0.01 ,
992+ entropy_coef : float | Mapping [ str , float ] = 0.01 ,
943993 critic_coef : float | None = None ,
944994 loss_critic_type : str = "smooth_l1" ,
945995 normalize_advantage : bool = False ,
@@ -1064,9 +1114,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
10641114 if is_tensor_collection (entropy ):
10651115 # Reports the entropy of each action head.
10661116 td_out .set ("composite_entropy" , entropy .detach ())
1067- entropy = _sum_td_features (entropy )
1068- td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
1069- td_out .set ("loss_entropy" , - self .entropy_coef * entropy )
1117+ td_out .set (
1118+ "entropy" , _sum_td_features (entropy ).detach ().mean ()
1119+ ) # for logging
1120+ else :
1121+ td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
1122+ td_out .set ("loss_entropy" , self ._weighted_loss_entropy (entropy ))
10701123 if self ._has_critic :
10711124 loss_critic , value_clip_fraction = self .loss_critic (tensordict )
10721125 td_out .set ("loss_critic" , loss_critic )
@@ -1120,7 +1173,9 @@ class KLPENPPOLoss(PPOLoss):
11201173 ``samples_mc_entropy`` will control how many
11211174 samples will be used to compute this estimate.
11221175 Defaults to ``1``.
1123- entropy_coef (scalar, optional): entropy multiplier when computing the total loss.
1176+ entropy_coef (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
1177+ * **Scalar**: one value applied to the summed entropy of every action head.
1178+ * **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
11241179 Defaults to ``0.01``.
11251180 critic_coef (scalar, optional): critic loss multiplier when computing the total
11261181 loss. Defaults to ``1.0``.
@@ -1224,7 +1279,7 @@ def __init__(
12241279 samples_mc_kl : int = 1 ,
12251280 entropy_bonus : bool = True ,
12261281 samples_mc_entropy : int = 1 ,
1227- entropy_coef : float = 0.01 ,
1282+ entropy_coef : float | Mapping [ str , float ] = 0.01 ,
12281283 critic_coef : float | None = None ,
12291284 loss_critic_type : str = "smooth_l1" ,
12301285 normalize_advantage : bool = False ,
@@ -1405,9 +1460,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
14051460 if is_tensor_collection (entropy ):
14061461 # Reports the entropy of each action head.
14071462 td_out .set ("composite_entropy" , entropy .detach ())
1408- entropy = _sum_td_features (entropy )
1409- td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
1410- td_out .set ("loss_entropy" , - self .entropy_coef * entropy )
1463+ td_out .set (
1464+ "entropy" , _sum_td_features (entropy ).detach ().mean ()
1465+ ) # for logging
1466+ else :
1467+ td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
1468+ td_out .set ("loss_entropy" , self ._weighted_loss_entropy (entropy ))
14111469 if self ._has_critic :
14121470 loss_critic , value_clip_fraction = self .loss_critic (tensordict_copy )
14131471 td_out .set ("loss_critic" , loss_critic )
0 commit comments