1515from scipy import stats
1616from pymc3 .distributions .distribution import generate_samples , draw_values
1717
18+
1819def extend_axis_aet (array , axis ):
1920 n = array .shape [axis ] + 1
2021 sum_vals = array .sum (axis , keepdims = True )
2122 norm = sum_vals / (np .sqrt (n ) + n )
2223 fill_val = norm - sum_vals / np .sqrt (n )
23-
24+
2425 out = aet .concatenate ([array , fill_val .astype (str (array .dtype ))], axis = axis )
2526 return out - norm .astype (str (array .dtype ))
2627
@@ -32,7 +33,7 @@ def extend_axis_rev_aet(array: np.ndarray, axis: int):
3233
3334 n = array .shape [axis ]
3435 last = aet .take (array , [- 1 ], axis = axis )
35-
36+
3637 sum_vals = - last * np .sqrt (n )
3738 norm = sum_vals / (np .sqrt (n ) + n )
3839 slice_before = (slice (None , None ),) * axis
@@ -44,15 +45,15 @@ def extend_axis(array, axis):
4445 sum_vals = array .sum (axis , keepdims = True )
4546 norm = sum_vals / (np .sqrt (n ) + n )
4647 fill_val = norm - sum_vals / np .sqrt (n )
47-
48+
4849 out = np .concatenate ([array , fill_val .astype (str (array .dtype ))], axis = axis )
4950 return out - norm .astype (str (array .dtype ))
5051
5152
5253def extend_axis_rev (array , axis ):
5354 n = array .shape [axis ]
5455 last = np .take (array , [- 1 ], axis = axis )
55-
56+
5657 sum_vals = - last * np .sqrt (n )
5758 norm = sum_vals / (np .sqrt (n ) + n )
5859 slice_before = (slice (None , None ),) * len (array .shape [:axis ])
@@ -61,60 +62,60 @@ def extend_axis_rev(array, axis):
6162
6263class ZeroSumTransform (pm .distributions .transforms .Transform ):
6364 name = "zerosum"
64-
65+
6566 _active_dims : List [int ]
66-
67+
6768 def __init__ (self , active_dims ):
6869 self ._active_dims = active_dims
69-
70+
7071 def forward (self , x ):
7172 for axis in self ._active_dims :
7273 x = extend_axis_rev_aet (x , axis = axis )
7374 return x
74-
75+
7576 def forward_val (self , x , point = None ):
7677 for axis in self ._active_dims :
7778 x = extend_axis_rev (x , axis = axis )
7879 return x
79-
80+
8081 def backward (self , z ):
8182 z = aet .as_tensor_variable (z )
8283 for axis in self ._active_dims :
8384 z = extend_axis_aet (z , axis = axis )
8485 return z
85-
86+
8687 def jacobian_det (self , x ):
87- return aet .constant (0. )
88-
89-
88+ return aet .constant (0.0 )
89+
90+
9091class ZeroSumNormal (pm .Continuous ):
9192 def __init__ (self , sigma = 1 , * , active_dims = None , active_axes = None , ** kwargs ):
9293 shape = kwargs .get ("shape" , ())
9394 dims = kwargs .get ("dims" , None )
9495 if isinstance (shape , int ):
9596 shape = (shape ,)
96-
97+
9798 if isinstance (dims , str ):
9899 dims = (dims ,)
99100
100101 self .mu = self .median = self .mode = aet .zeros (shape )
101102 self .sigma = aet .as_tensor_variable (sigma )
102-
103+
103104 if active_dims is None and active_axes is None :
104105 if shape :
105106 active_axes = (- 1 ,)
106107 else :
107108 active_axes = ()
108-
109+
109110 if isinstance (active_axes , int ):
110111 active_axes = (active_axes ,)
111-
112+
112113 if isinstance (active_dims , str ):
113114 active_dims = (active_dims ,)
114-
115+
115116 if active_axes is not None and active_dims is not None :
116117 raise ValueError ("Only one of active_axes and active_dims can be specified." )
117-
118+
118119 if active_dims is not None :
119120 model = pm .modelcontext (None )
120121 print (model .RV_dims )
@@ -123,19 +124,19 @@ def __init__(self, sigma=1, *, active_dims=None, active_axes=None, **kwargs):
123124 active_axes = []
124125 for dim in active_dims :
125126 active_axes .append (dims .index (dim ))
126-
127+
127128 super ().__init__ (** kwargs , transform = ZeroSumTransform (active_axes ))
128129
129130 def logp (self , x ):
130131 return pm .Normal .dist (sigma = self .sigma ).logp (x )
131-
132+
132133 @staticmethod
133134 def _random (scale , size ):
134135 samples = stats .norm .rvs (loc = 0 , scale = scale , size = size )
135136 return samples - np .mean (samples , axis = - 1 , keepdims = True )
136-
137+
137138 def random (self , point = None , size = None ):
138- sigma , = draw_values ([self .sigma ], point = point , size = size )
139+ ( sigma ,) = draw_values ([self .sigma ], point = point , size = size )
139140 return generate_samples (self ._random , scale = sigma , dist_shape = self .shape , size = size )
140141
141142 def _distr_parameters_for_repr (self ):
0 commit comments