|
28 | 28 | from aeppl.logprob import ParameterValueError |
29 | 29 | from aesara.tensor import TensorVariable |
30 | 30 | from aesara.tensor.random.utils import broadcast_params |
31 | | -from numpy import AxisError |
32 | 31 |
|
33 | 32 | import pymc as pm |
34 | 33 |
|
@@ -1442,21 +1441,29 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape): |
1442 | 1441 | ).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." |
1443 | 1442 |
|
1444 | 1443 | @pytest.mark.parametrize( |
1445 | | - "dims, zerosum_axes", |
| 1444 | + "error, match, shape, support_shape, zerosum_axes", |
1446 | 1445 | [ |
1447 | | - (("regions", "answers"), 2), |
1448 | | - (("regions", "answers"), (0, -2)), |
| 1446 | + (IndexError, "index out of range", (3, 4, 5), None, 4), |
| 1447 | + (AssertionError, "does not match", (3, 4), 3, None), # support_shape should be 4 |
| 1448 | + ( |
| 1449 | + AssertionError, |
| 1450 | + "does not match", |
| 1451 | + (3, 4), |
| 1452 | + (3, 4), |
| 1453 | + None, |
| 1454 | + ), # doesn't work because zerosum_axes = 1 |
1449 | 1455 | ], |
1450 | 1456 | ) |
1451 | | - def test_zsn_fail_axis(self, dims, zerosum_axes): |
1452 | | - if isinstance(zerosum_axes, (list, tuple)): |
1453 | | - with pytest.raises(ValueError, match="repeated axis"): |
1454 | | - with pm.Model(coords=COORDS) as m: |
1455 | | - _ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) |
1456 | | - else: |
1457 | | - with pytest.raises(AxisError, match="out of bounds"): |
1458 | | - with pm.Model(coords=COORDS) as m: |
1459 | | - _ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) |
| 1457 | + def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes): |
| 1458 | + with pytest.raises(error, match=match): |
| 1459 | + with pm.Model() as m: |
| 1460 | + _ = pm.ZeroSumNormal( |
| 1461 | + "v", shape=shape, support_shape=support_shape, zerosum_axes=zerosum_axes |
| 1462 | + ) |
| 1463 | + |
| 1464 | + # v = pm.ZeroSumNormal("v", support_shape=(3, 4), zerosum_axes=2) # should work |
| 1465 | + |
| 1466 | + # v = pm.ZeroSumNormal("v", shape=(3, 4), support_shape=(3, 4), zerosum_axes=2) this should work but doesn't |
1460 | 1467 |
|
1461 | 1468 | @pytest.mark.parametrize( |
1462 | 1469 | "zerosum_axes", |
|
0 commit comments