File tree Expand file tree Collapse file tree 1 file changed +30
-0
lines changed Expand file tree Collapse file tree 1 file changed +30
-0
lines changed Original file line number Diff line number Diff line change @@ -1933,6 +1933,36 @@ def state_spec_unbatched(self, spec: Composite):
19331933 spec = spec .expand (self .batch_size + spec .shape )
19341934 self .state_spec = spec
19351935
1936+ def _skip_tensordict (self , tensordict : TensorDictBase ) -> TensorDictBase :
1937+ # Creates a "skip" tensordict, ie a placeholder for when a step is skipped
1938+ next_tensordict = self .full_done_spec .zero ()
1939+ next_tensordict .update (self .full_observation_spec .zero ())
1940+ next_tensordict .update (self .full_reward_spec .zero ())
1941+
1942+ # Copy the data from tensordict in `next`
1943+ keys = set ()
1944+
1945+ def select_and_clone (name , x , y ):
1946+ keys .add (name )
1947+ if y is not None :
1948+ if y .device == x .device :
1949+ return x .clone ()
1950+ return x .to (y .device )
1951+
1952+ result = tensordict ._fast_apply (
1953+ select_and_clone ,
1954+ next_tensordict ,
1955+ device = next_tensordict .device ,
1956+ batch_size = next_tensordict .batch_size ,
1957+ default = None ,
1958+ filter_empty = True ,
1959+ is_leaf = _is_leaf_nontensor ,
1960+ named = True ,
1961+ nested_keys = True ,
1962+ )
1963+ result .update (next_tensordict .exclude (* keys ).filter_empty_ ())
1964+ return result
1965+
19361966 def step (self , tensordict : TensorDictBase ) -> TensorDictBase :
19371967 """Makes a step in the environment.
19381968
You can’t perform that action at this time.
0 commit comments