@@ -1464,12 +1464,29 @@ def make_env():
14641464 "transformed_in,transformed_out" , [[True , True ], [False , False ]]
14651465 ) # 1226: effociency
14661466 @pytest .mark .parametrize ("static_seed" , [False , True ])
1467+ @pytest .mark .parametrize ("penv_device" , ["cpu" , None ])
1468+ @pytest .mark .parametrize ("env_device" , ["cpu" , None ])
1469+ @pytest .mark .parametrize ("bwad" , [True , False ])
14671470 def test_parallel_env_seed (
1468- self , env_name , frame_skip , transformed_in , transformed_out , static_seed
1471+ self ,
1472+ env_name ,
1473+ frame_skip ,
1474+ transformed_in ,
1475+ transformed_out ,
1476+ static_seed ,
1477+ penv_device ,
1478+ env_device ,
1479+ bwad ,
14691480 ):
14701481 env_name = env_name ()
14711482 env_parallel , env_serial , _ , _ = _make_envs (
1472- env_name , frame_skip , transformed_in , transformed_out , 5
1483+ env_name ,
1484+ frame_skip ,
1485+ transformed_in ,
1486+ transformed_out ,
1487+ 5 ,
1488+ p_env_device = penv_device ,
1489+ env_device = env_device ,
14731490 )
14741491 try :
14751492 out_seed_serial = env_serial .set_seed (0 , static_seed = static_seed )
@@ -1479,7 +1496,10 @@ def test_parallel_env_seed(
14791496 torch .manual_seed (0 )
14801497
14811498 td_serial = env_serial .rollout (
1482- max_steps = 10 , auto_reset = False , tensordict = td0_serial
1499+ max_steps = 10 ,
1500+ auto_reset = False ,
1501+ tensordict = td0_serial ,
1502+ break_when_any_done = bwad ,
14831503 ).contiguous ()
14841504 key = "pixels" if "pixels" in td_serial .keys () else "observation"
14851505 torch .testing .assert_close (
@@ -1494,7 +1514,10 @@ def test_parallel_env_seed(
14941514 torch .manual_seed (0 )
14951515 assert out_seed_parallel == out_seed_serial
14961516 td_parallel = env_parallel .rollout (
1497- max_steps = 10 , auto_reset = False , tensordict = td0_parallel
1517+ max_steps = 10 ,
1518+ auto_reset = False ,
1519+ tensordict = td0_parallel ,
1520+ break_when_any_done = bwad ,
14981521 ).contiguous ()
14991522 torch .testing .assert_close (
15001523 td_parallel [:, :- 1 ].get (("next" , key )), td_parallel [:, 1 :].get (key )
@@ -1670,7 +1693,7 @@ def test_parallel_env_device(
16701693 frame_skip ,
16711694 transformed_in = transformed_in ,
16721695 transformed_out = transformed_out ,
1673- device = device ,
1696+ env_device = device ,
16741697 N = N ,
16751698 local_mp_ctx = "spawn" ,
16761699 )
0 commit comments