99import warnings
1010from copy import copy
1111from enum import Enum
12- from typing import Any , Callable , Iterable
12+ from typing import Any , Callable , Iterable , TypeVar
1313
1414import torch
1515from tensordict import NestedKey , TensorDict , TensorDictBase , unravel_key
@@ -101,54 +101,44 @@ def decorate_context(*args, **kwargs):
101101 return decorate_context
102102
103103
104+ TensorLike = TypeVar ("TensorLike" , Tensor , TensorDict )
105+
106+
104107def distance_loss (
105- v1 : torch . Tensor ,
106- v2 : torch . Tensor ,
108+ v1 : TensorLike ,
109+ v2 : TensorLike ,
107110 loss_function : str ,
108111 strict_shape : bool = True ,
109- ) -> torch . Tensor :
112+ ) -> TensorLike :
110113 """Computes a distance loss between two tensors.
111114
112115 Args:
113- v1 (Tensor): a tensor with a shape compatible with v2
114- v2 (Tensor): a tensor with a shape compatible with v1
116+ v1 (Tensor | TensorDict ): a tensor or tensordict with a shape compatible with v2.
117+ v2 (Tensor | TensorDict ): a tensor or tensordict with a shape compatible with v1.
115118 loss_function (str): One of "l2", "l1" or "smooth_l1" representing which loss function is to be used.
116119 strict_shape (bool): if False, v1 and v2 are allowed to have a different shape.
117120 Default is ``True``.
118121
119122 Returns:
120- A tensor of the shape v1.view_as(v2) or v2.view_as(v1) with values equal to the distance loss between the
121- two.
123+ A tensor or tensordict of the shape v1.view_as(v2) or v2.view_as(v1)
124+ with values equal to the distance loss between the two.
122125
123126 """
124127 if v1 .shape != v2 .shape and strict_shape :
125128 raise RuntimeError (
126- f"The input tensors have shapes { v1 .shape } and { v2 .shape } which are incompatible."
129+ f"The input tensors or tensordicts have shapes { v1 .shape } and { v2 .shape } which are incompatible."
127130 )
128131
129132 if loss_function == "l2" :
130- value_loss = F .mse_loss (
131- v1 ,
132- v2 ,
133- reduction = "none" ,
134- )
133+ return F .mse_loss (v1 , v2 , reduction = "none" )
135134
136- elif loss_function == "l1" :
137- value_loss = F .l1_loss (
138- v1 ,
139- v2 ,
140- reduction = "none" ,
141- )
135+ if loss_function == "l1" :
136+ return F .l1_loss (v1 , v2 , reduction = "none" )
142137
143- elif loss_function == "smooth_l1" :
144- value_loss = F .smooth_l1_loss (
145- v1 ,
146- v2 ,
147- reduction = "none" ,
148- )
149- else :
150- raise NotImplementedError (f"Unknown loss { loss_function } " )
151- return value_loss
138+ if loss_function == "smooth_l1" :
139+ return F .smooth_l1_loss (v1 , v2 , reduction = "none" )
140+
141+ raise NotImplementedError (f"Unknown loss { loss_function } ." )
152142
153143
154144class TargetNetUpdater :
@@ -620,13 +610,13 @@ def _reduce(tensor: torch.Tensor, reduction: str) -> float | torch.Tensor:
620610
621611
622612def _clip_value_loss (
623- old_state_value : torch .Tensor ,
624- state_value : torch .Tensor ,
625- clip_value : torch .Tensor ,
626- target_return : torch .Tensor ,
627- loss_value : torch .Tensor ,
613+ old_state_value : torch .Tensor | TensorDict ,
614+ state_value : torch .Tensor | TensorDict ,
615+ clip_value : torch .Tensor | TensorDict ,
616+ target_return : torch .Tensor | TensorDict ,
617+ loss_value : torch .Tensor | TensorDict ,
628618 loss_critic_type : str ,
629- ):
619+ ) -> tuple [ torch . Tensor | TensorDict , torch . Tensor ] :
630620 """Value clipping method for loss computation.
631621
632622 This method computes a clipped state value from the old state value and the state value,
@@ -644,7 +634,7 @@ def _clip_value_loss(
644634 loss_function = loss_critic_type ,
645635 )
646636 # Chose the most pessimistic value prediction between clipped and non-clipped
647- loss_value = torch .max (loss_value , loss_value_clipped )
637+ loss_value = torch .maximum (loss_value , loss_value_clipped )
648638 return loss_value , clip_fraction
649639
650640
0 commit comments