Skip to content

Commit 5cc74d1

Browse files
johncantjessegrabowski
authored andcommitted
Port TF bijector to ensure posdef LKJCorr samples
1 parent 9f653a6 commit 5cc74d1

File tree

2 files changed

+118
-1
lines changed

2 files changed

+118
-1
lines changed

pymc/distributions/multivariate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1647,7 +1647,9 @@ def logp(value, n, eta):
16471647

16481648
@_default_transform.register(_LKJCorr)
16491649
def lkjcorr_default_transform(op, rv):
1650-
return MultivariateIntervalTransform(-1.0, 1.0)
1650+
_, _, _, n, *_ = rv.owner.inputs
1651+
n = n.eval()
1652+
return transforms.CholeskyCorr(n)
16511653

16521654

16531655
class LKJCorr:

pymc/distributions/transforms.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numpy as np
1818
import pytensor.tensor as pt
19+
import pytensor
1920

2021
from pytensor.graph import Op
2122
from pytensor.npy_2_compat import normalize_axis_tuple
@@ -44,6 +45,11 @@
4445
"ordered",
4546
"simplex",
4647
"sum_to_1",
48+
"circular",
49+
"CholeskyCorr",
50+
"CholeskyCovPacked",
51+
"Chain",
52+
"ZeroSumTransform",
4753
]
4854

4955

@@ -135,6 +141,115 @@ def log_jac_det(self, value, *inputs):
135141
return pt.sum(y, axis=-1)
136142

137143

144+
class CholeskyCorr(Transform):
145+
"""
146+
Transforms the off-diagonal elements of a correlation matrix to
147+
unconstrained real numbers.
148+
149+
Note: This is not particular to the LKJ distribution - it is only a
150+
transform to help generate cholesky decompositions for random valid
151+
correlation matrices.
152+
153+
Ported from here: https://github.com/tensorflow/probability/blob/94f592af363e13391858b48f785eb4c250912904/tensorflow_probability/python/bijectors/correlation_cholesky.py#L31
154+
155+
The backward side of this transformation is the off-diagonal upper
156+
triangular elements of a correlation matrix, specified in row major order.
157+
"""
158+
159+
name = "cholesky-corr"
160+
161+
def __init__(self, n):
162+
"""
163+
164+
Parameters
165+
----------
166+
n: int
167+
Size of correlation matrix
168+
"""
169+
self.n = n
170+
self.m = int(n*(n-1)/2) # number of off-diagonal elements
171+
self.tril_r_idxs, self.tril_c_idxs = self._generate_tril_indices()
172+
self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices()
173+
174+
def _generate_tril_indices(self):
175+
row_indices, col_indices = np.tril_indices(self.n, -1)
176+
return (
177+
pytensor.shared(row_indices),
178+
pytensor.shared(col_indices)
179+
)
180+
181+
def _generate_triu_indices(self):
182+
row_indices, col_indices = np.triu_indices(self.n, 1)
183+
return (
184+
pytensor.shared(row_indices),
185+
pytensor.shared(col_indices)
186+
)
187+
188+
def _jacobian(self, value, *inputs):
189+
return pt.jacobian(
190+
self.backward(value),
191+
wrt=value
192+
)
193+
194+
def log_jac_det(self, value, *inputs):
195+
"""
196+
Compute log of the determinant of the jacobian.
197+
198+
There are no clever tricks here - we literally compute the jacobian
199+
then compute its determinant then take log.
200+
"""
201+
jac = self._jacobian(value)
202+
return pt.log(pt.linalg.det(jac))
203+
204+
def forward(self, value, *inputs):
205+
"""
206+
Convert the off-diagonal elements of a cholesky decomposition of a
207+
correlation matrix to unconstrained real numbers.
208+
"""
209+
# The correlation matrix is specified via its upper triangular elements
210+
corr = pt.set_subtensor(
211+
pt.zeros((self.n, self.n))[self.triu_r_idxs, self.triu_c_idxs],
212+
value
213+
)
214+
corr = corr + corr.T + pt.eye(self.n)
215+
216+
chol = pt.linalg.cholesky(corr)
217+
218+
# Are the diagonals always guaranteed to be positive?
219+
# I don't know, so we'll use abs
220+
row_norms = 1/pt.abs(pt.diag(chol))
221+
222+
# Multiply by the row norms to undo the normalization
223+
unconstrained = chol*row_norms[:, pt.newaxis]
224+
225+
return unconstrained[self.tril_r_idxs, self.tril_c_idxs]
226+
227+
def backward(self, value, *inputs, foo=False):
228+
"""
229+
Convert unconstrained real numbers to the off-diagonal elements of the
230+
cholesky decomposition of a correlation matrix.
231+
"""
232+
# The diagonals of this matrix are 1, but these ones are just used for
233+
# computing a denominator. The diagonals of the cholesky factor are not
234+
# returned, but they are not ones.
235+
chol_pre_norm = pt.set_subtensor(
236+
pt.eye(self.n).astype("floatX")[self.tril_r_idxs, self.tril_c_idxs],
237+
value
238+
)
239+
240+
# derivative of pt.linalg.norm ended up complex, which caused errors
241+
# row_norm = pt.abs(pt.linalg.norm(chol_pre_norm, axis=1))[:, pt.newaxis].astype("floatX")
242+
243+
row_norm = pt.pow(pt.abs(pt.pow(chol_pre_norm, 2).sum(1)), 0.5)
244+
chol = chol_pre_norm / row_norm[:, pt.newaxis]
245+
246+
# Undo the cholesky decomposition
247+
corr = pt.matmul(chol, chol.T)
248+
249+
# We want the upper triangular indices here.
250+
return corr[self.triu_r_idxs, self.triu_c_idxs]
251+
252+
138253
class CholeskyCovPacked(Transform):
139254
"""Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the log scale."""
140255

0 commit comments

Comments
 (0)