1- import pymc3 as pm
1+ from typing import List
2+
3+ try :
4+ import aesara .tensor as aet
5+ except ImportError :
6+ import theano .tensor as aet
7+
28import numpy as np
3- import pandas as pd
4- from typing import *
5- import aesara
6- import aesara .tensor as aet
7-
8-
9- def ZeroSumNormal (
10- name : str ,
11- sigma : Optional [float ] = None ,
12- * ,
13- dims : Union [str , Tuple [str ]],
14- model : Optional [pm .Model ] = None ,
15- ):
16- """
17- Multivariate normal, such that sum(x, axis=-1) = 0.
18-
19- Parameters
20- ----------
21- name: str
22- String name representation of the PyMC variable.
23- sigma: Optional[float], defaults to None
24- Scale for the Normal distribution. If ``None``, a standard Normal is used.
25- dims: Union[str, Tuple[str]]
26- Dimension names for the shape of the distribution.
27- See https://docs.pymc.io/pymc-examples/examples/pymc3_howto/data_container.html for an example.
28- model: Optional[pm.Model], defaults to None
29- PyMC model instance. If ``None``, a model instance is created.
30- """
31- if isinstance (dims , str ):
32- dims = (dims ,)
33-
34- model = pm .modelcontext (model )
35- * dims_pre , dim = dims
36- dim_trunc = f"{ dim } _truncated_"
37- (shape ,) = model .shape_from_dims ((dim ,))
38- assert shape >= 1
39-
40- model .add_coords ({f"{ dim } _truncated_" : pd .RangeIndex (shape - 1 )})
41- raw = pm .Normal (f"{ name } _truncated_" , dims = tuple (dims_pre ) + (dim_trunc ,), sigma = sigma )
42- Q = make_sum_zero_hh (shape )
43- draws = aet .dot (raw , Q [:, 1 :].T )
44-
45- #if sigma is not None:
46- # draws = sigma * draws
47-
48- return pm .Deterministic (name , draws , dims = dims )
49-
50-
51-
52- def make_sum_zero_hh (N : int ) -> np .ndarray :
53- """
54- Build a householder transformation matrix that maps e_1 to a vector of all 1s.
55- """
56- e_1 = np .zeros (N )
57- e_1 [0 ] = 1
58- a = np .ones (N )
59- a /= np .sqrt (a @ a )
60- v = a + e_1
61- v /= np .sqrt (v @ v )
62- return np .eye (N ) - 2 * np .outer (v , v )
63-
64- def make_sum_zero_hh (N : int ) -> np .ndarray :
65- """
66- Build a householder transformation matrix that maps e_1 to a vector of all 1s.
67- """
68- e_1 = np .zeros (N )
69- e_1 [0 ] = 1
70- a = np .ones (N )
71- a /= np .sqrt (a @ a )
72- v = a + e_1
73- v /= np .sqrt (v @ v )
74- return np .eye (N ) - 2 * np .outer (v , v )
9+ import pymc3 as pm
10+ from scipy import stats
11+ from pymc3 .distributions .distribution import generate_samples , draw_values
12+
13+ def extend_axis_aet (array , axis ):
14+ n = array .shape [axis ] + 1
15+ sum_vals = array .sum (axis , keepdims = True )
16+ norm = sum_vals / (np .sqrt (n ) + n )
17+ fill_val = norm - sum_vals / np .sqrt (n )
18+
19+ out = aet .concatenate ([array , fill_val .astype (str (array .dtype ))], axis = axis )
20+ return out - norm .astype (str (array .dtype ))
21+
22+
23+ def extend_axis_rev_aet (array : np .ndarray , axis : int ):
24+ if axis < 0 :
25+ axis = axis % array .ndim
26+ assert axis >= 0 and axis < array .ndim
27+
28+ n = array .shape [axis ]
29+ last = aet .take (array , [- 1 ], axis = axis )
30+
31+ sum_vals = - last * np .sqrt (n )
32+ norm = sum_vals / (np .sqrt (n ) + n )
33+ slice_before = (slice (None , None ),) * axis
34+ return array [slice_before + (slice (None , - 1 ),)] + norm .astype (str (array .dtype ))
35+
36+
37+ def extend_axis (array , axis ):
38+ n = array .shape [axis ] + 1
39+ sum_vals = array .sum (axis , keepdims = True )
40+ norm = sum_vals / (np .sqrt (n ) + n )
41+ fill_val = norm - sum_vals / np .sqrt (n )
42+
43+ out = np .concatenate ([array , fill_val .astype (str (array .dtype ))], axis = axis )
44+ return out - norm .astype (str (array .dtype ))
45+
46+
47+ def extend_axis_rev (array , axis ):
48+ n = array .shape [axis ]
49+ last = np .take (array , [- 1 ], axis = axis )
50+
51+ sum_vals = - last * np .sqrt (n )
52+ norm = sum_vals / (np .sqrt (n ) + n )
53+ slice_before = (slice (None , None ),) * len (array .shape [:axis ])
54+ return array [slice_before + (slice (None , - 1 ),)] + norm .astype (str (array .dtype ))
55+
56+
57+ class ZeroSumTransform (pm .distributions .transforms .Transform ):
58+ name = "zerosum"
59+
60+ _active_dims : List [int ]
61+
62+ def __init__ (self , active_dims ):
63+ self ._active_dims = active_dims
64+
65+ def forward (self , x ):
66+ for axis in self ._active_dims :
67+ x = extend_axis_rev_aet (x , axis = axis )
68+ return x
69+
70+ def forward_val (self , x , point = None ):
71+ for axis in self ._active_dims :
72+ x = extend_axis_rev (x , axis = axis )
73+ return x
74+
75+ def backward (self , z ):
76+ z = aet .as_tensor_variable (z )
77+ for axis in self ._active_dims :
78+ z = extend_axis_aet (z , axis = axis )
79+ return z
80+
81+ def jacobian_det (self , x ):
82+ return aet .constant (0. )
83+
84+
85+ class ZeroSumNormal (pm .Continuous ):
86+ def __init__ (self , sigma = 1 , * , active_dims = None , active_axes = None , ** kwargs ):
87+ shape = kwargs .get ("shape" , ())
88+ dims = kwargs .get ("dims" , None )
89+ if isinstance (shape , int ):
90+ shape = (shape ,)
91+
92+ if isinstance (dims , str ):
93+ dims = (dims ,)
94+
95+ self .mu = self .median = self .mode = aet .zeros (shape )
96+ self .sigma = aet .as_tensor_variable (sigma )
97+
98+ if active_dims is None and active_axes is None :
99+ if shape :
100+ active_axes = (- 1 ,)
101+ else :
102+ active_axes = ()
103+
104+ if isinstance (active_axes , int ):
105+ active_axes = (active_axes ,)
106+
107+ if isinstance (active_dims , str ):
108+ active_dims = (active_dims ,)
109+
110+ if active_axes is not None and active_dims is not None :
111+ raise ValueError ("Only one of active_axes and active_dims can be specified." )
112+
113+ if active_dims is not None :
114+ model = pm .modelcontext (None )
115+ print (model .RV_dims )
116+ if dims is None :
117+ raise ValueError ("active_dims can only be used with the dims kwargs." )
118+ active_axes = []
119+ for dim in active_dims :
120+ active_axes .append (dims .index (dim ))
121+
122+ super ().__init__ (** kwargs , transform = ZeroSumTransform (active_axes ))
123+
124+ def logp (self , x ):
125+ return pm .Normal .dist (sigma = self .sigma ).logp (x )
126+
127+ @staticmethod
128+ def _random (scale , size ):
129+ samples = stats .norm .rvs (loc = 0 , scale = scale , size = size )
130+ return samples - np .mean (samples , axis = - 1 , keepdims = True )
131+
132+ def random (self , point = None , size = None ):
133+ sigma , = draw_values ([self .sigma ], point = point , size = size )
134+ return generate_samples (self ._random , scale = sigma , dist_shape = self .shape , size = size )
135+
136+ def _distr_parameters_for_repr (self ):
137+ return ["sigma" ]
138+
139+ def logcdf (self , value ):
140+ raise NotImplementedError ()
0 commit comments