@@ -11616,7 +11616,7 @@ def _make_transform_env(self, out_key, base_env):
1161611616 transform = KLRewardTransform(actor, out_keys=out_key)
1161711617 return Compose(
1161811618 TensorDictPrimer(
11619- sample_log_prob =Unbounded(shape=base_env.action_spec.shape[:-1]),
11619+ action_log_prob =Unbounded(shape=base_env.action_spec.shape[:-1]),
1162011620 shape=base_env.shape,
1162111621 ),
1162211622 transform,
@@ -11640,7 +11640,7 @@ def test_transform_no_env(self, in_key, out_key):
1164011640 {
1164111641 "action": torch.randn(*batch, 7),
1164211642 "observation": torch.randn(*batch, 7),
11643- "sample_log_prob ": torch.randn(*batch),
11643+ "action_log_prob ": torch.randn(*batch),
1164411644 },
1164511645 batch,
1164611646 )
@@ -11658,7 +11658,7 @@ def test_transform_compose(self):
1165811658 "action": torch.randn(*batch, 7),
1165911659 "observation": torch.randn(*batch, 7),
1166011660 "next": {t[0].in_keys[0]: torch.zeros(*batch, 1)},
11661- "sample_log_prob ": torch.randn(*batch),
11661+ "action_log_prob ": torch.randn(*batch),
1166211662 },
1166311663 batch,
1166411664 )
@@ -11678,7 +11678,7 @@ def test_transform_env(self, out_key):
1167811678 base_env = self.envclass()
1167911679 torch.manual_seed(0)
1168011680 actor = self._make_actor()
11681- # we need to patch the env and create a sample_log_prob spec to make check_env_specs happy
11681+ # we need to patch the env and create a action_log_prob spec to make check_env_specs happy
1168211682 env = TransformedEnv(
1168311683 base_env,
1168411684 Compose(
@@ -11711,7 +11711,7 @@ def update(x):
1171111711 @pytest.mark.parametrize("out_key", [None, "some_stuff", ["some_stuff"]])
1171211712 def test_single_trans_env_check(self, out_key):
1171311713 base_env = self.envclass()
11714- # we need to patch the env and create a sample_log_prob spec to make check_env_specs happy
11714+ # we need to patch the env and create a action_log_prob spec to make check_env_specs happy
1171511715 env = TransformedEnv(base_env, self._make_transform_env(out_key, base_env))
1171611716 check_env_specs(env)
1171711717
@@ -11776,7 +11776,7 @@ def test_transform_model(self):
1177611776 "action": torch.randn(*batch, 7),
1177711777 "observation": torch.randn(*batch, 7),
1177811778 "next": {t.in_keys[0]: torch.zeros(*batch, 1)},
11779- "sample_log_prob ": torch.randn(*batch),
11779+ "action_log_prob ": torch.randn(*batch),
1178011780 },
1178111781 batch,
1178211782 )
@@ -11796,7 +11796,7 @@ def test_transform_rb(self, rbclass):
1179611796 "action": torch.randn(*batch, 7),
1179711797 "observation": torch.randn(*batch, 7),
1179811798 "next": {t.in_keys[0]: torch.zeros(*batch, 1)},
11799- "sample_log_prob ": torch.randn(*batch),
11799+ "action_log_prob ": torch.randn(*batch),
1180011800 },
1180111801 batch,
1180211802 )
0 commit comments