@@ -1471,12 +1471,29 @@ def make_env():
14711471 "transformed_in,transformed_out" , [[True , True ], [False , False ]]
14721472 ) # 1226: effociency
14731473 @pytest .mark .parametrize ("static_seed" , [False , True ])
1474+ @pytest .mark .parametrize ("penv_device" , ["cpu" , None ])
1475+ @pytest .mark .parametrize ("env_device" , ["cpu" , None ])
1476+ @pytest .mark .parametrize ("bwad" , [True , False ])
14741477 def test_parallel_env_seed (
1475- self , env_name , frame_skip , transformed_in , transformed_out , static_seed
1478+ self ,
1479+ env_name ,
1480+ frame_skip ,
1481+ transformed_in ,
1482+ transformed_out ,
1483+ static_seed ,
1484+ penv_device ,
1485+ env_device ,
1486+ bwad ,
14761487 ):
14771488 env_name = env_name ()
14781489 env_parallel , env_serial , _ , _ = _make_envs (
1479- env_name , frame_skip , transformed_in , transformed_out , 5
1490+ env_name ,
1491+ frame_skip ,
1492+ transformed_in ,
1493+ transformed_out ,
1494+ 5 ,
1495+ p_env_device = penv_device ,
1496+ env_device = env_device ,
14801497 )
14811498 try :
14821499 out_seed_serial = env_serial .set_seed (0 , static_seed = static_seed )
@@ -1486,7 +1503,10 @@ def test_parallel_env_seed(
14861503 torch .manual_seed (0 )
14871504
14881505 td_serial = env_serial .rollout (
1489- max_steps = 10 , auto_reset = False , tensordict = td0_serial
1506+ max_steps = 10 ,
1507+ auto_reset = False ,
1508+ tensordict = td0_serial ,
1509+ break_when_any_done = bwad ,
14901510 ).contiguous ()
14911511 key = "pixels" if "pixels" in td_serial .keys () else "observation"
14921512 torch .testing .assert_close (
@@ -1501,7 +1521,10 @@ def test_parallel_env_seed(
15011521 torch .manual_seed (0 )
15021522 assert out_seed_parallel == out_seed_serial
15031523 td_parallel = env_parallel .rollout (
1504- max_steps = 10 , auto_reset = False , tensordict = td0_parallel
1524+ max_steps = 10 ,
1525+ auto_reset = False ,
1526+ tensordict = td0_parallel ,
1527+ break_when_any_done = bwad ,
15051528 ).contiguous ()
15061529 torch .testing .assert_close (
15071530 td_parallel [:, :- 1 ].get (("next" , key )), td_parallel [:, 1 :].get (key )
@@ -1677,7 +1700,7 @@ def test_parallel_env_device(
16771700 frame_skip ,
16781701 transformed_in = transformed_in ,
16791702 transformed_out = transformed_out ,
1680- device = device ,
1703+ env_device = device ,
16811704 N = N ,
16821705 local_mp_ctx = "spawn" ,
16831706 )
0 commit comments