@@ -104,6 +104,8 @@ class PPOLoss(LossModule):
104104 * **Scalar**: one value applied to the summed entropy of every action head.
105105 * **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
106106 Defaults to ``0.01``.
107+
108+ See :ref:`ppo_entropy_coefficients` for detailed usage examples and troubleshooting.
107109 log_explained_variance (bool, optional): if ``True``, the explained variance of the critic
108110 predictions w.r.t. value targets will be computed and logged as ``"explained_variance"``.
109111 This can help monitor critic quality during training. Best possible score is 1.0, lower values are worse. Defaults to ``True``.
@@ -217,7 +219,7 @@ class PPOLoss(LossModule):
217219 >>> action = spec.rand(batch)
218220 >>> data = TensorDict({"observation": torch.randn(*batch, n_obs),
219221 ... "action": action,
220- ... "sample_log_prob ": torch.randn_like(action[..., 1]),
222+ ... "action_log_prob ": torch.randn_like(action[..., 1]),
221223 ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
222224 ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
223225 ... ("next", "reward"): torch.randn(*batch, 1),
@@ -227,6 +229,8 @@ class PPOLoss(LossModule):
227229 TensorDict(
228230 fields={
229231 entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
232+ explained_variance: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
233+ kl_approx: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
230234 loss_critic: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
231235 loss_entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
232236 loss_objective: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
@@ -279,12 +283,69 @@ class PPOLoss(LossModule):
279283 ... next_observation=torch.randn(*batch, n_obs))
280284 >>> loss_objective.backward()
281285
286+ **Simple Entropy Coefficient Examples**:
287+ >>> # Scalar entropy coefficient (default behavior)
288+ >>> loss = PPOLoss(actor, critic, entropy_coeff=0.01)
289+ >>>
290+ >>> # Per-head entropy coefficients (for composite action spaces)
291+ >>> entropy_coeff = {
292+ ... ("agent0", "action_log_prob"): 0.01, # Low exploration
293+ ... ("agent1", "action_log_prob"): 0.05, # High exploration
294+ ... }
295+ >>> loss = PPOLoss(actor, critic, entropy_coeff=entropy_coeff)
296+
282297 .. note::
283298 There is an exception regarding compatibility with non-tensordict-based modules.
284299 If the actor network is probabilistic and uses a :class:`~tensordict.nn.distributions.CompositeDistribution`,
285300 this class must be used with tensordicts and cannot function as a tensordict-independent module.
286301 This is because composite action spaces inherently rely on the structured representation of data provided by
287302 tensordicts to handle their actions.
303+
304+ .. _ppo_entropy_coefficients:
305+
306+ .. note::
307+ **Entropy Bonus and Coefficient Management**
308+
309+ The entropy bonus encourages exploration by adding the negative entropy of the policy to the loss.
310+ This can be configured in two ways:
311+
312+ **Scalar Coefficient (Default)**: Use a single coefficient for all action heads:
313+ >>> loss = PPOLoss(actor, critic, entropy_coeff=0.01)
314+
315+ **Per-Head Coefficients**: Use different coefficients for different action components:
316+ >>> # For a robot with movement and gripper actions
317+ >>> entropy_coeff = {
318+ ... ("agent0", "action_log_prob"): 0.01, # Movement: low exploration
319+ ... ("agent1", "action_log_prob"): 0.05, # Gripper: high exploration
320+ ... }
321+ >>> loss = PPOLoss(actor, critic, entropy_coeff=entropy_coeff)
322+
323+ **Key Requirements**: When using per-head coefficients, you must provide the full nested key
324+ path to each action head's log probability (e.g., `("agent0", "action_log_prob")`).
325+
326+ **Monitoring Entropy Loss**:
327+
328+ When using composite action spaces, the loss output includes:
329+ - `"entropy"`: Summed entropy across all action heads (for logging)
330+ - `"composite_entropy"`: Individual entropy values for each action head
331+ - `"loss_entropy"`: The weighted entropy loss term
332+
333+ Example output:
334+ >>> result = loss(data)
335+ >>> print(result["entropy"]) # Total entropy: 2.34
336+ >>> print(result["composite_entropy"]) # Per-head: {"movement": 1.2, "gripper": 1.14}
337+ >>> print(result["loss_entropy"]) # Weighted loss: -0.0234
338+
339+ **Common Issues**:
340+
341+ **KeyError: "Missing entropy coeff for head 'head_name'"**:
342+ - Ensure you provide coefficients for ALL action heads
343+ - Use full nested keys: `("head_name", "action_log_prob")`
344+ - Check that your action space structure matches the coefficient mapping
345+
346+ **Incorrect Entropy Calculation**:
347+ - Call `set_composite_lp_aggregate(False).set()` before creating your policy
348+ - Verify that your action space uses :class:`~tensordict.nn.distributions.CompositeDistribution`
288349 """
289350
290351 @dataclass
@@ -911,27 +972,37 @@ def _weighted_loss_entropy(
911972 Otherwise, use the scalar `self.entropy_coeff`.
912973 The entries in self._entropy_coeff_map require the full nested key to the entropy head.
913974 """
975+ # Mode 1: Use scalar entropy coefficient (default behavior)
914976 if self ._entropy_coeff_map is None :
977+ # If entropy is a TensorDict (composite action space), sum all entropy values
915978 if is_tensor_collection (entropy ):
916979 entropy = _sum_td_features (entropy )
980+ # Apply scalar coefficient: loss = -coeff * entropy (negative for maximization)
917981 return - self .entropy_coeff * entropy
918982
919- loss_term = None # running sum over heads
920- coeff = 0
983+ # Mode 2: Use per-head entropy coefficients (for composite action spaces)
984+ loss_term = None # Initialize running sum over action heads
985+ coeff = 0 # Placeholder for coefficient value
986+ # Iterate through all entropy heads in the composite action space
921987 for head_name , entropy_head in entropy .items (
922988 include_nested = True , leaves_only = True
923989 ):
924990 try :
991+ # Look up the coefficient for this specific action head
925992 coeff = self ._entropy_coeff_map [head_name ]
926993 except KeyError as exc :
994+ # Provide clear error message if coefficient mapping is incomplete
927995 raise KeyError (f"Missing entropy coeff for head '{ head_name } '" ) from exc
996+ # Convert coefficient to tensor with matching dtype and device
928997 coeff_t = torch .as_tensor (
929998 coeff , dtype = entropy_head .dtype , device = entropy_head .device
930999 )
1000+ # Compute weighted loss for this head: -coeff * entropy
9311001 head_loss_term = - coeff_t * entropy_head
1002+ # Accumulate loss terms across all heads
9321003 loss_term = (
9331004 head_loss_term if loss_term is None else loss_term + head_loss_term
934- ) # accumulate
1005+ )
9351006
9361007 return loss_term
9371008
@@ -972,10 +1043,12 @@ class ClipPPOLoss(PPOLoss):
9721043 ``samples_mc_entropy`` will control how many
9731044 samples will be used to compute this estimate.
9741045 Defaults to ``1``.
975- entropy_coeff: (scalar | Mapping[NesstedKey , scalar], optional): entropy multiplier when computing the total loss.
1046+ entropy_coeff: (scalar | Mapping[NestedKey , scalar], optional): entropy multiplier when computing the total loss.
9761047 * **Scalar**: one value applied to the summed entropy of every action head.
9771048 * **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
9781049 Defaults to ``0.01``.
1050+
1051+ See :ref:`ppo_entropy_coefficients` for detailed usage examples and troubleshooting.
9791052 critic_coeff (scalar, optional): critic loss multiplier when computing the total
9801053 loss. Defaults to ``1.0``. Set ``critic_coeff`` to ``None`` to exclude the value
9811054 loss from the forward outputs.
@@ -1269,6 +1342,8 @@ class KLPENPPOLoss(PPOLoss):
12691342 * **Scalar**: one value applied to the summed entropy of every action head.
12701343 * **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
12711344 Defaults to ``0.01``.
1345+
1346+ See :ref:`ppo_entropy_coefficients` for detailed usage examples and troubleshooting.
12721347 critic_coeff (scalar, optional): critic loss multiplier when computing the total
12731348 loss. Defaults to ``1.0``.
12741349 loss_critic_type (str, optional): loss function for the value discrepancy.
0 commit comments