@@ -2437,13 +2437,16 @@ class ZeroSumNormal(Distribution):
24372437 "answers": ["yes", "no", "whatever", "don't understand question"],
24382438 }
24392439 with pm.Model(coords=COORDS) as m:
2440- ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes="answers")
2440+ # the zero sum axis will be 'answers'
2441+ ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"))
24412442
24422443 with pm.Model(coords=COORDS) as m:
2443- ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=("regions", "answers"))
2444+ # the zero sum axes will be 'answers' and 'regions'
2445+ ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=2)
24442446
24452447 with pm.Model(coords=COORDS) as m:
2446- ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=1)
2448+ # the zero sum axes will be the last two
2449+ ...: v = pm.ZeroSumNormal("v", shape=(3, 4, 5), zerosum_axes=2)
24472450 """
24482451 rv_type = ZeroSumNormalRV
24492452
@@ -2525,18 +2528,13 @@ def check_zerosum_axes(cls, zerosum_axes: Optional[int]) -> int:
25252528
25262529 @classmethod
25272530 def rv_op (cls , sigma , zerosum_axes , support_shape , size = None ):
2528- # if size is None:
2529- # zerosum_axes_ = np.asarray(zerosum_axes)
2530- # # just a placeholder size to infer minimum shape
2531- # size = np.ones(
2532- # max((max(np.abs(zerosum_axes_) - 1), max(zerosum_axes_))) + 1, dtype=int
2533- # ).tolist()
2534-
2535- # check if zerosum_axes is valid
2536- # normalize_axis_tuple(zerosum_axes, len(size))
25372531
25382532 shape = to_tuple (size ) + tuple (support_shape )
25392533 normal_dist = ignore_logprob (pm .Normal .dist (sigma = sigma , shape = shape ))
2534+
2535+ if zerosum_axes > normal_dist .ndim :
2536+ raise ValueError ("Shape of distribution is too small for the number of zerosum axes" )
2537+
25402538 normal_dist_ , sigma_ , support_shape_ = (
25412539 normal_dist .type (),
25422540 sigma .type (),
@@ -2555,7 +2553,6 @@ def rv_op(cls, sigma, zerosum_axes, support_shape, size=None):
25552553 )(normal_dist , sigma , support_shape )
25562554
25572555 # TODO:
2558- # write __new__
25592556 # refactor ZSN tests
25602557 # test get_support_shape with 2D
25612558 # test ZSN logp
0 commit comments