@@ -1252,7 +1252,9 @@ class GAE(ValueEstimatorBase):
12521252 Args:
12531253 gamma (scalar): exponential mean discount.
12541254 lmbda (scalar): trajectory discount.
1255- value_network (TensorDictModule): value operator used to retrieve the value estimates.
1255+ value_network (TensorDictModule, optional): value operator used to retrieve the value estimates.
1256+ If ``None``, this module will expect the ``"state_value"`` keys to be already filled, and
1257+ will not call the value network to produce it.
12561258 average_gae (bool): if ``True``, the resulting GAE values will be standardized.
12571259 Default is ``False``.
12581260 differentiable (bool, optional): if ``True``, gradients are propagated through
@@ -1327,7 +1329,7 @@ def __init__(
13271329 * ,
13281330 gamma : float | torch .Tensor ,
13291331 lmbda : float | torch .Tensor ,
1330- value_network : TensorDictModule ,
1332+ value_network : TensorDictModule | None ,
13311333 average_gae : bool = False ,
13321334 differentiable : bool = False ,
13331335 vectorized : bool | None = None ,
@@ -1499,6 +1501,15 @@ def forward(
14991501 value = tensordict .get (self .tensor_keys .value )
15001502 next_value = tensordict .get (("next" , self .tensor_keys .value ))
15011503
1504+ if value is None :
1505+ raise ValueError (
1506+ f"The tensor with key { self .tensor_keys .value } is missing, and no value network was provided."
1507+ )
1508+ if next_value is None :
1509+ raise ValueError (
1510+ f"The tensor with key { ('next' , self .tensor_keys .value )} is missing, and no value network was provided."
1511+ )
1512+
15021513 done = tensordict .get (("next" , self .tensor_keys .done ))
15031514 terminated = tensordict .get (("next" , self .tensor_keys .terminated ), default = done )
15041515 time_dim = self ._get_time_dim (time_dim , tensordict )
0 commit comments