@@ -1947,31 +1947,35 @@ def state_spec_unbatched(self, spec: Composite):
19471947 spec = spec .expand (self .batch_size + spec .shape )
19481948 self .state_spec = spec
19491949
1950- def _skip_tensordict (self , tensordict ) :
1950+ def _skip_tensordict (self , tensordict : TensorDictBase ) -> TensorDictBase :
19511951 # Creates a "skip" tensordict, ie a placeholder for when a step is skipped
19521952 next_tensordict = self .full_done_spec .zero ()
19531953 next_tensordict .update (self .full_observation_spec .zero ())
19541954 next_tensordict .update (self .full_reward_spec .zero ())
19551955
19561956 # Copy the data from tensordict in `next`
1957- def select_and_clone (x , y ):
1957+ keys = set ()
1958+
1959+ def select_and_clone (name , x , y ):
1960+ keys .add (name )
19581961 if y is not None :
19591962 if y .device == x .device :
19601963 return x .clone ()
19611964 return x .to (y .device )
19621965
1963- next_tensordict . update (
1964- tensordict . _fast_apply (
1965- select_and_clone ,
1966- next_tensordict ,
1967- device = next_tensordict .device ,
1968- batch_size = next_tensordict . batch_size ,
1969- default = None ,
1970- filter_empty = True ,
1971- is_leaf = _is_leaf_nontensor ,
1972- )
1966+ result = tensordict . _fast_apply (
1967+ select_and_clone ,
1968+ next_tensordict ,
1969+ device = next_tensordict . device ,
1970+ batch_size = next_tensordict .batch_size ,
1971+ default = None ,
1972+ filter_empty = True ,
1973+ is_leaf = _is_leaf_nontensor ,
1974+ named = True ,
1975+ nested_keys = True ,
19731976 )
1974- return next_tensordict
1977+ result .update (next_tensordict .exclude (* keys ).filter_empty_ ())
1978+ return result
19751979
19761980 def step (self , tensordict : TensorDictBase ) -> TensorDictBase :
19771981 """Makes a step in the environment.
0 commit comments