@@ -278,15 +278,19 @@ def pause(self):
278278 f"Collector pause() is not implemented for { type (self ).__name__ } ."
279279 )
280280
281- def async_shutdown (self , timeout : float | None = None ) -> None :
281+ def async_shutdown (
282+ self , timeout : float | None = None , close_env : bool = True
283+ ) -> None :
282284 """Shuts down the collector when started asynchronously with the `start` method.
283285
284286 Arg:
285287 timeout (float, optional): The maximum time to wait for the collector to shutdown.
288+ close_env (bool, optional): If True, the collector will close the contained environment.
289+ Defaults to `True`.
286290
287291 .. seealso:: :meth:`~.start`
288292 """
289- return self .shutdown (timeout = timeout )
293+ return self .shutdown (timeout = timeout , close_env = close_env )
290294
291295 def update_policy_weights_ (
292296 self ,
@@ -342,7 +346,7 @@ def next(self):
342346 return None
343347
344348 @abc .abstractmethod
345- def shutdown (self , timeout : float | None = None ) -> None :
349+ def shutdown (self , timeout : float | None = None , close_env : bool = True ) -> None :
346350 raise NotImplementedError
347351
348352 @abc .abstractmethod
@@ -1317,12 +1321,14 @@ def _run_iterator(self):
13171321 if self ._stop :
13181322 return
13191323
1320- def async_shutdown (self , timeout : float | None = None ) -> None :
1324+ def async_shutdown (
1325+ self , timeout : float | None = None , close_env : bool = True
1326+ ) -> None :
13211327 """Finishes processes started by ray.init() during async execution."""
13221328 self ._stop = True
13231329 if hasattr (self , "_thread" ) and self ._thread .is_alive ():
13241330 self ._thread .join (timeout = timeout )
1325- self .shutdown ()
1331+ self .shutdown (close_env = close_env )
13261332
13271333 def _postproc (self , tensordict_out ):
13281334 if self .split_trajs :
@@ -1582,14 +1588,20 @@ def reset(self, index=None, **kwargs) -> None:
15821588 )
15831589 self ._shuttle ["collector" ] = collector_metadata
15841590
1585- def shutdown (self , timeout : float | None = None ) -> None :
1586- """Shuts down all workers and/or closes the local environment."""
1591+ def shutdown (self , timeout : float | None = None , close_env : bool = True ) -> None :
1592+ """Shuts down all workers and/or closes the local environment.
1593+
1594+ Args:
1595+ timeout (float, optional): The timeout for closing pipes between workers.
1596+ No effect for this class.
1597+ close_env (bool, optional): Whether to close the environment. Defaults to `True`.
1598+ """
15871599 if not self .closed :
15881600 self .closed = True
15891601 del self ._shuttle
15901602 if self ._use_buffers :
15911603 del self ._final_rollout
1592- if not self .env .is_closed :
1604+ if close_env and not self .env .is_closed :
15931605 self .env .close ()
15941606 del self .env
15951607 return
@@ -2391,8 +2403,17 @@ def __del__(self):
23912403 # __del__ will not affect the program.
23922404 pass
23932405
2394- def shutdown (self , timeout : float | None = None ) -> None :
2395- """Shuts down all processes. This operation is irreversible."""
2406+ def shutdown (self , timeout : float | None = None , close_env : bool = True ) -> None :
2407+ """Shuts down all processes. This operation is irreversible.
2408+
2409+ Args:
2410+ timeout (float, optional): The timeout for closing pipes between workers.
2411+ close_env (bool, optional): Whether to close the environment. Defaults to `True`.
2412+ """
2413+ if not close_env :
2414+ raise RuntimeError (
2415+ f"Cannot shutdown { type (self ).__name__ } collector without environment being closed."
2416+ )
23962417 self ._shutdown_main (timeout )
23972418
23982419 def _shutdown_main (self , timeout : float | None = None ) -> None :
@@ -2665,7 +2686,11 @@ def next(self):
26652686 return super ().next ()
26662687
26672688 # for RPC
2668- def shutdown (self , timeout : float | None = None ) -> None :
2689+ def shutdown (self , timeout : float | None = None , close_env : bool = True ) -> None :
2690+ if not close_env :
2691+ raise RuntimeError (
2692+ f"Cannot shutdown { type (self ).__name__ } collector without environment being closed."
2693+ )
26692694 if hasattr (self , "out_buffer" ):
26702695 del self .out_buffer
26712696 if hasattr (self , "buffers" ):
@@ -3038,9 +3063,13 @@ def next(self):
30383063 return super ().next ()
30393064
30403065 # for RPC
3041- def shutdown (self , timeout : float | None = None ) -> None :
3066+ def shutdown (self , timeout : float | None = None , close_env : bool = True ) -> None :
30423067 if hasattr (self , "out_tensordicts" ):
30433068 del self .out_tensordicts
3069+ if not close_env :
3070+ raise RuntimeError (
3071+ f"Cannot shutdown { type (self ).__name__ } collector without environment being closed."
3072+ )
30443073 return super ().shutdown (timeout = timeout )
30453074
30463075 # for RPC
@@ -3382,8 +3411,8 @@ def next(self):
33823411 return super ().next ()
33833412
33843413 # for RPC
3385- def shutdown (self , timeout : float | None = None ) -> None :
3386- return super ().shutdown (timeout = timeout )
3414+ def shutdown (self , timeout : float | None = None , close_env : bool = True ) -> None :
3415+ return super ().shutdown (timeout = timeout , close_env = close_env )
33873416
33883417 # for RPC
33893418 def set_seed (self , seed : int , static_seed : bool = False ) -> int :
0 commit comments