-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Implement unconstraining transform for LKJCorr #7380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement unconstraining transform for LKJCorr #7380
Conversation
|
|
|
Hi, It's unlikely I'm going to have any time to work on this for the next 6 months. The hardest part is coming up with a closed form solution for log_det_jac, which I don't think I'm very close to doing. |
|
Thanks for the update @johncant and for pushing this as far as you did. |
df723bc to
cf8d9a8
Compare
|
This is updated to a fully working implementation of the transformer, with non-trivial tests. It is currently blocked by #7910, because it uses There's some jank -- I'm not sure we need to pass If we got rid of the On that note, those two triangular pack helpers could themselves be a separate transformation, since they define a Bijection with forward and backward methods. That might be conceptually better, but on the other hand I don't think these would be used anywhere else, so it also makes sense to keep them as part of this class. |
3ac4970 to
1f29284
Compare
|
BTW once this is merged we should still explore the new transform by that Stan dev, it should have better sampling properties IIRC |
|
Do you have a link to the actual implementation? I was rooting around in the stan/stan-math repo and couldn't find it |
|
Wow - congratulations @jessegrabowski ! |
You had it 90% of the way there, we were just missing this weird spiral triangular construction that tfp was doing. I have no idea why they do it this way though, I just blindly copied xD |
1f29284 to
a6ab223
Compare
|
I switch the LKJCorr implementation to use
We could potentially try to check Here is the benchmark I ran: Before (unrolled loop): After (scan): |
|
We can support estimation of |
|
Those timings look good enough I wouldn't worry. It will hopefully fall out from our goal of speeding up Scan. Can you try with Numba mode out of curiosity? |
|
Numba timings. With scan: Unrolled loop (old implementation): Pretty bad! |
|
JAX also doesn't work with the new scan implementation, I guess because |
|
length shouldn't matter if only last state is used? can you write index as a recurring out? I don't recall where arange shows up |
|
timings are pretty nice, not worried about the small case |
No, it's bad. I updated the original timings with a compiled function instead of |
|
In the compiled version did you set updates so the rng is not copied? |
|
I used |
|
Ah so you were seeing longer compile times that made it look like scan was doing better? Unrolled 100 things or more is gonna be pretty ugly though |
|
Can you show the final numba graph? |
| return False | ||
|
|
||
|
|
||
| class PosDefMatrix(Op): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Get rid of this Op? Requires python mode and it's only calling cholesky to see if it raises
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. What we really need anyway is a complete typing system for matrices :)
|
|
||
| def log_jac_det(self, *args): | ||
| return super().log_jac_det(*args).sum(-1) | ||
| P = pt.concatenate([P0[None], P_seq], axis=0)[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whay do you concatenate and then take the last value? Sometimes P_seq is empty?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, when n=2. I had an ifelse for that but it still breaks in non-vm backends, so this was my other idea. Final fallback is to go back to non-symbolic n, but that's a bit of a bummer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Scan over pt.arange(2, pt.maximum(n, 3)) and then P = pt.where(n > 2, P_seq[-1], P0)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That will still compute the P_seq[-1] branch in all cases though?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh you want to do an extra scan step for safety. Seems about the same, both allocate an extra n,n in memory, but yours also does the extra computation
| # If a distribution has non-stochastic elements in the output (e.g. LKJCorr putting 1's on the diagonal), | ||
| # it will mess up the KS test. So we filter those out here. | ||
| stacked_samples = np.c_[np.atleast_1d(s0).flatten(), np.atleast_1d(s1).flatten()] | ||
| samples = stacked_samples[~np.isclose(stacked_samples[..., 0], stacked_samples[..., 1])] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like this change. Use some other routine in the test instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I think it's a pretty questionable test in the first place.
| # 2) Although n can be symbolic, the inner scan graph needs to be rebuilt after it changes. The approach | ||
| # implemented in this tester does not rebuild the inner function graph, causing an error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment doesn't make sense to me. Why / how could the Scan be rebuilt when n changes if n can be symbolic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't remember, it was something weird about how these tests work.
| def lkjcorr_default_transform(op, rv): | ||
| return MultivariateIntervalTransform(-1.0, 1.0) | ||
| rng, scan_rng, shape, n, eta, *_ = rv.owner.inputs | ||
| n = pt.get_scalar_constant_value(n) # Safely extract scalar value without eval |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if n can be arbitrarily symbolic there is no way to safely extract it :)
| # In general, RVs are expected to take an "rng" argument. We allow it here to prevent API break, but | ||
| # ignore it. This can be changed in the future if we relax the requirement that rng be a shared variable | ||
| # in a scan. | ||
| rng = kwargs.pop("rng", None) | ||
| if rng is not None: | ||
| warnings.warn( | ||
| "You passed a random generator to LKJCorr via the `rng` keyword argument, but it is not " | ||
| "used. To seed LKJCorr, pass two random generators via the `outer_rng` and `scan_rng` " | ||
| "keyword arguments.", | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be removed once we pin to the next PyTensor release

I've ported this bijector from
tensorflowand added toLKJCorr. This ensures that initial samples drawn fromLKJCorrare positive definite, which fixes #7101 . Sampling now completes successfully with no divergences.There are several parts I'm not comfortable with:
nparameter fromoporrvwithoutevaling any pytensors?@fonnesbeck @twiecki @jessegrabowski @velochy - please could you take a look? I would like to make sure that this fix makes sense before adding tests and making the linters pass.
Notes:
forwardintensorflow_probabilityisbackwardinpymcDescription
Backward method
Forward method
log_jac_det
This was quite complicated to implement, so I used the symbolic jacobian.
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7380.org.readthedocs.build/en/7380/