3636from aesara .tensor .random .utils import broadcast_params
3737from aesara .tensor .slinalg import Cholesky , SolveTriangular
3838from aesara .tensor .type import TensorType
39- from numpy .core .numeric import normalize_axis_tuple
39+
40+ # from numpy.core.numeric import normalize_axis_tuple
4041from scipy import linalg , stats
4142
4243import pymc as pm
6465 _change_dist_size ,
6566 broadcast_dist_samples_to ,
6667 change_dist_size ,
67- convert_dims ,
68+ get_support_shape ,
6869 rv_size_is_none ,
6970 to_tuple ,
7071)
@@ -2389,11 +2390,7 @@ class ZeroSumNormalRV(SymbolicRandomVariable):
23892390 """ZeroSumNormal random variable"""
23902391
23912392 _print_name = ("ZeroSumNormal" , "\\ operatorname{ZeroSumNormal}" )
2392- zerosum_axes = None
2393-
2394- def __init__ (self , * args , zerosum_axes , ** kwargs ):
2395- self .zerosum_axes = zerosum_axes
2396- super ().__init__ (* args , ** kwargs )
2393+ default_output = 0
23972394
23982395
23992396class ZeroSumNormal (Distribution ):
@@ -2447,36 +2444,57 @@ class ZeroSumNormal(Distribution):
24472444 """
24482445 rv_type = ZeroSumNormalRV
24492446
2450- def __new__ (cls , * args , zerosum_axes = None , dims = None , ** kwargs ):
2451- dims = convert_dims (dims )
2452- if zerosum_axes is None :
2453- zerosum_axes = [- 1 ]
2454- if not isinstance (zerosum_axes , (list , tuple )):
2455- zerosum_axes = [zerosum_axes ]
2447+ # def __new__(cls, *args, zerosum_axes=None, dims=None, **kwargs):
2448+ # dims = convert_dims(dims)
2449+ # if zerosum_axes is None:
2450+ # zerosum_axes = [-1]
2451+ # if not isinstance(zerosum_axes, (list, tuple)):
2452+ # zerosum_axes = [zerosum_axes]
24562453
2457- if isinstance (zerosum_axes [0 ], str ):
2458- if not dims :
2459- raise ValueError ("You need to specify dims if zerosum_axes are strings." )
2460- else :
2461- zerosum_axes_ = []
2462- for axis in zerosum_axes :
2463- zerosum_axes_ .append (dims .index (axis ))
2464- zerosum_axes = zerosum_axes_
2454+ # if isinstance(zerosum_axes[0], str):
2455+ # if not dims:
2456+ # raise ValueError("You need to specify dims if zerosum_axes are strings.")
2457+ # else:
2458+ # zerosum_axes_ = []
2459+ # for axis in zerosum_axes:
2460+ # zerosum_axes_.append(dims.index(axis))
2461+ # zerosum_axes = zerosum_axes_
24652462
2466- return super ().__new__ (cls , * args , zerosum_axes = zerosum_axes , dims = dims , ** kwargs )
2463+ # return super().__new__(cls, *args, zerosum_axes=zerosum_axes, dims=dims, **kwargs)
24672464
24682465 @classmethod
2469- def dist (cls , sigma = 1 , zerosum_axes = None , ** kwargs ):
2466+ def dist (cls , sigma = 1 , zerosum_axes = None , support_shape = None , ** kwargs ):
24702467 if zerosum_axes is None :
2471- zerosum_axes = [- 1 ]
2472- if not isinstance (zerosum_axes , (list , tuple )):
2473- zerosum_axes = [zerosum_axes ]
2468+ zerosum_axes = 1
2469+ if not isinstance (zerosum_axes , int ):
2470+ raise TypeError ("zerosum_axes has to be an integer" )
2471+ if not zerosum_axes > 0 :
2472+ raise ValueError ("zerosum_axes has to be > 0" )
24742473
24752474 sigma = at .as_tensor_variable (floatX (sigma ))
24762475 if sigma .ndim > 0 :
24772476 raise ValueError ("sigma has to be a scalar" )
24782477
2479- return super ().dist ([sigma ], zerosum_axes = zerosum_axes , ** kwargs )
2478+ support_shape = get_support_shape (
2479+ support_shape = support_shape ,
2480+ shape = kwargs .get ("shape" ),
2481+ ndim_supp = zerosum_axes ,
2482+ )
2483+ if support_shape is None :
2484+ if zerosum_axes > 0 :
2485+ raise ValueError ("You must specify shape or support_shape parameter" )
2486+ # edge case doesn't work for now, because at.stack in get_support_shape fails
2487+ # else:
2488+ # support_shape = () # because it's just a Normal in that case
2489+ support_shape = at .as_tensor_variable (intX (support_shape ))
2490+
2491+ assert zerosum_axes == at .get_vector_length (
2492+ support_shape
2493+ ), "support_shape has to be as long as zerosum_axes"
2494+
2495+ return super ().dist (
2496+ [sigma ], zerosum_axes = zerosum_axes , support_shape = support_shape , ** kwargs
2497+ )
24802498
24812499 # TODO: This is if we want ZeroSum constraint on other dists than Normal
24822500 # def dist(cls, dist, lower, upper, **kwargs):
@@ -2494,39 +2512,55 @@ def dist(cls, sigma=1, zerosum_axes=None, **kwargs):
24942512 # return super().dist([dist, lower, upper], **kwargs)
24952513
24962514 @classmethod
2497- def rv_op (cls , sigma , zerosum_axes , size = None ):
2498- if size is None :
2499- zerosum_axes_ = np .asarray (zerosum_axes )
2500- # just a placeholder size to infer minimum shape
2501- size = np .ones (
2502- max ((max (np .abs (zerosum_axes_ ) - 1 ), max (zerosum_axes_ ))) + 1 , dtype = int
2503- ).tolist ()
2515+ def rv_op (cls , sigma , zerosum_axes , support_shape , size = None ):
2516+ # if size is None:
2517+ # zerosum_axes_ = np.asarray(zerosum_axes)
2518+ # # just a placeholder size to infer minimum shape
2519+ # size = np.ones(
2520+ # max((max(np.abs(zerosum_axes_) - 1), max(zerosum_axes_))) + 1, dtype=int
2521+ # ).tolist()
25042522
25052523 # check if zerosum_axes is valid
2506- normalize_axis_tuple (zerosum_axes , len (size ))
2507-
2508- normal_dist = ignore_logprob (pm .Normal .dist (sigma = sigma , size = size ))
2509- normal_dist_ , sigma_ = normal_dist .type (), sigma .type ()
2524+ # normalize_axis_tuple(zerosum_axes, len(size))
2525+
2526+ shape = to_tuple (size ) + tuple (support_shape )
2527+ normal_dist = ignore_logprob (pm .Normal .dist (sigma = sigma , shape = shape ))
2528+ normal_dist_ , sigma_ , support_shape_ = (
2529+ normal_dist .type (),
2530+ sigma .type (),
2531+ support_shape .type (),
2532+ )
25102533
25112534 # Zerosum-normaling is achieved by substracting the mean along the given zerosum_axes
25122535 zerosum_rv_ = normal_dist_
2513- for axis in zerosum_axes :
2514- zerosum_rv_ -= zerosum_rv_ .mean (axis = axis , keepdims = True )
2536+ for axis in range ( zerosum_axes ) :
2537+ zerosum_rv_ -= zerosum_rv_ .mean (axis = - axis - 1 , keepdims = True )
25152538
25162539 return ZeroSumNormalRV (
2517- inputs = [normal_dist_ , sigma_ ],
2518- outputs = [zerosum_rv_ ],
2519- zerosum_axes = zerosum_axes ,
2520- ndim_supp = 0 ,
2521- )(normal_dist , sigma )
2540+ inputs = [normal_dist_ , sigma_ , support_shape_ ],
2541+ outputs = [zerosum_rv_ , support_shape_ ],
2542+ ndim_supp = zerosum_axes ,
2543+ )(normal_dist , sigma , support_shape )
2544+
2545+ # TODO:
2546+ # write __new__
2547+ # refactor ZSN tests
2548+ # test get_support_shape with 2D
2549+ # test ZSN logp
2550+ # test ZSN variance
2551+ # fix failing Ubuntu test
25222552
25232553
25242554@_change_dist_size .register (ZeroSumNormalRV )
25252555def change_zerosum_size (op , normal_dist , new_size , expand = False ):
2526- normal_dist , sigma = normal_dist .owner .inputs
2556+ normal_dist , sigma , support_shape = normal_dist .owner .inputs
25272557 if expand :
2528- new_size = tuple (new_size ) + tuple (normal_dist .shape )
2529- return ZeroSumNormal .rv_op (sigma = sigma , zerosum_axes = op .zerosum_axes , size = new_size )
2558+ original_shape = tuple (normal_dist .shape )
2559+ old_size = original_shape [len (original_shape ) - op .ndim_supp :]
2560+ new_size = tuple (new_size ) + old_size
2561+ return ZeroSumNormal .rv_op (
2562+ sigma = sigma , zerosum_axes = op .ndim_supp , support_shape = support_shape , size = new_size
2563+ )
25302564
25312565
25322566@_moment .register (ZeroSumNormalRV )
@@ -2540,20 +2574,22 @@ def zerosum_default_transform(op, rv):
25402574
25412575
25422576@_logprob .register (ZeroSumNormalRV )
2543- def zerosumnormal_logp (op , values , normal_dist , sigma , ** kwargs ):
2577+ def zerosumnormal_logp (op , values , normal_dist , sigma , support_shape , ** kwargs ):
25442578 (value ,) = values
25452579 shape = value .shape
2546- _deg_free_shape = at .inc_subtensor (shape [at .as_tensor_variable (op .zerosum_axes )], - 1 )
2580+ zerosum_axes = op .ndim_supp
2581+ _deg_free_support_shape = at .inc_subtensor (shape [- zerosum_axes :], - 1 )
25472582 _full_size = at .prod (shape )
2548- _degrees_of_freedom = at .prod (_deg_free_shape )
2583+ _degrees_of_freedom = at .prod (_deg_free_support_shape )
25492584 zerosums = [
2550- at .all (at .isclose (at .mean (value , axis = axis ), 0 , atol = 1e-9 )) for axis in op .zerosum_axes
2585+ at .all (at .isclose (at .mean (value , axis = - axis - 1 ), 0 , atol = 1e-9 ))
2586+ for axis in range (zerosum_axes )
25512587 ]
2552- # out = at.sum(
2553- # pm.logp(dist , value) * _degrees_of_freedom / _full_size,
2554- # axis=op. zerosum_axes,
2555- # )
2588+ out = at .sum (
2589+ pm .logp (normal_dist , value ) * _degrees_of_freedom / _full_size ,
2590+ axis = tuple ( np . arange ( - zerosum_axes , 0 )) ,
2591+ )
25562592 # figure out how dimensionality should be handled for logp
25572593 # for now, we assume ZSN is a scalar distribut, which is not correct
2558- out = pm .logp (normal_dist , value ) * _degrees_of_freedom / _full_size
2594+ # out = pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size
25592595 return check_parameters (out , * zerosums , msg = "at.mean(value, axis=zerosum_axes) == 0" )
0 commit comments