File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -598,8 +598,6 @@ def _log_weight(
598598
599599 if is_composite :
600600 with set_composite_lp_aggregate (False ):
601- if log_prob .batch_size != adv_shape :
602- log_prob .batch_size = adv_shape
603601 if not is_tensor_collection (prev_log_prob ):
604602 # this isn't great: in general, multi-head actions should have a composite log-prob too
605603 warnings .warn (
@@ -612,6 +610,8 @@ def _log_weight(
612610 if is_tensor_collection (log_prob ):
613611 log_prob = _sum_td_features (log_prob )
614612 log_prob .view_as (prev_log_prob )
613+ if log_prob .batch_size != adv_shape :
614+ log_prob .batch_size = adv_shape
615615 log_weight = (log_prob - prev_log_prob ).unsqueeze (- 1 )
616616 if is_tensor_collection (log_weight ):
617617 log_weight = _sum_td_features (log_weight )
You can’t perform that action at this time.
0 commit comments