@@ -8783,6 +8783,7 @@ def test_ppo(
87838783 value,
87848784 loss_critic_type="l2",
87858785 functional=functional,
8786+ device=device,
87868787 )
87878788 if composite_action_dist:
87888789 loss_fn.set_keys(
@@ -8883,6 +8884,7 @@ def test_ppo_composite_no_aggregate(
88838884 value,
88848885 loss_critic_type="l2",
88858886 functional=functional,
8887+ device=device,
88868888 )
88878889 loss_fn.set_keys(
88888890 action=("action", "action1"),
@@ -8943,9 +8945,19 @@ def test_ppo_state_dict(
89438945 device=device, composite_action_dist=composite_action_dist
89448946 )
89458947 value = self._create_mock_value(device=device)
8946- loss_fn = loss_class(actor, value, loss_critic_type="l2")
8948+ loss_fn = loss_class(
8949+ actor,
8950+ value,
8951+ loss_critic_type="l2",
8952+ device=device,
8953+ )
89478954 sd = loss_fn.state_dict()
8948- loss_fn2 = loss_class(actor, value, loss_critic_type="l2")
8955+ loss_fn2 = loss_class(
8956+ actor,
8957+ value,
8958+ loss_critic_type="l2",
8959+ device=device,
8960+ )
89498961 loss_fn2.load_state_dict(sd)
89508962
89518963 @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@@ -8993,6 +9005,7 @@ def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist):
89939005 value,
89949006 loss_critic_type="l2",
89959007 separate_losses=True,
9008+ device=device,
89969009 )
89979010
89989011 if advantage is not None:
@@ -9100,6 +9113,7 @@ def test_ppo_shared_seq(
91009113 loss_critic_type="l2",
91019114 separate_losses=separate_losses,
91029115 entropy_coef=0.0,
9116+ device=device,
91039117 )
91049118
91059119 loss_fn2 = loss_class(
@@ -9108,6 +9122,7 @@ def test_ppo_shared_seq(
91089122 loss_critic_type="l2",
91099123 separate_losses=separate_losses,
91109124 entropy_coef=0.0,
9125+ device=device,
91119126 )
91129127
91139128 if advantage is not None:
@@ -9202,7 +9217,12 @@ def test_ppo_diff(
92029217 else:
92039218 raise NotImplementedError
92049219
9205- loss_fn = loss_class(actor, value, loss_critic_type="l2")
9220+ loss_fn = loss_class(
9221+ actor,
9222+ value,
9223+ loss_critic_type="l2",
9224+ device=device,
9225+ )
92069226
92079227 params = TensorDict.from_module(loss_fn, as_module=True)
92089228
@@ -9595,6 +9615,7 @@ def test_ppo_value_clipping(
95959615 value,
95969616 loss_critic_type="l2",
95979617 clip_value=clip_value,
9618+ device=device,
95989619 )
95999620
96009621 else:
@@ -9603,6 +9624,7 @@ def test_ppo_value_clipping(
96039624 value,
96049625 loss_critic_type="l2",
96059626 clip_value=clip_value,
9627+ device=device,
96069628 )
96079629 advantage(td)
96089630 if composite_action_dist:
0 commit comments