File tree Expand file tree Collapse file tree 1 file changed +12
-0
lines changed Expand file tree Collapse file tree 1 file changed +12
-0
lines changed Original file line number Diff line number Diff line change @@ -475,6 +475,18 @@ def update_policy_weights_(
475475 # Apply to local policy
476476 if hasattr (self , "policy" ) and isinstance (self .policy , nn .Module ):
477477 strategy .apply_weights (self .policy , weights )
478+ elif (
479+ hasattr (self , "_original_policy" )
480+ and isinstance (self ._original_policy , nn .Module )
481+ and hasattr (self , "policy" )
482+ and isinstance (self .policy , nn .Module )
483+ ):
484+ # If no weights were provided, mirror weights from the original (trainer) policy
485+ from torchrl .weight_update .weight_sync_schemes import WeightStrategy
486+
487+ strategy = WeightStrategy (extract_as = "tensordict" )
488+ weights = strategy .extract_weights (self ._original_policy )
489+ strategy .apply_weights (self .policy , weights )
478490 # Otherwise, no action needed - policy is local and changes are immediately visible
479491
480492 def __iter__ (self ) -> Iterator [TensorDictBase ]:
You can’t perform that action at this time.
0 commit comments