@@ -407,6 +407,7 @@ def __init__(
407407 )
408408 else :
409409 self .critic_coef = None
410+ self ._has_critic = bool (self .critic_coef is not None and self .critic_coef > 0 )
410411 self .loss_critic_type = loss_critic_type
411412 self .normalize_advantage = normalize_advantage
412413 self .normalize_advantage_exclude_dims = normalize_advantage_exclude_dims
@@ -689,7 +690,7 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
689690 "target_actor_network_params" ,
690691 "target_critic_network_params" ,
691692 )
692- if self .critic_coef is not None :
693+ if self ._has_critic :
693694 return self .critic_coef * loss_value , clip_fraction
694695 return loss_value , clip_fraction
695696
@@ -737,7 +738,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
737738 entropy = _sum_td_features (entropy )
738739 td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
739740 td_out .set ("loss_entropy" , - self .entropy_coef * entropy )
740- if self .critic_coef is not None :
741+ if self ._has_critic :
741742 loss_critic , value_clip_fraction = self .loss_critic (tensordict )
742743 td_out .set ("loss_critic" , loss_critic )
743744 if value_clip_fraction is not None :
@@ -1048,7 +1049,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
10481049 entropy = _sum_td_features (entropy )
10491050 td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
10501051 td_out .set ("loss_entropy" , - self .entropy_coef * entropy )
1051- if self .critic_coef is not None and self . critic_coef > 0 :
1052+ if self ._has_critic :
10521053 loss_critic , value_clip_fraction = self .loss_critic (tensordict )
10531054 td_out .set ("loss_critic" , loss_critic )
10541055 if value_clip_fraction is not None :
@@ -1375,7 +1376,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
13751376 entropy = _sum_td_features (entropy )
13761377 td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
13771378 td_out .set ("loss_entropy" , - self .entropy_coef * entropy )
1378- if self .critic_coef is not None :
1379+ if self ._has_critic :
13791380 loss_critic , value_clip_fraction = self .loss_critic (tensordict_copy )
13801381 td_out .set ("loss_critic" , loss_critic )
13811382 if value_clip_fraction is not None :
0 commit comments