11< << << << HEAD
2+ < << << << HEAD
3+ == == == =
4+ >> >> >> > cb0c201 (latest ZeroSumNormal code , pymc3 v3 , random seed for sampling )
25from typing import List
36
47try :
58 import aesara .tensor as aet
69except ImportError :
710 import theano .tensor as aet
811
12+ << << << < HEAD
913import numpy as np
1014import pymc3 as pm
1115from scipy import stats
@@ -141,68 +145,48 @@ def logcdf(self, value):
141145 raise NotImplementedError ()
142146== == == =
143147import pymc3 as pm
148+ == == == =
149+ >> >> >> > cb0c201 (latest ZeroSumNormal code , pymc3 v3 , random seed for sampling )
144150import numpy as np
145- import pandas as pd
146- from typing import *
147- import aesara
148- import aesara .tensor as aet
149-
150-
151- def ZeroSumNormal (
152- name : str ,
153- sigma : Optional [float ] = None ,
154- * ,
155- dims : Union [str , Tuple [str ]],
156- model : Optional [pm .Model ] = None ,
157- ):
158- """
159- Multivariate normal, such that sum(x, axis=-1) = 0.
160-
161- Parameters
162- ----------
163- name: str
164- String name representation of the PyMC variable.
165- sigma: Optional[float], defaults to None
166- Scale for the Normal distribution. If ``None``, a standard Normal is used.
167- dims: Union[str, Tuple[str]]
168- Dimension names for the shape of the distribution.
169- See https://docs.pymc.io/pymc-examples/examples/pymc3_howto/data_container.html for an example.
170- model: Optional[pm.Model], defaults to None
171- PyMC model instance. If ``None``, a model instance is created.
172- """
173- if isinstance (dims , str ):
174- dims = (dims ,)
151+ import pymc3 as pm
152+ from scipy import stats
153+ from pymc3 .distributions .distribution import generate_samples , draw_values
175154
176- model = pm .modelcontext (model )
177- * dims_pre , dim = dims
178- dim_trunc = f"{ dim } _truncated_"
179- (shape ,) = model .shape_from_dims ((dim ,))
180- assert shape >= 1
155+ def extend_axis_aet (array , axis ):
156+ n = array .shape [axis ] + 1
157+ sum_vals = array .sum (axis , keepdims = True )
158+ norm = sum_vals / (np .sqrt (n ) + n )
159+ fill_val = norm - sum_vals / np .sqrt (n )
160+
161+ out = aet .concatenate ([array , fill_val .astype (str (array .dtype ))], axis = axis )
162+ return out - norm .astype (str (array .dtype ))
181163
182- model .add_coords ({f"{ dim } _truncated_" : pd .RangeIndex (shape - 1 )})
183- raw = pm .Normal (f"{ name } _truncated_" , dims = tuple (dims_pre ) + (dim_trunc ,), sigma = sigma )
184- Q = make_sum_zero_hh (shape )
185- draws = aet .dot (raw , Q [:, 1 :].T )
186164
187- #if sigma is not None:
188- # draws = sigma * draws
165+ def extend_axis_rev_aet (array : np .ndarray , axis : int ):
166+ if axis < 0 :
167+ axis = axis % array .ndim
168+ assert axis >= 0 and axis < array .ndim
189169
190- return pm .Deterministic (name , draws , dims = dims )
170+ n = array .shape [axis ]
171+ last = aet .take (array , [- 1 ], axis = axis )
172+
173+ sum_vals = - last * np .sqrt (n )
174+ norm = sum_vals / (np .sqrt (n ) + n )
175+ slice_before = (slice (None , None ),) * axis
176+ return array [slice_before + (slice (None , - 1 ),)] + norm .astype (str (array .dtype ))
191177
192178
179+ def extend_axis (array , axis ):
180+ n = array .shape [axis ] + 1
181+ sum_vals = array .sum (axis , keepdims = True )
182+ norm = sum_vals / (np .sqrt (n ) + n )
183+ fill_val = norm - sum_vals / np .sqrt (n )
184+
185+ out = np .concatenate ([array , fill_val .astype (str (array .dtype ))], axis = axis )
186+ return out - norm .astype (str (array .dtype ))
193187
194- def make_sum_zero_hh (N : int ) -> np .ndarray :
195- """
196- Build a householder transformation matrix that maps e_1 to a vector of all 1s.
197- """
198- e_1 = np .zeros (N )
199- e_1 [0 ] = 1
200- a = np .ones (N )
201- a /= np .sqrt (a @ a )
202- v = a + e_1
203- v /= np .sqrt (v @ v )
204- return np .eye (N ) - 2 * np .outer (v , v )
205188
189+ < << << << HEAD
206190def make_sum_zero_hh (N : int ) -> np .ndarray :
207191 """
208192 Build a householder transformation matrix that maps e_1 to a vector of all 1s.
@@ -215,3 +199,99 @@ def make_sum_zero_hh(N: int) -> np.ndarray:
215199 v /= np .sqrt (v @ v )
216200 return np .eye (N ) - 2 * np .outer (v , v )
217201> >> >> >> 2 da3052 (ZeroSumNormal : initial commit )
202+ == == == =
203+ def extend_axis_rev (array , axis ):
204+ n = array .shape [axis ]
205+ last = np .take (array , [- 1 ], axis = axis )
206+
207+ sum_vals = - last * np .sqrt (n )
208+ norm = sum_vals / (np .sqrt (n ) + n )
209+ slice_before = (slice (None , None ),) * len (array .shape [:axis ])
210+ return array [slice_before + (slice (None , - 1 ),)] + norm .astype (str (array .dtype ))
211+
212+
213+ class ZeroSumTransform (pm .distributions .transforms .Transform ):
214+ name = "zerosum"
215+
216+ _active_dims : List [int ]
217+
218+ def __init__ (self , active_dims ):
219+ self ._active_dims = active_dims
220+
221+ def forward (self , x ):
222+ for axis in self ._active_dims :
223+ x = extend_axis_rev_aet (x , axis = axis )
224+ return x
225+
226+ def forward_val (self , x , point = None ):
227+ for axis in self ._active_dims :
228+ x = extend_axis_rev (x , axis = axis )
229+ return x
230+
231+ def backward (self , z ):
232+ z = aet .as_tensor_variable (z )
233+ for axis in self ._active_dims :
234+ z = extend_axis_aet (z , axis = axis )
235+ return z
236+
237+ def jacobian_det (self , x ):
238+ return aet .constant (0. )
239+
240+
241+ class ZeroSumNormal (pm .Continuous ):
242+ def __init__ (self , sigma = 1 , * , active_dims = None , active_axes = None , ** kwargs ):
243+ shape = kwargs .get ("shape" , ())
244+ dims = kwargs .get ("dims" , None )
245+ if isinstance (shape , int ):
246+ shape = (shape ,)
247+
248+ if isinstance (dims , str ):
249+ dims = (dims ,)
250+
251+ self .mu = self .median = self .mode = aet .zeros (shape )
252+ self .sigma = aet .as_tensor_variable (sigma )
253+
254+ if active_dims is None and active_axes is None :
255+ if shape :
256+ active_axes = (- 1 ,)
257+ else :
258+ active_axes = ()
259+
260+ if isinstance (active_axes , int ):
261+ active_axes = (active_axes ,)
262+
263+ if isinstance (active_dims , str ):
264+ active_dims = (active_dims ,)
265+
266+ if active_axes is not None and active_dims is not None :
267+ raise ValueError ("Only one of active_axes and active_dims can be specified." )
268+
269+ if active_dims is not None :
270+ model = pm .modelcontext (None )
271+ print (model .RV_dims )
272+ if dims is None :
273+ raise ValueError ("active_dims can only be used with the dims kwargs." )
274+ active_axes = []
275+ for dim in active_dims :
276+ active_axes .append (dims .index (dim ))
277+
278+ super ().__init__ (** kwargs , transform = ZeroSumTransform (active_axes ))
279+
280+ def logp (self , x ):
281+ return pm .Normal .dist (sigma = self .sigma ).logp (x )
282+
283+ @staticmethod
284+ def _random (scale , size ):
285+ samples = stats .norm .rvs (loc = 0 , scale = scale , size = size )
286+ return samples - np .mean (samples , axis = - 1 , keepdims = True )
287+
288+ def random (self , point = None , size = None ):
289+ sigma , = draw_values ([self .sigma ], point = point , size = size )
290+ return generate_samples (self ._random , scale = sigma , dist_shape = self .shape , size = size )
291+
292+ def _distr_parameters_for_repr (self ):
293+ return ["sigma" ]
294+
295+ def logcdf (self , value ):
296+ raise NotImplementedError ()
297+ > >> >> >> cb0c201 (latest ZeroSumNormal code , pymc3 v3 , random seed for sampling )
0 commit comments