Skip to content

Conversation

@penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Nov 8, 2025

This does lead to some improvement in performance, but not as much as I had hoped:

using Turing, Random, LinearAlgebra
y = [28, 8, -3, 7, -1, 1, 18, 12]
sigma = [15, 10, 16, 11, 9, 11, 10, 18]
@model function eight_schools(y, sigma)
    mu ~ Normal(0, 5)
    tau ~ truncated(Cauchy(0, 5); lower=0)
    theta ~ MvNormal(fill(mu, length(y)), tau^2 * I)
    for i in eachindex(y)
        y[i] ~ Normal(theta[i], sigma[i])
    end
    return (mu=mu, tau=tau)
end
model = eight_schools(y, sigma)

chn = sample(Xoshiro(468), model, NUTS(), 1000; progress=false)
@time returned(model, chn);
# main   :  0.073007 seconds (394.28 k allocations: 17.673 MiB)
# This PR:  0.068737 seconds (319.98 k allocations: 9.329 MiB)

It appears to me that most of this time is spent faffing with MCMCChains. Every time you try to get the value of @varname(mu) you have to go through the varname_to_symbol dict, etc. Even more importantly, there's an issue with theta, because that's vector-valued and when you access it you have to reconstruct the vector (with getvalue(dict, vn, dist). So I believe we are hitting a natural plateau that is caused by the data structure.

Still, I suppose it's worth putting this in because it's basically free performance, so why not?

For FlexiChains, see penelopeysm/FlexiChains.jl#85, which does similar performance tricks but isn't limited by the data structure, and gets a 2x speedup:

using FlexiChains
chn = sample(Xoshiro(468), model, NUTS(), 1000; progress=false, chain_type=VNChain)
@time returned(model, chn);
# full varinfo         : 0.016095 seconds (169.83 k allocations: 13.120 MiB, 11.21% gc time)
# only accs in varinfo : 0.008083 seconds (133.54 k allocations: 5.948 MiB)

A complete aside

Wouldn't it be fun if we could inspect the model, realise that the return value only involves mu and tau, realise that tau is a bits type and thus tau^2 * I cannot mutate tau, and thus optimise away all the things to do with theta and y?

@github-actions
Copy link
Contributor

github-actions bot commented Nov 8, 2025

Benchmark Report for Commit 1d79c7f

Computer Information

Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬───────────────────┬────────┬────────────────┬─────────────────┐
│                 Model │   Dim │  AD Backend │           VarInfo │ Linked │ t(eval)/t(ref) │ t(grad)/t(eval) │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼────────────────┼─────────────────┤
│ Simple assume observe │     1 │ forwarddiff │             typed │  false │            6.2 │             1.9 │
│           Smorgasbord │   201 │ forwarddiff │             typed │  false │          710.4 │            45.5 │
│           Smorgasbord │   201 │ forwarddiff │ simple_namedtuple │   true │          416.3 │            59.2 │
│           Smorgasbord │   201 │ forwarddiff │           untyped │   true │          779.7 │            38.4 │
│           Smorgasbord │   201 │ forwarddiff │       simple_dict │   true │         6922.9 │            25.4 │
│           Smorgasbord │   201 │ forwarddiff │      typed_vector │   true │          776.6 │            40.5 │
│           Smorgasbord │   201 │ forwarddiff │    untyped_vector │   true │          784.0 │            37.0 │
│           Smorgasbord │   201 │ reversediff │             typed │   true │          896.3 │            45.6 │
│           Smorgasbord │   201 │    mooncake │             typed │   true │          715.7 │             5.9 │
│           Smorgasbord │   201 │      enzyme │             typed │   true │          892.2 │             3.9 │
│    Loop univariate 1k │  1000 │    mooncake │             typed │   true │         3852.4 │             5.8 │
│       Multivariate 1k │  1000 │    mooncake │             typed │   true │         1002.1 │             8.9 │
│   Loop univariate 10k │ 10000 │    mooncake │             typed │   true │        42247.0 │             5.3 │
│      Multivariate 10k │ 10000 │    mooncake │             typed │   true │         8799.3 │             9.8 │
│               Dynamic │    10 │    mooncake │             typed │   true │          127.9 │            11.0 │
│              Submodel │     1 │    mooncake │             typed │   true │            8.6 │             6.6 │
│                   LDA │    12 │ reversediff │             typed │   true │          980.9 │             2.1 │
└───────────────────────┴───────┴─────────────┴───────────────────┴────────┴────────────────┴─────────────────┘

@codecov
Copy link

codecov bot commented Nov 8, 2025

Codecov Report

❌ Patch coverage is 88.88889% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.39%. Comparing base (699aa23) to head (1d79c7f).

Files with missing lines Patch % Lines
src/fasteval.jl 50.00% 3 Missing ⚠️
Additional details and impacted files
@@               Coverage Diff               @@
##           py/fastinit    #1130      +/-   ##
===============================================
- Coverage        81.72%   81.39%   -0.34%     
===============================================
  Files               41       41              
  Lines             3956     3928      -28     
===============================================
- Hits              3233     3197      -36     
- Misses             723      731       +8     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@penelopeysm penelopeysm mentioned this pull request Nov 8, 2025
Comment on lines +121 to +160
"""
reevaluate_with(
rng::AbstractRNG,
model::Model,
chain::MCMCChains.Chains;
fallback=nothing,
)
Re-evaluate `model` for each sample in `chain`, returning an matrix of (retval, varinfo)
tuples.
This loops over all entries in the chain and uses `DynamicPPL.InitFromParams` as the
initialisation strategy when re-evaluating the model. For many usecases the fallback should
not be provided (as we expect the chain to contain all necessary variables); but for
`predict` this has to be `InitFromPrior()` to allow sampling new variables (i.e. generating
the posterior predictions).
"""
function reevaluate_with_chain(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
accs::NTuple{N,DynamicPPL.AbstractAccumulator},
fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing,
) where {N}
params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain)
return map(params_with_stats) do ps
varinfo = DynamicPPL.Experimental.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(accs))
DynamicPPL.init!!(
rng, model, varinfo, DynamicPPL.InitFromParams(ps.params, fallback)
)
end
end
function reevaluate_with_chain(
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
accs::NTuple{N,DynamicPPL.AbstractAccumulator},
fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing,
) where {N}
return reevaluate_with_chain(Random.default_rng(), model, chain, accs, fallback)
end
Copy link
Member Author

@penelopeysm penelopeysm Nov 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fundamentally, all the functions in this extension really just use this under the hood.

FWIW the FlexiChains extension has a very similar structure and I believe these can be unified pretty much immediately after this PR. To be specific, DynamicPPL.reevaluate_with_chain should be implemented by each chain type in the most performant manner (FlexiChains doesn't use InitFromParams), but the definitions of returned, logjoint, ..., pointwise_logdensities, ... can be shared.

predict can't yet be shared unfortunately, because the include_all keyword argument forces custom MCMCChains / FlexiChains code. That would require an extension of the AbstractChains API to support a subset-like operation.

@github-actions
Copy link
Contributor

github-actions bot commented Nov 8, 2025

DynamicPPL.jl documentation for PR #1130 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1130/

Base automatically changed from py/fastinit to py/fastldf November 10, 2025 19:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants