1111import warnings
1212from copy import copy , deepcopy
1313from datetime import timedelta
14- from typing import Callable , OrderedDict
14+ from typing import Any , Callable , OrderedDict , Sequence
1515
1616import torch .cuda
1717from tensordict import TensorDict , TensorDictBase
@@ -131,6 +131,7 @@ def _distributed_init_collection_node(
131131 num_workers ,
132132 env_make ,
133133 policy ,
134+ policy_factory ,
134135 frames_per_batch ,
135136 collector_kwargs ,
136137 verbose = True ,
@@ -143,6 +144,7 @@ def _distributed_init_collection_node(
143144 num_workers ,
144145 env_make ,
145146 policy ,
147+ policy_factory ,
146148 frames_per_batch ,
147149 collector_kwargs ,
148150 verbose = verbose ,
@@ -156,6 +158,7 @@ def _run_collector(
156158 num_workers ,
157159 env_make ,
158160 policy ,
161+ policy_factory ,
159162 frames_per_batch ,
160163 collector_kwargs ,
161164 verbose = True ,
@@ -178,12 +181,17 @@ def _run_collector(
178181 policy_weights = TensorDict .from_module (policy )
179182 policy_weights = policy_weights .data .lock_ ()
180183 else :
181- warnings .warn (_NON_NN_POLICY_WEIGHTS )
184+ if collector_kwargs .get ("remote_weight_updater" ) is None and (
185+ policy_factory is None
186+ or (isinstance (policy_factory , Sequence ) and not any (policy_factory ))
187+ ):
188+ warnings .warn (_NON_NN_POLICY_WEIGHTS )
182189 policy_weights = TensorDict (lock = True )
183190
184191 collector = collector_class (
185192 env_make ,
186193 policy ,
194+ policy_factory = policy_factory ,
187195 frames_per_batch = frames_per_batch ,
188196 total_frames = - 1 ,
189197 split_trajs = False ,
@@ -278,8 +286,8 @@ class DistributedDataCollector(DataCollectorBase):
278286 pickled directly), the :arg:`policy_factory` should be used instead.
279287
280288 Keyword Args:
281- policy_factory (Callable[[], Callable], optional): a callable that returns
282- a policy instance. This is exclusive with the `policy` argument.
289+ policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
290+ (or list of callables) that returns a policy instance. This is exclusive with the `policy` argument.
283291
284292 .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
285293
@@ -411,14 +419,16 @@ class DistributedDataCollector(DataCollectorBase):
411419 to learn more.
412420 Defaults to ``"submitit"``.
413421 tcp_port (int, optional): the TCP port to be used. Defaults to 10003.
414- local_weight_updater (LocalWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.LocalWeightUpdaterBase`
422+ local_weight_updater (LocalWeightUpdaterBase or constructor , optional): An instance of :class:`~torchrl.collectors.LocalWeightUpdaterBase`
415423 or its subclass, responsible for updating the policy weights on the local inference worker.
416424 This is typically not used in :class:`~torchrl.collectors.distributed.DistributedDataCollector` as it
417425 focuses on distributed environments.
418- remote_weight_updater (RemoteWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.RemoteWeightUpdaterBase`
426+ Consider using a constructor if the updater needs to be serialized.
427+ remote_weight_updater (RemoteWeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.RemoteWeightUpdaterBase`
419428 or its subclass, responsible for updating the policy weights on distributed inference workers.
420429 If not provided, a :class:`~torchrl.collectors.distributed.DistributedRemoteWeightUpdater` will be used by
421430 default, which handles weight synchronization across distributed workers.
431+ Consider using a constructor if the updater needs to be serialized.
422432
423433 """
424434
@@ -429,31 +439,37 @@ def __init__(
429439 create_env_fn ,
430440 policy : Callable [[TensorDictBase ], TensorDictBase ] | None = None ,
431441 * ,
432- policy_factory : Callable [[], Callable ] | None = None ,
442+ policy_factory : Callable [[], Callable ]
443+ | list [Callable [[] | Callable ]]
444+ | None = None ,
433445 frames_per_batch : int ,
434446 total_frames : int = - 1 ,
435- device : torch .device | list [torch .device ] = None ,
436- storing_device : torch .device | list [torch .device ] = None ,
437- env_device : torch .device | list [torch .device ] = None ,
438- policy_device : torch .device | list [torch .device ] = None ,
447+ device : torch .device | list [torch .device ] | None = None ,
448+ storing_device : torch .device | list [torch .device ] | None = None ,
449+ env_device : torch .device | list [torch .device ] | None = None ,
450+ policy_device : torch .device | list [torch .device ] | None = None ,
439451 max_frames_per_traj : int = - 1 ,
440452 init_random_frames : int = - 1 ,
441453 reset_at_each_iter : bool = False ,
442454 postproc : Callable | None = None ,
443455 split_trajs : bool = False ,
444456 exploration_type : ExporationType = DEFAULT_EXPLORATION_TYPE , # noqa
445457 collector_class : type = SyncDataCollector ,
446- collector_kwargs : dict = None ,
458+ collector_kwargs : dict [ str , Any ] | None = None ,
447459 num_workers_per_collector : int = 1 ,
448460 sync : bool = False ,
449- slurm_kwargs : dict | None = None ,
461+ slurm_kwargs : dict [ str , Any ] | None = None ,
450462 backend : str = "gloo" ,
451463 update_after_each_batch : bool = False ,
452464 max_weight_update_interval : int = - 1 ,
453465 launcher : str = "submitit" ,
454- tcp_port : int = None ,
455- remote_weight_updater : RemoteWeightUpdaterBase | None = None ,
456- local_weight_updater : LocalWeightUpdaterBase | None = None ,
466+ tcp_port : int | None = None ,
467+ remote_weight_updater : RemoteWeightUpdaterBase
468+ | Callable [[], RemoteWeightUpdaterBase ]
469+ | None = None ,
470+ local_weight_updater : LocalWeightUpdaterBase
471+ | Callable [[], LocalWeightUpdaterBase ]
472+ | None = None ,
457473 ):
458474
459475 if collector_class == "async" :
@@ -465,18 +481,22 @@ def __init__(
465481 self .collector_class = collector_class
466482 self .env_constructors = create_env_fn
467483 self .policy = policy
484+ if not isinstance (policy_factory , Sequence ):
485+ policy_factory = [policy_factory for _ in range (len (self .env_constructors ))]
486+ self .policy_factory = policy_factory
468487 if isinstance (policy , nn .Module ):
469488 policy_weights = TensorDict .from_module (policy )
470489 policy_weights = policy_weights .data .lock_ ()
471- elif policy_factory is not None :
490+ elif any ( policy_factory ) :
472491 policy_weights = None
473492 if remote_weight_updater is None :
474493 raise RuntimeError (
475494 "remote_weight_updater must be passed along with "
476495 "a policy_factory."
477496 )
478497 else :
479- warnings .warn (_NON_NN_POLICY_WEIGHTS )
498+ if not any (policy_factory ):
499+ warnings .warn (_NON_NN_POLICY_WEIGHTS )
480500 policy_weights = TensorDict (lock = True )
481501 self .policy_weights = policy_weights
482502 self .num_workers = len (create_env_fn )
@@ -664,12 +684,15 @@ def _make_container(self):
664684 if self ._VERBOSE :
665685 torchrl_logger .info ("making container" )
666686 env_constructor = self .env_constructors [0 ]
687+ kwargs = self .collector_kwargs [0 ]
667688 pseudo_collector = SyncDataCollector (
668689 env_constructor ,
669- self .policy ,
690+ policy = self .policy ,
691+ policy_factory = self .policy_factory [0 ],
670692 frames_per_batch = self ._frames_per_batch_corrected ,
671693 total_frames = - 1 ,
672694 split_trajs = False ,
695+ ** kwargs ,
673696 )
674697 for _data in pseudo_collector :
675698 break
@@ -713,6 +736,7 @@ def _init_worker_dist_submitit(self, executor, i):
713736 self .num_workers_per_collector ,
714737 env_make ,
715738 self .policy ,
739+ self .policy_factory [i ],
716740 self ._frames_per_batch_corrected ,
717741 self .collector_kwargs [i ],
718742 self ._VERBOSE ,
@@ -734,6 +758,7 @@ def get_env_make(i):
734758 "num_workers" : self .num_workers_per_collector ,
735759 "env_make" : get_env_make (i ),
736760 "policy" : self .policy ,
761+ "policy_factory" : self .policy_factory [i ],
737762 "frames_per_batch" : self ._frames_per_batch_corrected ,
738763 "collector_kwargs" : self .collector_kwargs [i ],
739764 }
@@ -760,6 +785,7 @@ def _init_worker_dist_mp(self, i):
760785 self .num_workers_per_collector ,
761786 env_make ,
762787 self .policy ,
788+ self .policy_factory [i ],
763789 self ._frames_per_batch_corrected ,
764790 self .collector_kwargs [i ],
765791 self ._VERBOSE ,
0 commit comments