|
16 | 16 |
|
17 | 17 | import numpy as np |
18 | 18 | import pytensor.tensor as pt |
| 19 | +import pytensor |
19 | 20 |
|
20 | 21 | from pytensor.graph import Op |
21 | 22 | from pytensor.npy_2_compat import normalize_axis_tuple |
|
44 | 45 | "ordered", |
45 | 46 | "simplex", |
46 | 47 | "sum_to_1", |
| 48 | + "circular", |
| 49 | + "CholeskyCorr", |
| 50 | + "CholeskyCovPacked", |
| 51 | + "Chain", |
| 52 | + "ZeroSumTransform", |
47 | 53 | ] |
48 | 54 |
|
49 | 55 |
|
@@ -135,6 +141,115 @@ def log_jac_det(self, value, *inputs): |
135 | 141 | return pt.sum(y, axis=-1) |
136 | 142 |
|
137 | 143 |
|
| 144 | +class CholeskyCorr(Transform): |
| 145 | + """ |
| 146 | + Transforms the off-diagonal elements of a correlation matrix to |
| 147 | + unconstrained real numbers. |
| 148 | +
|
| 149 | + Note: This is not particular to the LKJ distribution - it is only a |
| 150 | + transform to help generate cholesky decompositions for random valid |
| 151 | + correlation matrices. |
| 152 | +
|
| 153 | + Ported from here: https://github.com/tensorflow/probability/blob/94f592af363e13391858b48f785eb4c250912904/tensorflow_probability/python/bijectors/correlation_cholesky.py#L31 |
| 154 | +
|
| 155 | + The backward side of this transformation is the off-diagonal upper |
| 156 | + triangular elements of a correlation matrix, specified in row major order. |
| 157 | + """ |
| 158 | + |
| 159 | + name = "cholesky-corr" |
| 160 | + |
| 161 | + def __init__(self, n): |
| 162 | + """ |
| 163 | +
|
| 164 | + Parameters |
| 165 | + ---------- |
| 166 | + n: int |
| 167 | + Size of correlation matrix |
| 168 | + """ |
| 169 | + self.n = n |
| 170 | + self.m = int(n*(n-1)/2) # number of off-diagonal elements |
| 171 | + self.tril_r_idxs, self.tril_c_idxs = self._generate_tril_indices() |
| 172 | + self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices() |
| 173 | + |
| 174 | + def _generate_tril_indices(self): |
| 175 | + row_indices, col_indices = np.tril_indices(self.n, -1) |
| 176 | + return ( |
| 177 | + pytensor.shared(row_indices), |
| 178 | + pytensor.shared(col_indices) |
| 179 | + ) |
| 180 | + |
| 181 | + def _generate_triu_indices(self): |
| 182 | + row_indices, col_indices = np.triu_indices(self.n, 1) |
| 183 | + return ( |
| 184 | + pytensor.shared(row_indices), |
| 185 | + pytensor.shared(col_indices) |
| 186 | + ) |
| 187 | + |
| 188 | + def _jacobian(self, value, *inputs): |
| 189 | + return pt.jacobian( |
| 190 | + self.backward(value), |
| 191 | + wrt=value |
| 192 | + ) |
| 193 | + |
| 194 | + def log_jac_det(self, value, *inputs): |
| 195 | + """ |
| 196 | + Compute log of the determinant of the jacobian. |
| 197 | +
|
| 198 | + There are no clever tricks here - we literally compute the jacobian |
| 199 | + then compute its determinant then take log. |
| 200 | + """ |
| 201 | + jac = self._jacobian(value) |
| 202 | + return pt.log(pt.linalg.det(jac)) |
| 203 | + |
| 204 | + def forward(self, value, *inputs): |
| 205 | + """ |
| 206 | + Convert the off-diagonal elements of a cholesky decomposition of a |
| 207 | + correlation matrix to unconstrained real numbers. |
| 208 | + """ |
| 209 | + # The correlation matrix is specified via its upper triangular elements |
| 210 | + corr = pt.set_subtensor( |
| 211 | + pt.zeros((self.n, self.n))[self.triu_r_idxs, self.triu_c_idxs], |
| 212 | + value |
| 213 | + ) |
| 214 | + corr = corr + corr.T + pt.eye(self.n) |
| 215 | + |
| 216 | + chol = pt.linalg.cholesky(corr) |
| 217 | + |
| 218 | + # Are the diagonals always guaranteed to be positive? |
| 219 | + # I don't know, so we'll use abs |
| 220 | + row_norms = 1/pt.abs(pt.diag(chol)) |
| 221 | + |
| 222 | + # Multiply by the row norms to undo the normalization |
| 223 | + unconstrained = chol*row_norms[:, pt.newaxis] |
| 224 | + |
| 225 | + return unconstrained[self.tril_r_idxs, self.tril_c_idxs] |
| 226 | + |
| 227 | + def backward(self, value, *inputs, foo=False): |
| 228 | + """ |
| 229 | + Convert unconstrained real numbers to the off-diagonal elements of the |
| 230 | + cholesky decomposition of a correlation matrix. |
| 231 | + """ |
| 232 | + # The diagonals of this matrix are 1, but these ones are just used for |
| 233 | + # computing a denominator. The diagonals of the cholesky factor are not |
| 234 | + # returned, but they are not ones. |
| 235 | + chol_pre_norm = pt.set_subtensor( |
| 236 | + pt.eye(self.n).astype("floatX")[self.tril_r_idxs, self.tril_c_idxs], |
| 237 | + value |
| 238 | + ) |
| 239 | + |
| 240 | + # derivative of pt.linalg.norm ended up complex, which caused errors |
| 241 | +# row_norm = pt.abs(pt.linalg.norm(chol_pre_norm, axis=1))[:, pt.newaxis].astype("floatX") |
| 242 | + |
| 243 | + row_norm = pt.pow(pt.abs(pt.pow(chol_pre_norm, 2).sum(1)), 0.5) |
| 244 | + chol = chol_pre_norm / row_norm[:, pt.newaxis] |
| 245 | + |
| 246 | + # Undo the cholesky decomposition |
| 247 | + corr = pt.matmul(chol, chol.T) |
| 248 | + |
| 249 | + # We want the upper triangular indices here. |
| 250 | + return corr[self.triu_r_idxs, self.triu_c_idxs] |
| 251 | + |
| 252 | + |
138 | 253 | class CholeskyCovPacked(Transform): |
139 | 254 | """Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the log scale.""" |
140 | 255 |
|
|
0 commit comments