@@ -4266,11 +4266,15 @@ def test_sac_deactivate_vmap(
42664266 loss_fn_no_vmap.make_value_estimator(td_est)
42674267
42684268 torch.manual_seed(0)
4269- with _check_td_steady(td), pytest.warns(
4270- UserWarning, match="No target network updater"
4271- ):
4272- loss_no_vmap = loss_fn_no_vmap(td)
4273- assert_allclose_td(loss_vmap, loss_no_vmap)
4269+ with pytest.raises(
4270+ NotImplementedError,
4271+ match="This implementation is not supported for torch<2.7",
4272+ ) if torch.__version__ < "2.7" else contextlib.nullcontext():
4273+ with _check_td_steady(td), pytest.warns(
4274+ UserWarning, match="No target network updater"
4275+ ):
4276+ loss_no_vmap = loss_fn_no_vmap(td)
4277+ assert_allclose_td(loss_vmap, loss_no_vmap)
42744278
42754279 @pytest.mark.parametrize("delay_value", (True, False))
42764280 @pytest.mark.parametrize("delay_actor", (True, False))
@@ -5235,12 +5239,16 @@ def test_discrete_sac_deactivate_vmap(
52355239 if td_est is not None:
52365240 loss_fn_no_vmap.make_value_estimator(td_est)
52375241
5238- with _check_td_steady(td), pytest.warns(
5239- UserWarning, match="No target network updater"
5240- ):
5241- torch.manual_seed(1)
5242- loss_no_vmap = loss_fn_no_vmap(td)
5243- assert_allclose_td(loss_vmap, loss_no_vmap)
5242+ with pytest.raises(
5243+ NotImplementedError,
5244+ match="This implementation is not supported for torch<2.7",
5245+ ) if torch.__version__ < "2.7" else contextlib.nullcontext():
5246+ with _check_td_steady(td), pytest.warns(
5247+ UserWarning, match="No target network updater"
5248+ ):
5249+ torch.manual_seed(1)
5250+ loss_no_vmap = loss_fn_no_vmap(td)
5251+ assert_allclose_td(loss_vmap, loss_no_vmap)
52445252
52455253 @pytest.mark.parametrize("delay_qvalue", (True, False))
52465254 @pytest.mark.parametrize("num_qvalue", [2])
@@ -5979,10 +5987,14 @@ def test_crossq_deactivate_vmap(
59795987 if td_est is not None:
59805988 loss_fn_no_vmap.make_value_estimator(td_est)
59815989
5982- with _check_td_steady(td):
5983- torch.manual_seed(1)
5984- loss_no_vmap = loss_fn_no_vmap(td)
5985- assert_allclose_td(loss_vmap, loss_no_vmap)
5990+ with pytest.raises(
5991+ NotImplementedError,
5992+ match="This implementation is not supported for torch<2.7",
5993+ ) if torch.__version__ < "2.7" else contextlib.nullcontext():
5994+ with _check_td_steady(td):
5995+ torch.manual_seed(1)
5996+ loss_no_vmap = loss_fn_no_vmap(td)
5997+ assert_allclose_td(loss_vmap, loss_no_vmap)
59865998
59875999 @pytest.mark.parametrize("num_qvalue", [2])
59886000 @pytest.mark.parametrize("device", get_default_devices())
@@ -7725,12 +7737,16 @@ def test_cql_deactivate_vmap(
77257737 if td_est is not None:
77267738 loss_fn_no_vmap.make_value_estimator(td_est)
77277739
7728- with _check_td_steady(td), pytest.warns(
7729- UserWarning, match="No target network updater"
7730- ):
7731- torch.manual_seed(1)
7732- loss_no_vmap = loss_fn_no_vmap(td)
7733- assert_allclose_td(loss_vmap, loss_no_vmap)
7740+ with pytest.raises(
7741+ NotImplementedError,
7742+ match="This implementation is not supported for torch<2.7",
7743+ ) if torch.__version__ < "2.7" else contextlib.nullcontext():
7744+ with _check_td_steady(td), pytest.warns(
7745+ UserWarning, match="No target network updater"
7746+ ):
7747+ torch.manual_seed(1)
7748+ loss_no_vmap = loss_fn_no_vmap(td)
7749+ assert_allclose_td(loss_vmap, loss_no_vmap)
77347750
77357751 @pytest.mark.parametrize("delay_actor", (True,))
77367752 @pytest.mark.parametrize("delay_qvalue", (True,))
@@ -12796,12 +12812,16 @@ def test_iql_deactivate_vmap(
1279612812 if td_est is not None:
1279712813 loss_fn_no_vmap.make_value_estimator(td_est)
1279812814
12799- with _check_td_steady(td), pytest.warns(
12800- UserWarning, match="No target network updater"
12801- ):
12802- torch.manual_seed(1)
12803- loss_no_vmap = loss_fn_no_vmap(td)
12804- assert_allclose_td(loss_vmap, loss_no_vmap)
12815+ with pytest.raises(
12816+ NotImplementedError,
12817+ match="This implementation is not supported for torch<2.7",
12818+ ) if torch.__version__ < "2.7" else contextlib.nullcontext():
12819+ with _check_td_steady(td), pytest.warns(
12820+ UserWarning, match="No target network updater"
12821+ ):
12822+ torch.manual_seed(1)
12823+ loss_no_vmap = loss_fn_no_vmap(td)
12824+ assert_allclose_td(loss_vmap, loss_no_vmap)
1280512825
1280612826 @pytest.mark.parametrize("num_qvalue", [2])
1280712827 @pytest.mark.parametrize("device", get_default_devices())
@@ -14507,10 +14527,14 @@ def test_gae_recurrent(self, module):
1450714527 shifted=False,
1450814528 deactivate_vmap=True,
1450914529 )
14510- with set_recurrent_mode(True):
14511- r1 = gae(vals.copy())
14512- a1 = r1["advantage"]
14513- torch.testing.assert_close(a0, a1)
14530+ with pytest.raises(
14531+ NotImplementedError,
14532+ match="This implementation is not supported for torch<2.7",
14533+ ) if torch.__version__ < "2.7" else contextlib.nullcontext():
14534+ with set_recurrent_mode(True):
14535+ r1 = gae(vals.copy())
14536+ a1 = r1["advantage"]
14537+ torch.testing.assert_close(a0, a1)
1451414538
1451514539 @pytest.mark.parametrize("device", get_default_devices())
1451614540 @pytest.mark.parametrize("gamma", [0.1, 0.5, 0.99])
0 commit comments