1010from scipy import stats
1111from pymc3 .distributions .distribution import generate_samples , draw_values
1212
13+
1314def extend_axis_aet (array , axis ):
1415 n = array .shape [axis ] + 1
1516 sum_vals = array .sum (axis , keepdims = True )
1617 norm = sum_vals / (np .sqrt (n ) + n )
1718 fill_val = norm - sum_vals / np .sqrt (n )
18-
19+
1920 out = aet .concatenate ([array , fill_val .astype (str (array .dtype ))], axis = axis )
2021 return out - norm .astype (str (array .dtype ))
2122
@@ -27,7 +28,7 @@ def extend_axis_rev_aet(array: np.ndarray, axis: int):
2728
2829 n = array .shape [axis ]
2930 last = aet .take (array , [- 1 ], axis = axis )
30-
31+
3132 sum_vals = - last * np .sqrt (n )
3233 norm = sum_vals / (np .sqrt (n ) + n )
3334 slice_before = (slice (None , None ),) * axis
@@ -39,15 +40,15 @@ def extend_axis(array, axis):
3940 sum_vals = array .sum (axis , keepdims = True )
4041 norm = sum_vals / (np .sqrt (n ) + n )
4142 fill_val = norm - sum_vals / np .sqrt (n )
42-
43+
4344 out = np .concatenate ([array , fill_val .astype (str (array .dtype ))], axis = axis )
4445 return out - norm .astype (str (array .dtype ))
4546
4647
4748def extend_axis_rev (array , axis ):
4849 n = array .shape [axis ]
4950 last = np .take (array , [- 1 ], axis = axis )
50-
51+
5152 sum_vals = - last * np .sqrt (n )
5253 norm = sum_vals / (np .sqrt (n ) + n )
5354 slice_before = (slice (None , None ),) * len (array .shape [:axis ])
@@ -56,60 +57,60 @@ def extend_axis_rev(array, axis):
5657
5758class ZeroSumTransform (pm .distributions .transforms .Transform ):
5859 name = "zerosum"
59-
60+
6061 _active_dims : List [int ]
61-
62+
6263 def __init__ (self , active_dims ):
6364 self ._active_dims = active_dims
64-
65+
6566 def forward (self , x ):
6667 for axis in self ._active_dims :
6768 x = extend_axis_rev_aet (x , axis = axis )
6869 return x
69-
70+
7071 def forward_val (self , x , point = None ):
7172 for axis in self ._active_dims :
7273 x = extend_axis_rev (x , axis = axis )
7374 return x
74-
75+
7576 def backward (self , z ):
7677 z = aet .as_tensor_variable (z )
7778 for axis in self ._active_dims :
7879 z = extend_axis_aet (z , axis = axis )
7980 return z
80-
81+
8182 def jacobian_det (self , x ):
82- return aet .constant (0. )
83-
84-
83+ return aet .constant (0.0 )
84+
85+
8586class ZeroSumNormal (pm .Continuous ):
8687 def __init__ (self , sigma = 1 , * , active_dims = None , active_axes = None , ** kwargs ):
8788 shape = kwargs .get ("shape" , ())
8889 dims = kwargs .get ("dims" , None )
8990 if isinstance (shape , int ):
9091 shape = (shape ,)
91-
92+
9293 if isinstance (dims , str ):
9394 dims = (dims ,)
9495
9596 self .mu = self .median = self .mode = aet .zeros (shape )
9697 self .sigma = aet .as_tensor_variable (sigma )
97-
98+
9899 if active_dims is None and active_axes is None :
99100 if shape :
100101 active_axes = (- 1 ,)
101102 else :
102103 active_axes = ()
103-
104+
104105 if isinstance (active_axes , int ):
105106 active_axes = (active_axes ,)
106-
107+
107108 if isinstance (active_dims , str ):
108109 active_dims = (active_dims ,)
109-
110+
110111 if active_axes is not None and active_dims is not None :
111112 raise ValueError ("Only one of active_axes and active_dims can be specified." )
112-
113+
113114 if active_dims is not None :
114115 model = pm .modelcontext (None )
115116 print (model .RV_dims )
@@ -118,19 +119,19 @@ def __init__(self, sigma=1, *, active_dims=None, active_axes=None, **kwargs):
118119 active_axes = []
119120 for dim in active_dims :
120121 active_axes .append (dims .index (dim ))
121-
122+
122123 super ().__init__ (** kwargs , transform = ZeroSumTransform (active_axes ))
123124
124125 def logp (self , x ):
125126 return pm .Normal .dist (sigma = self .sigma ).logp (x )
126-
127+
127128 @staticmethod
128129 def _random (scale , size ):
129130 samples = stats .norm .rvs (loc = 0 , scale = scale , size = size )
130131 return samples - np .mean (samples , axis = - 1 , keepdims = True )
131-
132+
132133 def random (self , point = None , size = None ):
133- sigma , = draw_values ([self .sigma ], point = point , size = size )
134+ ( sigma ,) = draw_values ([self .sigma ], point = point , size = size )
134135 return generate_samples (self ._random , scale = sigma , dist_shape = self .shape , size = size )
135136
136137 def _distr_parameters_for_repr (self ):
0 commit comments