@@ -122,6 +122,10 @@ class PPOLoss(LossModule):
122122 The purpose of clipping is to limit the impact of extreme value predictions, helping stabilize training
123123 and preventing large updates. However, it will have no impact if the value estimate was done by the current
124124 version of the value estimator. Defaults to ``None``.
125+ device (torch.device, optional): device of the buffers. Defaults to ``None``.
126+
127+ .. note:: Parameters and buffers from the policy / critic will not be cast to that device to ensure that
128+ the storages match the ones that are passed to other components, such as data collectors.
125129
126130 .. note::
127131 The advantage (typically GAE) can be computed by the loss function or
@@ -341,6 +345,7 @@ def __init__(
341345 critic : ProbabilisticTensorDictSequential = None ,
342346 reduction : str = None ,
343347 clip_value : float | None = None ,
348+ device : torch .device | None = None ,
344349 ** kwargs ,
345350 ):
346351 if actor is not None :
@@ -395,10 +400,13 @@ def __init__(
395400 self .separate_losses = separate_losses
396401 self .reduction = reduction
397402
398- try :
399- device = next (self .parameters ()).device
400- except (AttributeError , StopIteration ):
401- device = getattr (torch , "get_default_device" , lambda : torch .device ("cpu" ))()
403+ if device is None :
404+ try :
405+ device = next (self .parameters ()).device
406+ except (AttributeError , StopIteration ):
407+ device = getattr (
408+ torch , "get_default_device" , lambda : torch .device ("cpu" )
409+ )()
402410
403411 self .register_buffer ("entropy_coef" , torch .tensor (entropy_coef , device = device ))
404412 if critic_coef is not None :
@@ -422,7 +430,7 @@ def __init__(
422430
423431 if clip_value is not None :
424432 if isinstance (clip_value , float ):
425- clip_value = torch .tensor (clip_value )
433+ clip_value = torch .tensor (clip_value , device = device )
426434 elif isinstance (clip_value , torch .Tensor ):
427435 if clip_value .numel () != 1 :
428436 raise ValueError (
@@ -866,6 +874,10 @@ class ClipPPOLoss(PPOLoss):
866874 estimate was done by the current version of the value estimator. If instead ``True`` is provided, the
867875 ``clip_epsilon`` parameter will be used as the clipping threshold. If not provided or ``False``, no
868876 clipping will be performed. Defaults to ``False``.
877+ device (torch.device, optional): device of the buffers. Defaults to ``None``.
878+
879+ .. note:: Parameters and buffers from the policy / critic will not be cast to that device to ensure that
880+ the storages match the ones that are passed to other components, such as data collectors.
869881
870882 .. note:
871883 The advantage (typically GAE) can be computed by the loss function or
@@ -934,6 +946,7 @@ def __init__(
934946 separate_losses : bool = False ,
935947 reduction : str = None ,
936948 clip_value : bool | float | None = None ,
949+ device : torch .device | None = None ,
937950 ** kwargs ,
938951 ):
939952 # Define clipping of the value loss
@@ -954,13 +967,15 @@ def __init__(
954967 separate_losses = separate_losses ,
955968 reduction = reduction ,
956969 clip_value = clip_value ,
957- ** kwargs ,
970+ device = device ** kwargs ,
958971 )
959- for p in self .parameters ():
960- device = p .device
961- break
962- else :
963- device = None
972+ if device is None :
973+ try :
974+ device = next (self .parameters ()).device
975+ except (AttributeError , StopIteration ):
976+ device = getattr (
977+ torch , "get_default_device" , lambda : torch .device ("cpu" )
978+ )()
964979 self .register_buffer ("clip_epsilon" , torch .tensor (clip_epsilon , device = device ))
965980
966981 @property
@@ -1139,6 +1154,10 @@ class KLPENPPOLoss(PPOLoss):
11391154 The purpose of clipping is to limit the impact of extreme value predictions, helping stabilize training
11401155 and preventing large updates. However, it will have no impact if the value estimate was done by the current
11411156 version of the value estimator. Defaults to ``None``.
1157+ device (torch.device, optional): device of the buffers. Defaults to ``None``.
1158+
1159+ .. note:: Parameters and buffers from the policy / critic will not be cast to that device to ensure that
1160+ the storages match the ones that are passed to other components, such as data collectors.
11421161
11431162 .. note:
11441163 The advantage (typically GAE) can be computed by the loss function or
@@ -1211,6 +1230,7 @@ def __init__(
12111230 separate_losses : bool = False ,
12121231 reduction : str = None ,
12131232 clip_value : float | None = None ,
1233+ device : torch .device | None = None ,
12141234 ** kwargs ,
12151235 ):
12161236 super ().__init__ (
@@ -1227,12 +1247,21 @@ def __init__(
12271247 separate_losses = separate_losses ,
12281248 reduction = reduction ,
12291249 clip_value = clip_value ,
1250+ device = device ,
12301251 ** kwargs ,
12311252 )
12321253
1254+ if device is None :
1255+ try :
1256+ device = next (self .parameters ()).device
1257+ except (AttributeError , StopIteration ):
1258+ device = getattr (
1259+ torch , "get_default_device" , lambda : torch .device ("cpu" )
1260+ )()
1261+
12331262 self .dtarg = dtarg
12341263 self ._beta_init = beta
1235- self .register_buffer ("beta" , torch .tensor (beta ))
1264+ self .register_buffer ("beta" , torch .tensor (beta , device = device ))
12361265
12371266 if increment < 1.0 :
12381267 raise ValueError (
0 commit comments