Skip to content

Commit bbd8b93

Browse files
committed
Update
[ghstack-poisoned]
1 parent 5c2d8a8 commit bbd8b93

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

torchrl/collectors/collectors.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,18 @@ def update_policy_weights_(
475475
# Apply to local policy
476476
if hasattr(self, "policy") and isinstance(self.policy, nn.Module):
477477
strategy.apply_weights(self.policy, weights)
478+
elif (
479+
hasattr(self, "_original_policy")
480+
and isinstance(self._original_policy, nn.Module)
481+
and hasattr(self, "policy")
482+
and isinstance(self.policy, nn.Module)
483+
):
484+
# If no weights were provided, mirror weights from the original (trainer) policy
485+
from torchrl.weight_update.weight_sync_schemes import WeightStrategy
486+
487+
strategy = WeightStrategy(extract_as="tensordict")
488+
weights = strategy.extract_weights(self._original_policy)
489+
strategy.apply_weights(self.policy, weights)
478490
# Otherwise, no action needed - policy is local and changes are immediately visible
479491

480492
def __iter__(self) -> Iterator[TensorDictBase]:

0 commit comments

Comments
 (0)