Skip to content

Commit fd3aefa

Browse files
committed
Move ZSN tests to test_multivariate.py
1 parent 854ef4c commit fd3aefa

File tree

2 files changed

+120
-111
lines changed

2 files changed

+120
-111
lines changed

pymc/tests/distributions/test_continuous.py

Lines changed: 0 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,13 @@
2424

2525
from aeppl.logprob import ParameterValueError
2626
from aesara.compile.mode import Mode
27-
from numpy import AxisError
2827

2928
import pymc as pm
3029

3130
from pymc.aesaraf import floatX
3231
from pymc.distributions import logcdf, logp
3332
from pymc.distributions.continuous import get_tau_sigma, interpolated
3433
from pymc.distributions.dist_math import clipped_beta_rvs
35-
from pymc.distributions.shape_utils import change_dist_size
3634
from pymc.tests.distributions.util import (
3735
BaseTestDistributionRandom,
3836
Circ,
@@ -965,19 +963,6 @@ def test_normal_moment(self, mu, sigma, size, expected):
965963
pm.Normal("x", mu=mu, sigma=sigma, size=size)
966964
assert_moment_is_expected(model, expected)
967965

968-
@pytest.mark.parametrize(
969-
"shape, zerosum_axes, expected",
970-
[
971-
((2, 5), None, np.zeros((2, 5))),
972-
((2, 5, 6), None, np.zeros((2, 5, 6))),
973-
((2, 5, 6), (0, 1), np.zeros((2, 5, 6))),
974-
],
975-
)
976-
def test_zerosum_normal_moment(self, shape, zerosum_axes, expected):
977-
with pm.Model() as model:
978-
pm.ZeroSumNormal("x", shape=shape, zerosum_axes=zerosum_axes)
979-
assert_moment_is_expected(model, expected)
980-
981966
@pytest.mark.parametrize(
982967
"sigma, size, expected",
983968
[
@@ -1817,100 +1802,6 @@ class TestTruncatedNormalUpperArray(BaseTestDistributionRandom):
18171802
]
18181803

18191804

1820-
COORDS = {
1821-
"regions": ["a", "b", "c"],
1822-
"answers": ["yes", "no", "whatever", "don't understand question"],
1823-
}
1824-
1825-
1826-
class TestZeroSumNormal:
1827-
@pytest.mark.parametrize(
1828-
"dims, zerosum_axes, shape",
1829-
[
1830-
(("regions", "answers"), "answers", None),
1831-
(("regions", "answers"), ("regions", "answers"), None),
1832-
(("regions", "answers"), 0, None),
1833-
(("regions", "answers"), -1, None),
1834-
(("regions", "answers"), (0, 1), None),
1835-
(None, -2, (len(COORDS["regions"]), len(COORDS["answers"]))),
1836-
],
1837-
)
1838-
def test_zsn_dims_shape(self, dims, zerosum_axes, shape):
1839-
with pm.Model(coords=COORDS) as m:
1840-
v = pm.ZeroSumNormal("v", dims=dims, shape=shape, zerosum_axes=zerosum_axes)
1841-
s = pm.sample(10, chains=1, tune=100)
1842-
1843-
# to test forward graph
1844-
random_samples = pm.draw(
1845-
v,
1846-
draws=10,
1847-
)
1848-
1849-
assert s.posterior.v.shape == (1, 10, len(COORDS["regions"]), len(COORDS["answers"]))
1850-
1851-
if not isinstance(zerosum_axes, (list, tuple)):
1852-
zerosum_axes = [zerosum_axes]
1853-
1854-
if isinstance(zerosum_axes[0], str):
1855-
for ax in zerosum_axes:
1856-
for samples in [
1857-
s.posterior.v.mean(dim=ax),
1858-
random_samples.mean(axis=dims.index(ax) + 1),
1859-
]:
1860-
assert np.isclose(
1861-
samples, 0
1862-
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1863-
1864-
nonzero_axes = list(set(dims).difference(zerosum_axes))
1865-
if nonzero_axes:
1866-
for ax in nonzero_axes:
1867-
for samples in [
1868-
s.posterior.v.mean(dim=ax),
1869-
random_samples.mean(axis=dims.index(ax) + 1),
1870-
]:
1871-
assert not np.isclose(
1872-
samples, 0
1873-
).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1874-
1875-
else:
1876-
for ax in zerosum_axes:
1877-
if ax < 0:
1878-
assert np.isclose(
1879-
s.posterior.v.mean(axis=ax), 0
1880-
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1881-
else:
1882-
ax = ax + 2 # because 'chain' and 'draw' are added as new axes after sampling
1883-
assert np.isclose(
1884-
s.posterior.v.mean(axis=ax), 0
1885-
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1886-
1887-
@pytest.mark.parametrize(
1888-
"dims, zerosum_axes",
1889-
[
1890-
(("regions", "answers"), 2),
1891-
(("regions", "answers"), (0, -2)),
1892-
],
1893-
)
1894-
def test_zsn_fail_axis(self, dims, zerosum_axes):
1895-
if isinstance(zerosum_axes, (list, tuple)):
1896-
with pytest.raises(ValueError, match="repeated axis"):
1897-
with pm.Model(coords=COORDS) as m:
1898-
_ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes)
1899-
else:
1900-
with pytest.raises(AxisError, match="out of bounds"):
1901-
with pm.Model(coords=COORDS) as m:
1902-
_ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes)
1903-
1904-
def test_zsn_change_dist_size(self):
1905-
base_dist = pm.ZeroSumNormal.dist(shape=(4, 9))
1906-
1907-
new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=False)
1908-
assert new_dist.eval().shape == (5, 3)
1909-
1910-
new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=True)
1911-
assert new_dist.eval().shape == (5, 3, 4, 9)
1912-
1913-
19141805
class TestWald(BaseTestDistributionRandom):
19151806
pymc_dist = pm.Wald
19161807
mu, lam, alpha = 1.0, 1.0, 0.0

pymc/tests/distributions/test_multivariate.py

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from aeppl.logprob import ParameterValueError
2929
from aesara.tensor import TensorVariable
3030
from aesara.tensor.random.utils import broadcast_params
31+
from numpy import AxisError
3132

3233
import pymc as pm
3334

@@ -754,7 +755,12 @@ def test_car_logp(self, sparse, size):
754755

755756
# d x d adjacency matrix for a square (d=4) of rook-adjacent sites
756757
W = np.array(
757-
[[0.0, 1.0, 1.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0]]
758+
[
759+
[0.0, 1.0, 1.0, 0.0],
760+
[1.0, 0.0, 0.0, 1.0],
761+
[1.0, 0.0, 0.0, 1.0],
762+
[0.0, 1.0, 1.0, 0.0],
763+
]
758764
)
759765

760766
tau = 2
@@ -1007,6 +1013,19 @@ def test_mv_normal_moment(self, mu, cov, size, expected):
10071013
# MvNormal logp is only implemented for up to 2D variables
10081014
assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3)
10091015

1016+
@pytest.mark.parametrize(
1017+
"shape, zerosum_axes, expected",
1018+
[
1019+
((2, 5), None, np.zeros((2, 5))),
1020+
((2, 5, 6), None, np.zeros((2, 5, 6))),
1021+
((2, 5, 6), (0, 1), np.zeros((2, 5, 6))),
1022+
],
1023+
)
1024+
def test_zerosum_normal_moment(self, shape, zerosum_axes, expected):
1025+
with pm.Model() as model:
1026+
pm.ZeroSumNormal("x", shape=shape, zerosum_axes=zerosum_axes)
1027+
assert_moment_is_expected(model, expected)
1028+
10101029
@pytest.mark.parametrize(
10111030
"mu, size, expected",
10121031
[
@@ -1026,7 +1045,12 @@ def test_mv_normal_moment(self, mu, cov, size, expected):
10261045
)
10271046
def test_car_moment(self, mu, size, expected):
10281047
W = np.array(
1029-
[[0.0, 1.0, 1.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0]]
1048+
[
1049+
[0.0, 1.0, 1.0, 0.0],
1050+
[1.0, 0.0, 0.0, 1.0],
1051+
[1.0, 0.0, 0.0, 1.0],
1052+
[0.0, 1.0, 1.0, 0.0],
1053+
]
10301054
)
10311055
tau = 2
10321056
alpha = 0.5
@@ -1367,6 +1391,100 @@ def test_issue_3706(self):
13671391
assert prior_pred["X"].shape == (1, N, 2)
13681392

13691393

1394+
COORDS = {
1395+
"regions": ["a", "b", "c"],
1396+
"answers": ["yes", "no", "whatever", "don't understand question"],
1397+
}
1398+
1399+
1400+
class TestZeroSumNormal:
1401+
@pytest.mark.parametrize(
1402+
"dims, zerosum_axes, shape",
1403+
[
1404+
(("regions", "answers"), "answers", None),
1405+
(("regions", "answers"), ("regions", "answers"), None),
1406+
(("regions", "answers"), 0, None),
1407+
(("regions", "answers"), -1, None),
1408+
(("regions", "answers"), (0, 1), None),
1409+
(None, -2, (len(COORDS["regions"]), len(COORDS["answers"]))),
1410+
],
1411+
)
1412+
def test_zsn_dims_shape(self, dims, zerosum_axes, shape):
1413+
with pm.Model(coords=COORDS) as m:
1414+
v = pm.ZeroSumNormal("v", dims=dims, shape=shape, zerosum_axes=zerosum_axes)
1415+
s = pm.sample(10, chains=1, tune=100)
1416+
1417+
# to test forward graph
1418+
random_samples = pm.draw(
1419+
v,
1420+
draws=10,
1421+
)
1422+
1423+
assert s.posterior.v.shape == (1, 10, len(COORDS["regions"]), len(COORDS["answers"]))
1424+
1425+
if not isinstance(zerosum_axes, (list, tuple)):
1426+
zerosum_axes = [zerosum_axes]
1427+
1428+
if isinstance(zerosum_axes[0], str):
1429+
for ax in zerosum_axes:
1430+
for samples in [
1431+
s.posterior.v.mean(dim=ax),
1432+
random_samples.mean(axis=dims.index(ax) + 1),
1433+
]:
1434+
assert np.isclose(
1435+
samples, 0
1436+
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1437+
1438+
nonzero_axes = list(set(dims).difference(zerosum_axes))
1439+
if nonzero_axes:
1440+
for ax in nonzero_axes:
1441+
for samples in [
1442+
s.posterior.v.mean(dim=ax),
1443+
random_samples.mean(axis=dims.index(ax) + 1),
1444+
]:
1445+
assert not np.isclose(
1446+
samples, 0
1447+
).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1448+
1449+
else:
1450+
for ax in zerosum_axes:
1451+
if ax < 0:
1452+
assert np.isclose(
1453+
s.posterior.v.mean(axis=ax), 0
1454+
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1455+
else:
1456+
ax = ax + 2 # because 'chain' and 'draw' are added as new axes after sampling
1457+
assert np.isclose(
1458+
s.posterior.v.mean(axis=ax), 0
1459+
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1460+
1461+
@pytest.mark.parametrize(
1462+
"dims, zerosum_axes",
1463+
[
1464+
(("regions", "answers"), 2),
1465+
(("regions", "answers"), (0, -2)),
1466+
],
1467+
)
1468+
def test_zsn_fail_axis(self, dims, zerosum_axes):
1469+
if isinstance(zerosum_axes, (list, tuple)):
1470+
with pytest.raises(ValueError, match="repeated axis"):
1471+
with pm.Model(coords=COORDS) as m:
1472+
_ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes)
1473+
else:
1474+
with pytest.raises(AxisError, match="out of bounds"):
1475+
with pm.Model(coords=COORDS) as m:
1476+
_ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes)
1477+
1478+
def test_zsn_change_dist_size(self):
1479+
base_dist = pm.ZeroSumNormal.dist(shape=(4, 9))
1480+
1481+
new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=False)
1482+
assert new_dist.eval().shape == (5, 3)
1483+
1484+
new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=True)
1485+
assert new_dist.eval().shape == (5, 3, 4, 9)
1486+
1487+
13701488
class TestMvStudentTCov(BaseTestDistributionRandom):
13711489
def mvstudentt_rng_fn(self, size, nu, mu, cov, rng):
13721490
mv_samples = rng.multivariate_normal(np.zeros_like(mu), cov, size=size)

0 commit comments

Comments
 (0)