5757 clear_mpi_env_vars ,
5858)
5959
60+ _CONSOLIDATE_ERR_CAPTURE = (
61+ "TensorDict.consolidate failed. You can deactivate the tensordict consolidation via the "
62+ "`consolidate` keyword argument of the ParallelEnv constructor."
63+ )
64+
6065
6166def _check_start (fun ):
6267 def decorated_fun (self : BatchedEnvBase , * args , ** kwargs ):
@@ -307,6 +312,7 @@ def __init__(
307312 non_blocking : bool = False ,
308313 mp_start_method : str = None ,
309314 use_buffers : bool = None ,
315+ consolidate : bool = True ,
310316 ):
311317 super ().__init__ (device = device )
312318 self .serial_for_single = serial_for_single
@@ -315,6 +321,7 @@ def __init__(
315321 self .num_threads = num_threads
316322 self ._cache_in_keys = None
317323 self ._use_buffers = use_buffers
324+ self .consolidate = consolidate
318325
319326 self ._single_task = callable (create_env_fn ) or (len (set (create_env_fn )) == 1 )
320327 if callable (create_env_fn ):
@@ -841,9 +848,12 @@ def __repr__(self) -> str:
841848 f"\n \t batch_size={ self .batch_size } )"
842849 )
843850
844- def close (self ) -> None :
851+ def close (self , * , raise_if_closed : bool = True ) -> None :
845852 if self .is_closed :
846- raise RuntimeError ("trying to close a closed environment" )
853+ if raise_if_closed :
854+ raise RuntimeError ("trying to close a closed environment" )
855+ else :
856+ return
847857 if self ._verbose :
848858 torchrl_logger .info (f"closing { self .__class__ .__name__ } " )
849859
@@ -1470,6 +1480,12 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
14701480 "_non_tensor_keys" : self ._non_tensor_keys ,
14711481 }
14721482 )
1483+ else :
1484+ kwargs [idx ].update (
1485+ {
1486+ "consolidate" : self .consolidate ,
1487+ }
1488+ )
14731489 process = proc_fun (target = func , kwargs = kwargs [idx ])
14741490 process .daemon = True
14751491 process .start ()
@@ -1526,7 +1542,16 @@ def _step_and_maybe_reset_no_buffers(
15261542 else :
15271543 workers_range = range (self .num_workers )
15281544
1529- td = tensordict .consolidate (share_memory = True , inplace = True , num_threads = 1 )
1545+ if self .consolidate :
1546+ try :
1547+ td = tensordict .consolidate (
1548+ share_memory = True , inplace = True , num_threads = 1
1549+ )
1550+ except Exception as err :
1551+ raise RuntimeError (_CONSOLIDATE_ERR_CAPTURE ) from err
1552+ else :
1553+ td = tensordict
1554+
15301555 for i in workers_range :
15311556 # We send the same td multiple times as it is in shared mem and we just need to index it
15321557 # in each process.
@@ -1804,7 +1829,16 @@ def _step_no_buffers(
18041829 else :
18051830 workers_range = range (self .num_workers )
18061831
1807- data = tensordict .consolidate (share_memory = True , inplace = True , num_threads = 1 )
1832+ if self .consolidate :
1833+ try :
1834+ data = tensordict .consolidate (
1835+ share_memory = True , inplace = True , num_threads = 1
1836+ )
1837+ except Exception as err :
1838+ raise RuntimeError (_CONSOLIDATE_ERR_CAPTURE ) from err
1839+ else :
1840+ data = tensordict
1841+
18081842 for i , local_data in zip (workers_range , data .unbind (0 )):
18091843 self .parent_channels [i ].send (("step" , local_data ))
18101844 # for i in range(data.shape[0]):
@@ -2026,9 +2060,14 @@ def _reset_no_buffers(
20262060 ) -> Tuple [TensorDictBase , TensorDictBase ]:
20272061 if is_tensor_collection (tensordict ):
20282062 # tensordict = tensordict.consolidate(share_memory=True, num_threads=1)
2029- tensordict = tensordict .consolidate (
2030- share_memory = True , num_threads = 1
2031- ).unbind (0 )
2063+ if self .consolidate :
2064+ try :
2065+ tensordict = tensordict .consolidate (
2066+ share_memory = True , num_threads = 1
2067+ )
2068+ except Exception as err :
2069+ raise RuntimeError (_CONSOLIDATE_ERR_CAPTURE ) from err
2070+ tensordict = tensordict .unbind (0 )
20322071 else :
20332072 tensordict = [None ] * self .num_workers
20342073 out_tds = [None ] * self .num_workers
@@ -2545,6 +2584,7 @@ def _run_worker_pipe_direct(
25452584 has_lazy_inputs : bool = False ,
25462585 verbose : bool = False ,
25472586 num_threads : int | None = None , # for fork start method
2587+ consolidate : bool = True ,
25482588) -> None :
25492589 if num_threads is not None :
25502590 torch .set_num_threads (num_threads )
@@ -2634,9 +2674,18 @@ def _run_worker_pipe_direct(
26342674 event .record ()
26352675 event .synchronize ()
26362676 mp_event .set ()
2637- child_pipe .send (
2638- cur_td .consolidate (share_memory = True , inplace = True , num_threads = 1 )
2639- )
2677+ if consolidate :
2678+ try :
2679+ child_pipe .send (
2680+ cur_td .consolidate (
2681+ share_memory = True , inplace = True , num_threads = 1
2682+ )
2683+ )
2684+ except Exception as err :
2685+ raise RuntimeError (_CONSOLIDATE_ERR_CAPTURE ) from err
2686+ else :
2687+ child_pipe .send (cur_td )
2688+
26402689 del cur_td
26412690
26422691 elif cmd == "step" :
@@ -2650,9 +2699,18 @@ def _run_worker_pipe_direct(
26502699 event .record ()
26512700 event .synchronize ()
26522701 mp_event .set ()
2653- child_pipe .send (
2654- next_td .consolidate (share_memory = True , inplace = True , num_threads = 1 )
2655- )
2702+ if consolidate :
2703+ try :
2704+ child_pipe .send (
2705+ next_td .consolidate (
2706+ share_memory = True , inplace = True , num_threads = 1
2707+ )
2708+ )
2709+ except Exception as err :
2710+ raise RuntimeError (_CONSOLIDATE_ERR_CAPTURE ) from err
2711+ else :
2712+ child_pipe .send (next_td )
2713+
26562714 del next_td
26572715
26582716 elif cmd == "step_and_maybe_reset" :
0 commit comments