@@ -89,8 +89,13 @@ class _AcceptedKeys:
8989 pixels : NestedKey = "pixels"
9090 reco_pixels : NestedKey = "reco_pixels"
9191
92+ tensor_keys : _AcceptedKeys
9293 default_keys = _AcceptedKeys ()
9394
95+ decoder : TensorDictModule
96+ reward_model : TensorDictModule
97+ world_mdel : TensorDictModule
98+
9499 def __init__ (
95100 self ,
96101 world_model : TensorDictModule ,
@@ -238,9 +243,13 @@ class _AcceptedKeys:
238243 done : NestedKey = "done"
239244 terminated : NestedKey = "terminated"
240245
246+ tensor_keys : _AcceptedKeys
241247 default_keys = _AcceptedKeys ()
242248 default_value_estimator = ValueEstimators .TDLambda
243249
250+ value_model : TensorDictModule
251+ actor_model : TensorDictModule
252+
244253 def __init__ (
245254 self ,
246255 actor_model : TensorDictModule ,
@@ -392,8 +401,11 @@ class _AcceptedKeys:
392401
393402 value : NestedKey = "state_value"
394403
404+ tensor_keys : _AcceptedKeys
395405 default_keys = _AcceptedKeys ()
396406
407+ value_model : TensorDictModule
408+
397409 def __init__ (
398410 self ,
399411 value_model : TensorDictModule ,
0 commit comments