@@ -690,6 +690,39 @@ def make_env():
690690 del env
691691
692692
693+ @pytest .mark .parametrize (
694+ "break_when_any_done,break_when_all_done" ,
695+ [[True , False ], [False , True ], [False , False ]],
696+ )
697+ @pytest .mark .parametrize ("n_envs" , [1 , 4 ])
698+ def test_collector_outplace_policy (n_envs , break_when_any_done , break_when_all_done ):
699+ def policy_inplace (td ):
700+ td .set ("action" , torch .ones (td .shape + (1 ,)))
701+ return td
702+
703+ def policy_outplace (td ):
704+ return td .empty ().set ("action" , torch .ones (td .shape + (1 ,)))
705+
706+ if n_envs == 1 :
707+ env = CountingEnv (10 )
708+ else :
709+ env = SerialEnv (
710+ n_envs ,
711+ [functools .partial (CountingEnv , 10 + i ) for i in range (n_envs )],
712+ )
713+ env .reset ()
714+ c_inplace = SyncDataCollector (
715+ env , policy_inplace , frames_per_batch = 10 , total_frames = 100
716+ )
717+ d_inplace = torch .cat (list (c_inplace ), dim = 0 )
718+ env .reset ()
719+ c_outplace = SyncDataCollector (
720+ env , policy_outplace , frames_per_batch = 10 , total_frames = 100
721+ )
722+ d_outplace = torch .cat (list (c_outplace ), dim = 0 )
723+ assert_allclose_td (d_inplace , d_outplace )
724+
725+
693726# Deprecated reset_when_done
694727# @pytest.mark.parametrize("num_env", [1, 2])
695728# @pytest.mark.parametrize("env_name", ["vec"])
0 commit comments