1+ import pymc3 as pm
2+ import numpy as np
3+ import pandas as pd
4+ from typing import *
5+ import aesara
6+ import aesara .tensor as aet
7+
8+
9+ def ZeroSumNormal (
10+ name : str ,
11+ sigma : Optional [float ] = None ,
12+ * ,
13+ dims : Union [str , Tuple [str ]],
14+ model : Optional [pm .Model ] = None ,
15+ ):
16+ """
17+ Multivariate normal, such that sum(x, axis=-1) = 0.
18+
19+ Parameters
20+ ----------
21+ name: str
22+ String name representation of the PyMC variable.
23+ sigma: Optional[float], defaults to None
24+ Scale for the Normal distribution. If ``None``, a standard Normal is used.
25+ dims: Union[str, Tuple[str]]
26+ Dimension names for the shape of the distribution.
27+ See https://docs.pymc.io/pymc-examples/examples/pymc3_howto/data_container.html for an example.
28+ model: Optional[pm.Model], defaults to None
29+ PyMC model instance. If ``None``, a model instance is created.
30+ """
31+ if isinstance (dims , str ):
32+ dims = (dims ,)
33+
34+ model = pm .modelcontext (model )
35+ * dims_pre , dim = dims
36+ dim_trunc = f"{ dim } _truncated_"
37+ (shape ,) = model .shape_from_dims ((dim ,))
38+ assert shape >= 1
39+
40+ model .add_coords ({f"{ dim } _truncated_" : pd .RangeIndex (shape - 1 )})
41+ raw = pm .Normal (f"{ name } _truncated_" , dims = tuple (dims_pre ) + (dim_trunc ,), sigma = sigma )
42+ Q = make_sum_zero_hh (shape )
43+ draws = aet .dot (raw , Q [:, 1 :].T )
44+
45+ #if sigma is not None:
46+ # draws = sigma * draws
47+
48+ return pm .Deterministic (name , draws , dims = dims )
49+
50+
51+
52+ def make_sum_zero_hh (N : int ) -> np .ndarray :
53+ """
54+ Build a householder transformation matrix that maps e_1 to a vector of all 1s.
55+ """
56+ e_1 = np .zeros (N )
57+ e_1 [0 ] = 1
58+ a = np .ones (N )
59+ a /= np .sqrt (a @ a )
60+ v = a + e_1
61+ v /= np .sqrt (v @ v )
62+ return np .eye (N ) - 2 * np .outer (v , v )
63+
64+ def make_sum_zero_hh (N : int ) -> np .ndarray :
65+ """
66+ Build a householder transformation matrix that maps e_1 to a vector of all 1s.
67+ """
68+ e_1 = np .zeros (N )
69+ e_1 [0 ] = 1
70+ a = np .ones (N )
71+ a /= np .sqrt (a @ a )
72+ v = a + e_1
73+ v /= np .sqrt (v @ v )
74+ return np .eye (N ) - 2 * np .outer (v , v )
0 commit comments