@@ -269,11 +269,20 @@ def _reset(self, tensordict):
269269 batch_size = (
270270 tensordict .batch_size if tensordict is not None else self .batch_size
271271 )
272- if tensordict is None or tensordict . is_empty () :
272+ if tensordict is None or "params" not in tensordict :
273273 # if no ``tensordict`` is passed, we generate a single set of hyperparameters
274274 # Otherwise, we assume that the input ``tensordict`` contains all the relevant
275275 # parameters to get started.
276276 tensordict = self .gen_params (batch_size = batch_size , device = self .device )
277+ elif "th" in tensordict and "thdot" in tensordict :
278+ # we can hard-reset the env too
279+ return tensordict
280+ out = self ._reset_random_data (
281+ tensordict .shape , batch_size , tensordict ["params" ]
282+ )
283+ return out
284+
285+ def _reset_random_data (self , shape , batch_size , params ):
277286
278287 high_th = torch .tensor (self .DEFAULT_X , device = self .device )
279288 high_thdot = torch .tensor (self .DEFAULT_Y , device = self .device )
@@ -284,20 +293,20 @@ def _reset(self, tensordict):
284293 # of simulators run simultaneously. In other contexts, the initial
285294 # random state's shape will depend upon the environment batch-size instead.
286295 th = (
287- torch .rand (tensordict . shape , generator = self .rng , device = self .device )
296+ torch .rand (shape , generator = self .rng , device = self .device )
288297 * (high_th - low_th )
289298 + low_th
290299 )
291300 thdot = (
292- torch .rand (tensordict . shape , generator = self .rng , device = self .device )
301+ torch .rand (shape , generator = self .rng , device = self .device )
293302 * (high_thdot - low_thdot )
294303 + low_thdot
295304 )
296305 out = TensorDict (
297306 {
298307 "th" : th ,
299308 "thdot" : thdot ,
300- "params" : tensordict [ " params" ] ,
309+ "params" : params ,
301310 },
302311 batch_size = batch_size ,
303312 )
0 commit comments