235235class TestEnvBase :
236236 def test_run_type_checks (self ):
237237 env = ContinuousActionVecMockEnv ()
238+ env .adapt_dtype = False
238239 env ._run_type_checks = False
239240 check_env_specs (env )
240241 env ._run_type_checks = True
@@ -4112,17 +4113,21 @@ def test_parallel_partial_steps(
41124113 use_buffers = use_buffers ,
41134114 device = device ,
41144115 )
4115- td = penv .reset ()
4116- psteps = torch .zeros (4 , dtype = torch .bool )
4117- psteps [[1 , 3 ]] = True
4118- td .set ("_step" , psteps )
4119-
4120- td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4121- td = penv .step (td )
4122- assert (td [0 ].get ("next" ) == 0 ).all ()
4123- assert (td [1 ].get ("next" ) != 0 ).any ()
4124- assert (td [2 ].get ("next" ) == 0 ).all ()
4125- assert (td [3 ].get ("next" ) != 0 ).any ()
4116+ try :
4117+ td = penv .reset ()
4118+ psteps = torch .zeros (4 , dtype = torch .bool )
4119+ psteps [[1 , 3 ]] = True
4120+ td .set ("_step" , psteps )
4121+
4122+ td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4123+ td = penv .step (td )
4124+ assert_allclose_td (td [0 ].get ("next" ), td [0 ], intersection = True )
4125+ assert (td [1 ].get ("next" ) != 0 ).any ()
4126+ assert_allclose_td (td [2 ].get ("next" ), td [2 ], intersection = True )
4127+ assert (td [3 ].get ("next" ) != 0 ).any ()
4128+ finally :
4129+ penv .close ()
4130+ del penv
41264131
41274132 @pytest .mark .parametrize ("use_buffers" , [False , True ])
41284133 def test_parallel_partial_step_and_maybe_reset (
@@ -4135,17 +4140,21 @@ def test_parallel_partial_step_and_maybe_reset(
41354140 use_buffers = use_buffers ,
41364141 device = device ,
41374142 )
4138- td = penv .reset ()
4139- psteps = torch .zeros (4 , dtype = torch .bool )
4140- psteps [[1 , 3 ]] = True
4141- td .set ("_step" , psteps )
4142-
4143- td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4144- td , tdreset = penv .step_and_maybe_reset (td )
4145- assert (td [0 ].get ("next" ) == 0 ).all ()
4146- assert (td [1 ].get ("next" ) != 0 ).any ()
4147- assert (td [2 ].get ("next" ) == 0 ).all ()
4148- assert (td [3 ].get ("next" ) != 0 ).any ()
4143+ try :
4144+ td = penv .reset ()
4145+ psteps = torch .zeros (4 , dtype = torch .bool , device = td .get ("done" ).device )
4146+ psteps [[1 , 3 ]] = True
4147+ td .set ("_step" , psteps )
4148+
4149+ td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4150+ td , tdreset = penv .step_and_maybe_reset (td )
4151+ assert_allclose_td (td [0 ].get ("next" ), td [0 ], intersection = True )
4152+ assert (td [1 ].get ("next" ) != 0 ).any ()
4153+ assert_allclose_td (td [2 ].get ("next" ), td [2 ], intersection = True )
4154+ assert (td [3 ].get ("next" ) != 0 ).any ()
4155+ finally :
4156+ penv .close ()
4157+ del penv
41494158
41504159 @pytest .mark .parametrize ("use_buffers" , [False , True ])
41514160 def test_serial_partial_steps (self , use_buffers , device , env_device ):
@@ -4156,17 +4165,21 @@ def test_serial_partial_steps(self, use_buffers, device, env_device):
41564165 use_buffers = use_buffers ,
41574166 device = device ,
41584167 )
4159- td = penv .reset ()
4160- psteps = torch .zeros (4 , dtype = torch .bool )
4161- psteps [[1 , 3 ]] = True
4162- td .set ("_step" , psteps )
4163-
4164- td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4165- td = penv .step (td )
4166- assert (td [0 ].get ("next" ) == 0 ).all ()
4167- assert (td [1 ].get ("next" ) != 0 ).any ()
4168- assert (td [2 ].get ("next" ) == 0 ).all ()
4169- assert (td [3 ].get ("next" ) != 0 ).any ()
4168+ try :
4169+ td = penv .reset ()
4170+ psteps = torch .zeros (4 , dtype = torch .bool )
4171+ psteps [[1 , 3 ]] = True
4172+ td .set ("_step" , psteps )
4173+
4174+ td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4175+ td = penv .step (td )
4176+ assert_allclose_td (td [0 ].get ("next" ), td [0 ], intersection = True )
4177+ assert (td [1 ].get ("next" ) != 0 ).any ()
4178+ assert_allclose_td (td [2 ].get ("next" ), td [2 ], intersection = True )
4179+ assert (td [3 ].get ("next" ) != 0 ).any ()
4180+ finally :
4181+ penv .close ()
4182+ del penv
41704183
41714184 @pytest .mark .parametrize ("use_buffers" , [False , True ])
41724185 def test_serial_partial_step_and_maybe_reset (self , use_buffers , device , env_device ):
@@ -4184,9 +4197,9 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
41844197
41854198 td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
41864199 td = penv .step (td )
4187- assert (td [0 ].get ("next" ) == 0 ). all ( )
4200+ assert_allclose_td (td [0 ].get ("next" ), td [ 0 ], intersection = True )
41884201 assert (td [1 ].get ("next" ) != 0 ).any ()
4189- assert (td [2 ].get ("next" ) == 0 ). all ( )
4202+ assert_allclose_td (td [2 ].get ("next" ), td [ 2 ], intersection = True )
41904203 assert (td [3 ].get ("next" ) != 0 ).any ()
41914204
41924205
0 commit comments