@@ -3085,6 +3085,7 @@ def rollout(
30853085 set_truncated : bool = False ,
30863086 out = None ,
30873087 trust_policy : bool = False ,
3088+ storing_device : DEVICE_TYPING | None = None ,
30883089 ) -> TensorDictBase :
30893090 """Executes a rollout in the environment.
30903091
@@ -3140,6 +3141,8 @@ def rollout(
31403141 trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
31413142 assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules
31423143 and ``False`` otherwise.
3144+ storing_device (Device, optional): if provided, the tensordict will be stored on this device.
3145+ Defaults to ``None``.
31433146
31443147 Returns:
31453148 TensorDict object containing the resulting trajectory.
@@ -3372,6 +3375,9 @@ def rollout(
33723375 "policy" : policy ,
33733376 "policy_device" : policy_device ,
33743377 "env_device" : env_device ,
3378+ "storing_device" : None
3379+ if storing_device is None
3380+ else torch .device (storing_device ),
33753381 "callback" : callback ,
33763382 }
33773383 if break_when_any_done or break_when_all_done :
@@ -3508,6 +3514,7 @@ def _rollout_stop_early(
35083514 policy ,
35093515 policy_device ,
35103516 env_device ,
3517+ storing_device ,
35113518 callback ,
35123519 ):
35133520 # Get the sync func
@@ -3531,7 +3538,10 @@ def _rollout_stop_early(
35313538 else :
35323539 tensordict .clear_device_ ()
35333540 tensordict = self .step (tensordict )
3534- td_append = tensordict .copy ()
3541+ if storing_device is None or tensordict .device == storing_device :
3542+ td_append = tensordict .copy ()
3543+ else :
3544+ td_append = tensordict .to (storing_device )
35353545 if break_when_all_done :
35363546 if partial_steps is not True and not partial_steps .all ():
35373547 # At least one step is partial
@@ -3589,6 +3599,7 @@ def _rollout_nonstop(
35893599 policy ,
35903600 policy_device ,
35913601 env_device ,
3602+ storing_device ,
35923603 callback ,
35933604 ):
35943605 if auto_cast_to_device :
@@ -3614,7 +3625,10 @@ def _rollout_nonstop(
36143625 tensordict = self .step (tensordict_ )
36153626 else :
36163627 tensordict , tensordict_ = self .step_and_maybe_reset (tensordict_ )
3617- tensordicts .append (tensordict )
3628+ if storing_device is None or tensordict .device == storing_device :
3629+ tensordicts .append (tensordict )
3630+ else :
3631+ tensordicts .append (tensordict .to (storing_device ))
36183632 if i == max_steps - 1 :
36193633 # we don't truncate as one could potentially continue the run
36203634 break
0 commit comments