Skip to content

Commit 6c7bf9d

Browse files
authored
[Feature] Enable storing rollouts on a different device (#3199)
1 parent 80bfa6e commit 6c7bf9d

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

test/test_env.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,28 @@ def test_batch_unlocked_with_batch_size(self, device):
827827
# env.observation_spec = env.observation_spec.clone()
828828
# assert not env._cache
829829

830+
@pytest.mark.parametrize("storing_device", get_default_devices())
831+
def test_storing_device(self, storing_device):
832+
"""Ensure rollout data tensors are moved to the requested storing_device."""
833+
env = ContinuousActionVecMockEnv(device="cpu")
834+
835+
td = env.rollout(
836+
10,
837+
storing_device=torch.device(storing_device)
838+
if storing_device is not None
839+
else None,
840+
)
841+
842+
expected_device = (
843+
torch.device(storing_device) if storing_device is not None else env.device
844+
)
845+
846+
assert td.device == expected_device
847+
848+
for _, item in td.items(True, True):
849+
if isinstance(item, torch.Tensor):
850+
assert item.device == expected_device
851+
830852

831853
class TestRollout:
832854
@pytest.mark.skipif(not _has_gym, reason="no gym")

torchrl/envs/common.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3085,6 +3085,7 @@ def rollout(
30853085
set_truncated: bool = False,
30863086
out=None,
30873087
trust_policy: bool = False,
3088+
storing_device: DEVICE_TYPING | None = None,
30883089
) -> TensorDictBase:
30893090
"""Executes a rollout in the environment.
30903091
@@ -3140,6 +3141,8 @@ def rollout(
31403141
trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
31413142
assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules
31423143
and ``False`` otherwise.
3144+
storing_device (Device, optional): if provided, the tensordict will be stored on this device.
3145+
Defaults to ``None``.
31433146
31443147
Returns:
31453148
TensorDict object containing the resulting trajectory.
@@ -3372,6 +3375,9 @@ def rollout(
33723375
"policy": policy,
33733376
"policy_device": policy_device,
33743377
"env_device": env_device,
3378+
"storing_device": None
3379+
if storing_device is None
3380+
else torch.device(storing_device),
33753381
"callback": callback,
33763382
}
33773383
if break_when_any_done or break_when_all_done:
@@ -3508,6 +3514,7 @@ def _rollout_stop_early(
35083514
policy,
35093515
policy_device,
35103516
env_device,
3517+
storing_device,
35113518
callback,
35123519
):
35133520
# Get the sync func
@@ -3531,7 +3538,10 @@ def _rollout_stop_early(
35313538
else:
35323539
tensordict.clear_device_()
35333540
tensordict = self.step(tensordict)
3534-
td_append = tensordict.copy()
3541+
if storing_device is None or tensordict.device == storing_device:
3542+
td_append = tensordict.copy()
3543+
else:
3544+
td_append = tensordict.to(storing_device)
35353545
if break_when_all_done:
35363546
if partial_steps is not True and not partial_steps.all():
35373547
# At least one step is partial
@@ -3589,6 +3599,7 @@ def _rollout_nonstop(
35893599
policy,
35903600
policy_device,
35913601
env_device,
3602+
storing_device,
35923603
callback,
35933604
):
35943605
if auto_cast_to_device:
@@ -3614,7 +3625,10 @@ def _rollout_nonstop(
36143625
tensordict = self.step(tensordict_)
36153626
else:
36163627
tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
3617-
tensordicts.append(tensordict)
3628+
if storing_device is None or tensordict.device == storing_device:
3629+
tensordicts.append(tensordict)
3630+
else:
3631+
tensordicts.append(tensordict.to(storing_device))
36183632
if i == max_steps - 1:
36193633
# we don't truncate as one could potentially continue the run
36203634
break

0 commit comments

Comments
 (0)