-
Notifications
You must be signed in to change notification settings - Fork 37
Implement predict, returned, logjoint, ... with fast eval
#1130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: py/fastldf
Are you sure you want to change the base?
Conversation
Benchmark Report for Commit 1d79c7fComputer InformationBenchmark Results |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## py/fastinit #1130 +/- ##
===============================================
- Coverage 81.72% 81.39% -0.34%
===============================================
Files 41 41
Lines 3956 3928 -28
===============================================
- Hits 3233 3197 -36
- Misses 723 731 +8 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| """ | ||
| reevaluate_with( | ||
| rng::AbstractRNG, | ||
| model::Model, | ||
| chain::MCMCChains.Chains; | ||
| fallback=nothing, | ||
| ) | ||
| Re-evaluate `model` for each sample in `chain`, returning an matrix of (retval, varinfo) | ||
| tuples. | ||
| This loops over all entries in the chain and uses `DynamicPPL.InitFromParams` as the | ||
| initialisation strategy when re-evaluating the model. For many usecases the fallback should | ||
| not be provided (as we expect the chain to contain all necessary variables); but for | ||
| `predict` this has to be `InitFromPrior()` to allow sampling new variables (i.e. generating | ||
| the posterior predictions). | ||
| """ | ||
| function reevaluate_with_chain( | ||
| rng::Random.AbstractRNG, | ||
| model::DynamicPPL.Model, | ||
| chain::MCMCChains.Chains, | ||
| accs::NTuple{N,DynamicPPL.AbstractAccumulator}, | ||
| fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, | ||
| ) where {N} | ||
| params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) | ||
| return map(params_with_stats) do ps | ||
| varinfo = DynamicPPL.Experimental.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(accs)) | ||
| DynamicPPL.init!!( | ||
| rng, model, varinfo, DynamicPPL.InitFromParams(ps.params, fallback) | ||
| ) | ||
| end | ||
| end | ||
| function reevaluate_with_chain( | ||
| model::DynamicPPL.Model, | ||
| chain::MCMCChains.Chains, | ||
| accs::NTuple{N,DynamicPPL.AbstractAccumulator}, | ||
| fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, | ||
| ) where {N} | ||
| return reevaluate_with_chain(Random.default_rng(), model, chain, accs, fallback) | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fundamentally, all the functions in this extension really just use this under the hood.
FWIW the FlexiChains extension has a very similar structure and I believe these can be unified pretty much immediately after this PR. To be specific, DynamicPPL.reevaluate_with_chain should be implemented by each chain type in the most performant manner (FlexiChains doesn't use InitFromParams), but the definitions of returned, logjoint, ..., pointwise_logdensities, ... can be shared.
predict can't yet be shared unfortunately, because the include_all keyword argument forces custom MCMCChains / FlexiChains code. That would require an extension of the AbstractChains API to support a subset-like operation.
|
DynamicPPL.jl documentation for PR #1130 is available at: |
699aa23 to
c35dff5
Compare
This does lead to some improvement in performance, but not as much as I had hoped:
It appears to me that most of this time is spent faffing with MCMCChains. Every time you try to get the value of
@varname(mu)you have to go through thevarname_to_symboldict, etc. Even more importantly, there's an issue withtheta, because that's vector-valued and when you access it you have to reconstruct the vector (withgetvalue(dict, vn, dist). So I believe we are hitting a natural plateau that is caused by the data structure.Still, I suppose it's worth putting this in because it's basically free performance, so why not?
For FlexiChains, see penelopeysm/FlexiChains.jl#85, which does similar performance tricks but isn't limited by the data structure, and gets a 2x speedup:
A complete aside
Wouldn't it be fun if we could inspect the model, realise that the return value only involves
muandtau, realise thattauis a bits type and thustau^2 * Icannot mutatetau, and thus optimise away all the things to do withthetaandy?