2121from multiprocessing .managers import SyncManager
2222from queue import Empty
2323from textwrap import indent
24- from typing import Any , Callable , Iterator , Sequence
24+ from typing import Any , Callable , Iterator , Sequence , TypeVar
2525
2626import numpy as np
2727import torch
@@ -86,6 +86,8 @@ def cudagraph_mark_step_begin():
8686
8787_is_osx = sys .platform .startswith ("darwin" )
8888
89+ T = TypeVar ("T" )
90+
8991
9092class _Interruptor :
9193 """A class for managing the collection state of a process.
@@ -343,7 +345,15 @@ class SyncDataCollector(DataCollectorBase):
343345
344346 - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
345347
348+ .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
349+ pickled directly), the :arg:`policy_factory` should be used instead.
350+
346351 Keyword Args:
352+ policy_factory (Callable[[], Callable], optional): a callable that returns
353+ a policy instance. This is exclusive with the `policy` argument.
354+
355+ .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
356+
347357 frames_per_batch (int): A keyword-only argument representing the total
348358 number of elements in a batch.
349359 total_frames (int): A keyword-only argument representing the total
@@ -515,6 +525,7 @@ def __init__(
515525 policy : None
516526 | (TensorDictModule | Callable [[TensorDictBase ], TensorDictBase ]) = None ,
517527 * ,
528+ policy_factory : Callable [[], Callable ] | None = None ,
518529 frames_per_batch : int ,
519530 total_frames : int = - 1 ,
520531 device : DEVICE_TYPING = None ,
@@ -558,8 +569,13 @@ def __init__(
558569 env .update_kwargs (create_env_kwargs )
559570
560571 if policy is None :
572+ if policy_factory is not None :
573+ policy = policy_factory ()
574+ else :
575+ policy = RandomPolicy (env .full_action_spec )
576+ elif policy_factory is not None :
577+ raise TypeError ("policy_factory cannot be used with policy argument." )
561578
562- policy = RandomPolicy (env .full_action_spec )
563579 if trust_policy is None :
564580 trust_policy = isinstance (policy , (RandomPolicy , CudaGraphModule ))
565581 self .trust_policy = trust_policy
@@ -1429,17 +1445,22 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:
14291445 self ._iter = state_dict ["iter" ]
14301446
14311447 def __repr__ (self ) -> str :
1432- env_str = indent (f"env={ self .env } " , 4 * " " )
1433- policy_str = indent (f"policy={ self .policy } " , 4 * " " )
1434- td_out_str = indent (f"td_out={ getattr (self , '_final_rollout' , None )} " , 4 * " " )
1435- string = (
1436- f"{ self .__class__ .__name__ } ("
1437- f"\n { env_str } ,"
1438- f"\n { policy_str } ,"
1439- f"\n { td_out_str } ,"
1440- f"\n exploration={ self .exploration_type } )"
1441- )
1442- return string
1448+ try :
1449+ env_str = indent (f"env={ self .env } " , 4 * " " )
1450+ policy_str = indent (f"policy={ self .policy } " , 4 * " " )
1451+ td_out_str = indent (
1452+ f"td_out={ getattr (self , '_final_rollout' , None )} " , 4 * " "
1453+ )
1454+ string = (
1455+ f"{ self .__class__ .__name__ } ("
1456+ f"\n { env_str } ,"
1457+ f"\n { policy_str } ,"
1458+ f"\n { td_out_str } ,"
1459+ f"\n exploration={ self .exploration_type } )"
1460+ )
1461+ return string
1462+ except AttributeError :
1463+ return f"{ type (self ).__name__ } (not_init)"
14431464
14441465
14451466class _MultiDataCollector (DataCollectorBase ):
@@ -1469,7 +1490,18 @@ class _MultiDataCollector(DataCollectorBase):
14691490 - In all other cases an attempt to wrap it will be undergone as such:
14701491 ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
14711492
1493+ .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
1494+ pickled directly), the :arg:`policy_factory` should be used instead.
1495+
14721496 Keyword Args:
1497+ policy_factory (Callable[[], Callable], optional): a callable that returns
1498+ a policy instance. This is exclusive with the `policy` argument.
1499+
1500+ .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
1501+
1502+ .. warning:: `policy_factory` is currently not compatible with multiprocessed data
1503+ collectors.
1504+
14731505 frames_per_batch (int): A keyword-only argument representing the
14741506 total number of elements in a batch.
14751507 total_frames (int, optional): A keyword-only argument representing the
@@ -1612,6 +1644,7 @@ def __init__(
16121644 policy : None
16131645 | (TensorDictModule | Callable [[TensorDictBase ], TensorDictBase ]) = None ,
16141646 * ,
1647+ policy_factory : Callable [[], Callable ] | None = None ,
16151648 frames_per_batch : int ,
16161649 total_frames : int | None = - 1 ,
16171650 device : DEVICE_TYPING | Sequence [DEVICE_TYPING ] | None = None ,
@@ -1695,27 +1728,36 @@ def __init__(
16951728 self ._get_weights_fn_dict = {}
16961729
16971730 if trust_policy is None :
1698- trust_policy = isinstance (policy , CudaGraphModule )
1731+ trust_policy = policy is not None and isinstance (policy , CudaGraphModule )
16991732 self .trust_policy = trust_policy
17001733
1701- for policy_device , env_maker , env_maker_kwargs in zip (
1702- self .policy_device , self .create_env_fn , self .create_env_kwargs
1703- ):
1704- (policy_copy , get_weights_fn ,) = self ._get_policy_and_device (
1705- policy = policy ,
1706- policy_device = policy_device ,
1707- env_maker = env_maker ,
1708- env_maker_kwargs = env_maker_kwargs ,
1709- )
1710- if type (policy_copy ) is not type (policy ):
1711- policy = policy_copy
1712- weights = (
1713- TensorDict .from_module (policy_copy )
1714- if isinstance (policy_copy , nn .Module )
1715- else TensorDict ()
1734+ if policy_factory is not None and policy is not None :
1735+ raise TypeError ("policy_factory and policy are mutually exclusive" )
1736+ elif policy_factory is None :
1737+ for policy_device , env_maker , env_maker_kwargs in zip (
1738+ self .policy_device , self .create_env_fn , self .create_env_kwargs
1739+ ):
1740+ (policy_copy , get_weights_fn ,) = self ._get_policy_and_device (
1741+ policy = policy ,
1742+ policy_device = policy_device ,
1743+ env_maker = env_maker ,
1744+ env_maker_kwargs = env_maker_kwargs ,
1745+ )
1746+ if type (policy_copy ) is not type (policy ):
1747+ policy = policy_copy
1748+ weights = (
1749+ TensorDict .from_module (policy_copy )
1750+ if isinstance (policy_copy , nn .Module )
1751+ else TensorDict ()
1752+ )
1753+ self ._policy_weights_dict [policy_device ] = weights
1754+ self ._get_weights_fn_dict [policy_device ] = get_weights_fn
1755+ else :
1756+ # TODO
1757+ raise NotImplementedError (
1758+ "weight syncing is not supported for multiprocessed data collectors at the "
1759+ "moment."
17161760 )
1717- self ._policy_weights_dict [policy_device ] = weights
1718- self ._get_weights_fn_dict [policy_device ] = get_weights_fn
17191761 self .policy = policy
17201762
17211763 remainder = 0
@@ -2782,7 +2824,15 @@ class aSyncDataCollector(MultiaSyncDataCollector):
27822824
27832825 - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
27842826
2827+ .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
2828+ pickled directly), the :arg:`policy_factory` should be used instead.
2829+
27852830 Keyword Args:
2831+ policy_factory (Callable[[], Callable], optional): a callable that returns
2832+ a policy instance. This is exclusive with the `policy` argument.
2833+
2834+ .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
2835+
27862836 frames_per_batch (int): A keyword-only argument representing the
27872837 total number of elements in a batch.
27882838 total_frames (int, optional): A keyword-only argument representing the
@@ -2888,8 +2938,10 @@ class aSyncDataCollector(MultiaSyncDataCollector):
28882938 def __init__ (
28892939 self ,
28902940 create_env_fn : Callable [[], EnvBase ],
2891- policy : None | (TensorDictModule | Callable [[TensorDictBase ], TensorDictBase ]),
2941+ policy : None
2942+ | (TensorDictModule | Callable [[TensorDictBase ], TensorDictBase ]) = None ,
28922943 * ,
2944+ policy_factory : Callable [[], Callable ] | None = None ,
28932945 frames_per_batch : int ,
28942946 total_frames : int | None = - 1 ,
28952947 device : DEVICE_TYPING | Sequence [DEVICE_TYPING ] | None = None ,
@@ -2914,6 +2966,7 @@ def __init__(
29142966 super ().__init__ (
29152967 create_env_fn = [create_env_fn ],
29162968 policy = policy ,
2969+ policy_factory = policy_factory ,
29172970 total_frames = total_frames ,
29182971 create_env_kwargs = [create_env_kwargs ],
29192972 max_frames_per_traj = max_frames_per_traj ,
0 commit comments