@@ -527,12 +527,9 @@ def _log_weight(
527527 self .actor_network
528528 ) if self .functional else contextlib .nullcontext ():
529529 dist = self .actor_network .get_dist (tensordict )
530- if isinstance (dist , CompositeDistribution ):
531- is_composite = True
532- else :
533- is_composite = False
534530
535- # current log_prob of actions
531+ is_composite = isinstance (dist , CompositeDistribution )
532+
536533 if is_composite :
537534 action = tensordict .select (
538535 * (
@@ -562,25 +559,26 @@ def _log_weight(
562559 log_prob = dist .log_prob (action )
563560 if is_composite :
564561 with set_composite_lp_aggregate (False ):
562+ if log_prob .batch_size != adv_shape :
563+ log_prob .batch_size = adv_shape
565564 if not is_tensor_collection (prev_log_prob ):
566- # this isn't great, in general multihead actions should have a composite log-prob too
565+ # this isn't great: in general, multi-head actions should have a composite log-prob too
567566 warnings .warn (
568567 "You are using a composite distribution, yet your log-probability is a tensor. "
569568 "Make sure you have called tensordict.nn.set_composite_lp_aggregate(False).set() at "
570569 "the beginning of your script to get a proper composite log-prob." ,
571570 category = UserWarning ,
572571 )
573- if log_prob .batch_size != adv_shape :
574- log_prob .batch_size = adv_shape
575- if (
576- is_composite
577- and not is_tensor_collection (prev_log_prob )
578- and is_tensor_collection (log_prob )
579- ):
580- log_prob = _sum_td_features (log_prob )
581- log_prob .view_as (prev_log_prob )
572+
573+ if is_tensor_collection (log_prob ):
574+ log_prob = _sum_td_features (log_prob )
575+ log_prob .view_as (prev_log_prob )
582576
583577 log_weight = (log_prob - prev_log_prob ).unsqueeze (- 1 )
578+ if is_tensor_collection (log_weight ):
579+ log_weight = _sum_td_features (log_weight )
580+ log_weight = log_weight .view (adv_shape ).unsqueeze (- 1 )
581+
584582 kl_approx = (prev_log_prob - log_prob ).unsqueeze (- 1 )
585583 if is_tensor_collection (kl_approx ):
586584 kl_approx = _sum_td_features (kl_approx )
@@ -691,9 +689,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
691689 log_weight , dist , kl_approx = self ._log_weight (
692690 tensordict , adv_shape = advantage .shape [:- 1 ]
693691 )
694- if is_tensor_collection (log_weight ):
695- log_weight = _sum_td_features (log_weight )
696- log_weight = log_weight .view (advantage .shape )
697692 neg_loss = log_weight .exp () * advantage
698693 td_out = TensorDict ({"loss_objective" : - neg_loss })
699694 td_out .set ("kl_approx" , kl_approx .detach ().mean ()) # for logging
@@ -987,8 +982,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
987982 # to different, unrelated trajectories, which is not standard. Still, it can give an idea of the weights'
988983 # dispersion.
989984 lw = log_weight .squeeze ()
990- if not isinstance (lw , torch .Tensor ):
991- lw = _sum_td_features (lw )
992985 ess = (2 * lw .logsumexp (0 ) - (2 * lw ).logsumexp (0 )).exp ()
993986 batch = log_weight .shape [0 ]
994987
@@ -1000,8 +993,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
1000993 gain2 = ratio * advantage
1001994
1002995 gain = torch .stack ([gain1 , gain2 ], - 1 ).min (dim = - 1 ).values
1003- if is_tensor_collection (gain ):
1004- gain = _sum_td_features (gain )
1005996 td_out = TensorDict ({"loss_objective" : - gain })
1006997 td_out .set ("clip_fraction" , clip_fraction )
1007998 td_out .set ("kl_approx" , kl_approx .detach ().mean ()) # for logging
@@ -1291,8 +1282,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
12911282 tensordict_copy , adv_shape = advantage .shape [:- 1 ]
12921283 )
12931284 neg_loss = log_weight .exp () * advantage
1294- if is_tensor_collection (neg_loss ):
1295- neg_loss = _sum_td_features (neg_loss )
12961285
12971286 with self .actor_network_params .to_module (
12981287 self .actor_network
0 commit comments