|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -import sys |
16 | 15 | import unittest.mock as mock |
17 | 16 |
|
18 | 17 | from contextlib import ExitStack as does_not_raise |
|
39 | 38 | from pymc.tests.helpers import SeededTest |
40 | 39 | from pymc.tests.models import simple_init |
41 | 40 |
|
42 | | -IS_LINUX = sys.platform == "linux" |
43 | | -IS_FLOAT32 = aesara.config.floatX == "float32" |
44 | | - |
45 | 41 |
|
46 | 42 | class TestInitNuts(SeededTest): |
47 | 43 | def setup_method(self): |
@@ -705,20 +701,16 @@ def test_model_shared_variable(self): |
705 | 701 | assert post_pred["obs"].shape == (samples, 3) |
706 | 702 | npt.assert_allclose(post_pred["p"], expected_p) |
707 | 703 |
|
708 | | - @pytest.mark.xfail( |
709 | | - condition=IS_FLOAT32 and IS_LINUX, |
710 | | - reason="Test fails on linux float32 systems. See https://github.com/pymc-devs/pymc/issues/5088", |
711 | | - ) |
712 | 704 | def test_deterministic_of_observed(self): |
713 | 705 | rng = np.random.RandomState(8442) |
714 | 706 |
|
715 | 707 | meas_in_1 = pm.aesaraf.floatX(2 + 4 * rng.randn(10)) |
716 | 708 | meas_in_2 = pm.aesaraf.floatX(5 + 4 * rng.randn(10)) |
717 | 709 | nchains = 2 |
718 | 710 | with pm.Model(rng_seeder=rng) as model: |
719 | | - mu_in_1 = pm.Normal("mu_in_1", 0, 1) |
| 711 | + mu_in_1 = pm.Normal("mu_in_1", 0, 2) |
720 | 712 | sigma_in_1 = pm.HalfNormal("sd_in_1", 1) |
721 | | - mu_in_2 = pm.Normal("mu_in_2", 0, 1) |
| 713 | + mu_in_2 = pm.Normal("mu_in_2", 0, 2) |
722 | 714 | sigma_in_2 = pm.HalfNormal("sd__in_2", 1) |
723 | 715 |
|
724 | 716 | in_1 = pm.Normal("in_1", mu_in_1, sigma_in_1, observed=meas_in_1) |
|
0 commit comments