@@ -379,6 +379,14 @@ def __init__(
379379
380380 is_spec_locked = EnvBase .is_spec_locked
381381
382+ def select_and_clone (self , name , tensor , selected_keys = None ):
383+ if selected_keys is None :
384+ selected_keys = self ._selected_step_keys
385+ if name in selected_keys :
386+ if self .device is not None and tensor .device != self .device :
387+ return tensor .to (self .device , non_blocking = self .non_blocking )
388+ return tensor .clone ()
389+
382390 @property
383391 def non_blocking (self ):
384392 nb = self ._non_blocking
@@ -1072,12 +1080,10 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
10721080 selected_output_keys = self ._selected_reset_keys_filt
10731081
10741082 # select + clone creates 2 tds, but we can create one only
1075- def select_and_clone (name , tensor ):
1076- if name in selected_output_keys :
1077- return tensor .clone ()
1078-
10791083 out = self .shared_tensordict_parent .named_apply (
1080- select_and_clone ,
1084+ lambda * args : self .select_and_clone (
1085+ * args , selected_keys = selected_output_keys
1086+ ),
10811087 nested_keys = True ,
10821088 filter_empty = True ,
10831089 )
@@ -1150,14 +1156,14 @@ def _step(
11501156 # will be modified in-place at further steps
11511157 device = self .device
11521158
1153- def select_and_clone (name , tensor ):
1154- if name in self ._selected_step_keys :
1155- return tensor .clone ()
1159+ selected_keys = self ._selected_step_keys
11561160
11571161 if partial_steps is not None :
11581162 next_td = TensorDict .lazy_stack ([next_td [i ] for i in workers_range ])
11591163 out = next_td .named_apply (
1160- select_and_clone , nested_keys = True , filter_empty = True
1164+ lambda * args : self .select_and_clone (* args , selected_keys ),
1165+ nested_keys = True ,
1166+ filter_empty = True ,
11611167 )
11621168 if out_tds is not None :
11631169 out .update (
@@ -2010,20 +2016,8 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
20102016 next_td = shared_tensordict_parent .get ("next" )
20112017 device = self .device
20122018
2013- if next_td .device != device and device is not None :
2014-
2015- def select_and_clone (name , tensor ):
2016- if name in self ._selected_step_keys :
2017- return tensor .to (device , non_blocking = self .non_blocking )
2018-
2019- else :
2020-
2021- def select_and_clone (name , tensor ):
2022- if name in self ._selected_step_keys :
2023- return tensor .clone ()
2024-
20252019 out = next_td .named_apply (
2026- select_and_clone ,
2020+ self . select_and_clone ,
20272021 nested_keys = True ,
20282022 filter_empty = True ,
20292023 device = device ,
@@ -2203,20 +2197,10 @@ def tentative_update(val, other):
22032197 selected_output_keys = self ._selected_reset_keys_filt
22042198 device = self .device
22052199
2206- if self .shared_tensordict_parent .device != device and device is not None :
2207-
2208- def select_and_clone (name , tensor ):
2209- if name in selected_output_keys :
2210- return tensor .to (device , non_blocking = self .non_blocking )
2211-
2212- else :
2213-
2214- def select_and_clone (name , tensor ):
2215- if name in selected_output_keys :
2216- return tensor .clone ()
2217-
22182200 out = self .shared_tensordict_parent .named_apply (
2219- select_and_clone ,
2201+ lambda * args : self .select_and_clone (
2202+ * args , selected_keys = selected_output_keys
2203+ ),
22202204 nested_keys = True ,
22212205 filter_empty = True ,
22222206 device = device ,
0 commit comments