@@ -440,6 +440,11 @@ class SyncDataCollector(DataCollectorBase):
440440 cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
441441 in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
442442 If a dictionary of kwargs is passed, it will be used to wrap the policy.
443+ no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
444+ For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_
445+ or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected
446+ crashes.
447+ Defaults to ``False``.
443448
444449 Examples:
445450 >>> from torchrl.envs.libs.gym import GymEnv
@@ -532,6 +537,7 @@ def __init__(
532537 trust_policy : bool = None ,
533538 compile_policy : bool | Dict [str , Any ] | None = None ,
534539 cudagraph_policy : bool | Dict [str , Any ] | None = None ,
540+ no_cuda_sync : bool = False ,
535541 ** kwargs ,
536542 ):
537543 from torchrl .envs .batched_envs import BatchedEnvBase
@@ -625,6 +631,7 @@ def __init__(
625631 else :
626632 self ._sync_policy = _do_nothing
627633 self .device = device
634+ self .no_cuda_sync = no_cuda_sync
628635 # Check if we need to cast things from device to device
629636 # If the policy has a None device and the env too, no need to cast (we don't know
630637 # and assume the user knows what she's doing).
@@ -1010,12 +1017,16 @@ def iterator(self) -> Iterator[TensorDictBase]:
10101017 Yields: TensorDictBase objects containing (chunks of) trajectories
10111018
10121019 """
1013- if self .storing_device and self .storing_device .type == "cuda" :
1020+ if (
1021+ not self .no_cuda_sync
1022+ and self .storing_device
1023+ and self .storing_device .type == "cuda"
1024+ ):
10141025 stream = torch .cuda .Stream (self .storing_device , priority = - 1 )
10151026 event = stream .record_event ()
10161027 streams = [stream ]
10171028 events = [event ]
1018- elif self .storing_device is None :
1029+ elif not self . no_cuda_sync and self .storing_device is None :
10191030 streams = []
10201031 events = []
10211032 # this way of checking cuda is robust to lazy stacks with mismatching shapes
@@ -1166,10 +1177,17 @@ def rollout(self) -> TensorDictBase:
11661177 else :
11671178 if self ._cast_to_policy_device :
11681179 if self .policy_device is not None :
1180+ # This is unsafe if the shuttle is in pin_memory -- otherwise cuda will be happy with non_blocking
1181+ non_blocking = (
1182+ not self .no_cuda_sync
1183+ or self .policy_device .type == "cuda"
1184+ )
11691185 policy_input = self ._shuttle .to (
1170- self .policy_device , non_blocking = True
1186+ self .policy_device ,
1187+ non_blocking = non_blocking ,
11711188 )
1172- self ._sync_policy ()
1189+ if not self .no_cuda_sync :
1190+ self ._sync_policy ()
11731191 elif self .policy_device is None :
11741192 # we know the tensordict has a device otherwise we would not be here
11751193 # we can pass this, clear_device_ must have been called earlier
@@ -1191,8 +1209,14 @@ def rollout(self) -> TensorDictBase:
11911209
11921210 if self ._cast_to_env_device :
11931211 if self .env_device is not None :
1194- env_input = self ._shuttle .to (self .env_device , non_blocking = True )
1195- self ._sync_env ()
1212+ non_blocking = (
1213+ not self .no_cuda_sync or self .env_device .type == "cuda"
1214+ )
1215+ env_input = self ._shuttle .to (
1216+ self .env_device , non_blocking = non_blocking
1217+ )
1218+ if not self .no_cuda_sync :
1219+ self ._sync_env ()
11961220 elif self .env_device is None :
11971221 # we know the tensordict has a device otherwise we would not be here
11981222 # we can pass this, clear_device_ must have been called earlier
@@ -1216,10 +1240,16 @@ def rollout(self) -> TensorDictBase:
12161240 return
12171241 else :
12181242 if self .storing_device is not None :
1243+ non_blocking = (
1244+ not self .no_cuda_sync or self .storing_device .type == "cuda"
1245+ )
12191246 tensordicts .append (
1220- self ._shuttle .to (self .storing_device , non_blocking = True )
1247+ self ._shuttle .to (
1248+ self .storing_device , non_blocking = non_blocking
1249+ )
12211250 )
1222- self ._sync_storage ()
1251+ if not self .no_cuda_sync :
1252+ self ._sync_storage ()
12231253 else :
12241254 tensordicts .append (self ._shuttle )
12251255
@@ -1558,6 +1588,11 @@ class _MultiDataCollector(DataCollectorBase):
15581588 cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
15591589 in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
15601590 If a dictionary of kwargs is passed, it will be used to wrap the policy.
1591+ no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
1592+ For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_
1593+ or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected
1594+ crashes.
1595+ Defaults to ``False``.
15611596
15621597 """
15631598
@@ -1597,6 +1632,7 @@ def __init__(
15971632 trust_policy : bool = None ,
15981633 compile_policy : bool | Dict [str , Any ] | None = None ,
15991634 cudagraph_policy : bool | Dict [str , Any ] | None = None ,
1635+ no_cuda_sync : bool = False ,
16001636 ):
16011637 self .closed = True
16021638 self .num_workers = len (create_env_fn )
@@ -1636,6 +1672,7 @@ def __init__(
16361672 self .env_device = env_devices
16371673
16381674 del storing_device , env_device , policy_device , device
1675+ self .no_cuda_sync = no_cuda_sync
16391676
16401677 self ._use_buffers = use_buffers
16411678 self .replay_buffer = replay_buffer
@@ -1909,6 +1946,7 @@ def _run_processes(self) -> None:
19091946 "cudagraph_policy" : self .cudagraphed_policy_kwargs
19101947 if self .cudagraphed_policy
19111948 else False ,
1949+ "no_cuda_sync" : self .no_cuda_sync ,
19121950 }
19131951 proc = _ProcessNoWarn (
19141952 target = _main_async_collector ,
@@ -2914,6 +2952,7 @@ def _main_async_collector(
29142952 trust_policy : bool = False ,
29152953 compile_policy : bool = False ,
29162954 cudagraph_policy : bool = False ,
2955+ no_cuda_sync : bool = False ,
29172956) -> None :
29182957 pipe_parent .close ()
29192958 # init variables that will be cleared when closing
@@ -2943,6 +2982,7 @@ def _main_async_collector(
29432982 trust_policy = trust_policy ,
29442983 compile_policy = compile_policy ,
29452984 cudagraph_policy = cudagraph_policy ,
2985+ no_cuda_sync = no_cuda_sync ,
29462986 )
29472987 use_buffers = inner_collector ._use_buffers
29482988 if verbose :
0 commit comments