@@ -10837,12 +10837,16 @@ class Timer(Transform):
1083710837
1083810838 Attributes:
1083910839 out_keys: The keys of the output tensordict for the inverse transform. Defaults to
10840- `out_keys = [f"{time_key}_step", f"{time_key}_policy"]`, where the first key represents
10840+ `out_keys = [f"{time_key}_step", f"{time_key}_policy", f"{time_key}_reset" ]`, where the first key represents
1084110841 the time it takes to make a step in the environment, and the second key represents the
10842- time it takes to execute the policy.
10842+ time it takes to execute the policy, the third the time for the call to `reset` .
1084310843 time_key: A prefix for the keys where the time intervals will be stored in the tensordict.
1084410844 Defaults to `"time"`.
1084510845
10846+ .. note:: During a succession of rollouts, the time marks of the reset are written at the root (the `"time_reset"`
10847+ entry or equivalent key is always 0 in the `"next"` tensordict). At the root, the `"time_policy"` and `"time_step"`
10848+ entries will be 0 when there is a reset. they will never be `0` in the `"next"`.
10849+
1084610850 Examples:
1084710851 >>> from torchrl.envs import Timer, GymEnv
1084810852 >>>
@@ -10854,20 +10858,23 @@ class Timer(Transform):
1085410858 >>> print("time for step", r["time_step"])
1085510859 time for step tensor([9.5797e-04, 1.6289e-03, 9.7990e-05, 8.0824e-05, 9.0837e-05, 7.6056e-05,
1085610860 8.2016e-05, 7.6056e-05, 8.1062e-05, 7.7009e-05])
10861+
10862+
1085710863 """
1085810864
1085910865 def __init__ (self , out_keys : Sequence [NestedKey ] = None , time_key : str = "time" ):
1086010866 if out_keys is None :
10861- out_keys = [f"{ time_key } _step" , f"{ time_key } _policy" ]
10862- elif len (out_keys ) != 2 :
10863- raise TypeError (f"Expected two out_keys. Got out_keys={ out_keys } ." )
10867+ out_keys = [f"{ time_key } _step" , f"{ time_key } _policy" , f" { time_key } _reset" ]
10868+ elif len (out_keys ) != 3 :
10869+ raise TypeError (f"Expected three out_keys. Got out_keys={ out_keys } ." )
1086410870 super ().__init__ ([], out_keys )
1086510871 self .time_key = time_key
1086610872 self .last_inv_time = None
1086710873 self .last_call_time = None
10874+ self .last_reset_time = None
1086810875
1086910876 def _reset_env_preprocess (self , tensordict : TensorDictBase ) -> TensorDictBase :
10870- self .last_inv_time = time .time ()
10877+ self .last_reset_time = self . last_inv_time = time .time ()
1087110878 return tensordict
1087210879
1087310880 def _maybe_expand_and_set (self , key , time_elapsed , tensordict ):
@@ -10888,11 +10895,14 @@ def _reset(
1088810895 self , tensordict : TensorDictBase , tensordict_reset : TensorDictBase
1088910896 ) -> TensorDictBase :
1089010897 current_time = time .time ()
10891- if self .last_inv_time is not None :
10898+ if self .last_reset_time is not None :
1089210899 time_elapsed = torch .tensor (
10893- current_time - self .last_inv_time , device = tensordict .device
10900+ current_time - self .last_reset_time , device = tensordict .device
10901+ )
10902+ self ._maybe_expand_and_set (self .out_keys [2 ], time_elapsed , tensordict_reset )
10903+ self ._maybe_expand_and_set (
10904+ self .out_keys [0 ], time_elapsed * 0 , tensordict_reset
1089410905 )
10895- self ._maybe_expand_and_set (self .out_keys [0 ], time_elapsed , tensordict_reset )
1089610906 self .last_call_time = current_time
1089710907 # Placeholder
1089810908 self ._maybe_expand_and_set (self .out_keys [1 ], time_elapsed * 0 , tensordict_reset )
@@ -10917,6 +10927,9 @@ def _step(
1091710927 current_time - self .last_inv_time , device = tensordict .device
1091810928 )
1091910929 self ._maybe_expand_and_set (self .out_keys [0 ], time_elapsed , next_tensordict )
10930+ self ._maybe_expand_and_set (
10931+ self .out_keys [2 ], time_elapsed * 0 , next_tensordict
10932+ )
1092010933 self .last_call_time = current_time
1092110934 # presumbly no need to worry about batch size incongruencies here
1092210935 next_tensordict .set (self .out_keys [1 ], tensordict .get (self .out_keys [1 ]))
@@ -10929,6 +10942,9 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
1092910942 observation_spec [self .out_keys [1 ]] = Unbounded (
1093010943 shape = observation_spec .shape , device = observation_spec .device
1093110944 )
10945+ observation_spec [self .out_keys [2 ]] = Unbounded (
10946+ shape = observation_spec .shape , device = observation_spec .device
10947+ )
1093210948 return observation_spec
1093310949
1093410950 def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
0 commit comments