Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 81 additions & 104 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -1,41 +1,19 @@
module DynamicPPLMCMCChainsExt

using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC
using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC, Random
using MCMCChains: MCMCChains

_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names

function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains)
return _has_varname_to_symbol(chain.info)
end

function _check_varname_indexing(c::MCMCChains.Chains)
return DynamicPPL.supports_varname_indexing(c) ||
error("This `Chains` object does not support indexing using `VarName`s.")
end

function DynamicPPL.getindex_varname(
function getindex_varname(
c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx
)
_check_varname_indexing(c)
return c[sample_idx, c.info.varname_to_symbol[vn], chain_idx]
end
function DynamicPPL.varnames(c::MCMCChains.Chains)
_check_varname_indexing(c)
function get_varnames(c::MCMCChains.Chains)
haskey(c.info, :varname_to_symbol) ||
error("This `Chains` object does not support indexing using `VarName`s.")
return keys(c.info.varname_to_symbol)
end

function chain_sample_to_varname_dict(
c::MCMCChains.Chains{Tval}, sample_idx, chain_idx
) where {Tval}
_check_varname_indexing(c)
d = Dict{DynamicPPL.VarName,Tval}()
for vn in DynamicPPL.varnames(c)
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
end
return d
end

"""
AbstractMCMC.from_samples(
::Type{MCMCChains.Chains},
Expand Down Expand Up @@ -118,8 +96,8 @@ function AbstractMCMC.to_samples(
# Get parameters
params_matrix = map(idxs) do (sample_idx, chain_idx)
d = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}()
for vn in DynamicPPL.varnames(chain)
d[vn] = DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx)
for vn in get_varnames(chain)
d[vn] = getindex_varname(chain, sample_idx, vn, chain_idx)
end
d
end
Expand All @@ -140,6 +118,47 @@ function AbstractMCMC.to_samples(
end
end

"""
reevaluate_with(
rng::AbstractRNG,
model::Model,
chain::MCMCChains.Chains;
fallback=nothing,
)

Re-evaluate `model` for each sample in `chain`, returning an matrix of (retval, varinfo)
tuples.

This loops over all entries in the chain and uses `DynamicPPL.InitFromParams` as the
initialisation strategy when re-evaluating the model. For many usecases the fallback should
not be provided (as we expect the chain to contain all necessary variables); but for
`predict` this has to be `InitFromPrior()` to allow sampling new variables (i.e. generating
the posterior predictions).
"""
function reevaluate_with_chain(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
accs::NTuple{N,DynamicPPL.AbstractAccumulator},
fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing,
) where {N}
params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain)
return map(params_with_stats) do ps
varinfo = DynamicPPL.Experimental.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(accs))
DynamicPPL.init!!(
rng, model, varinfo, DynamicPPL.InitFromParams(ps.params, fallback)
)
end
end
function reevaluate_with_chain(
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
accs::NTuple{N,DynamicPPL.AbstractAccumulator},
fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing,
) where {N}
return reevaluate_with_chain(Random.default_rng(), model, chain, accs, fallback)
end
Comment on lines +121 to +160
Copy link
Member Author

@penelopeysm penelopeysm Nov 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fundamentally, all the functions in this extension really just use this under the hood.

FWIW the FlexiChains extension has a very similar structure and I believe these can be unified pretty much immediately after this PR. To be specific, DynamicPPL.reevaluate_with_chain should be implemented by each chain type in the most performant manner (FlexiChains doesn't use InitFromParams), but the definitions of returned, logjoint, ..., pointwise_logdensities, ... can be shared.

predict can't yet be shared unfortunately, because the include_all keyword argument forces custom MCMCChains / FlexiChains code. That would require an extension of the AbstractChains API to support a subset-like operation.


"""
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)

Expand Down Expand Up @@ -208,30 +227,18 @@ function DynamicPPL.predict(
include_all=false,
)
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)

# Set up a VarInfo with the right accumulators
varinfo = DynamicPPL.setaccs!!(
DynamicPPL.VarInfo(),
(
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
DynamicPPL.ValuesAsInModelAccumulator(false),
),
accs = (
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
DynamicPPL.ValuesAsInModelAccumulator(false),
)
_, varinfo = DynamicPPL.init!!(model, varinfo)
varinfo = DynamicPPL.typed_varinfo(varinfo)

params_and_stats = AbstractMCMC.to_samples(
DynamicPPL.ParamsWithStats, parameter_only_chain
predictions = map(
DynamicPPL.ParamsWithStats ∘ last,
reevaluate_with_chain(
rng, model, parameter_only_chain, accs, DynamicPPL.InitFromPrior()
),
)
predictions = map(params_and_stats) do ps
_, varinfo = DynamicPPL.init!!(
rng, model, varinfo, DynamicPPL.InitFromParams(ps.params)
)
DynamicPPL.ParamsWithStats(varinfo)
end
chain_result = AbstractMCMC.from_samples(MCMCChains.Chains, predictions)

parameter_names = if include_all
MCMCChains.names(chain_result, :parameters)
else
Expand Down Expand Up @@ -311,18 +318,7 @@ julia> returned(model, chain)
"""
function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains)
chain = MCMCChains.get_sections(chain_full, :parameters)
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain)
return map(params_with_stats) do ps
first(
DynamicPPL.init!!(
model,
varinfo,
DynamicPPL.InitFromParams(ps.params, DynamicPPL.InitFromPrior()),
),
)
end
return map(first, reevaluate_with_chain(model, chain, (), nothing))
end

"""
Expand Down Expand Up @@ -415,24 +411,13 @@ function DynamicPPL.pointwise_logdensities(
::Type{Tout}=MCMCChains.Chains,
::Val{whichlogprob}=Val(:both),
) where {whichlogprob,Tout}
vi = DynamicPPL.VarInfo(model)
acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}()
accname = DynamicPPL.accumulator_name(acc)
vi = DynamicPPL.setaccs!!(vi, (acc,))
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
pointwise_logps = map(iters) do (sample_idx, chain_idx)
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
# Re-evaluate the model
_, vi = DynamicPPL.init!!(
model,
vi,
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
)
DynamicPPL.getacc(vi, Val(accname)).logps
end

pointwise_logps =
map(reevaluate_with_chain(model, parameter_only_chain, (acc,), nothing)) do (_, vi)
DynamicPPL.getacc(vi, Val(accname)).logps
end
# pointwise_logps is a matrix of OrderedDicts
all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
for d in pointwise_logps
Expand Down Expand Up @@ -519,15 +504,15 @@ julia> logjoint(demo_model([1., 2.]), chain)
```
"""
function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains)
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}(
vn_parent => DynamicPPL.values_from_chain(
var_info, vn_parent, chain, chain_idx, iteration_idx
) for vn_parent in keys(var_info)
)
DynamicPPL.logjoint(model, argvals_dict)
end
return map(
DynamicPPL.getlogjoint ∘ last,
reevaluate_with_chain(
model,
chain,
(DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator()),
nothing,
),
)
end

"""
Expand Down Expand Up @@ -559,15 +544,12 @@ julia> loglikelihood(demo_model([1., 2.]), chain)
```
"""
function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains)
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}(
vn_parent => DynamicPPL.values_from_chain(
var_info, vn_parent, chain, chain_idx, iteration_idx
) for vn_parent in keys(var_info)
)
DynamicPPL.loglikelihood(model, argvals_dict)
end
return map(
DynamicPPL.getloglikelihood ∘ last,
reevaluate_with_chain(
model, chain, (DynamicPPL.LogLikelihoodAccumulator(),), nothing
),
)
end

"""
Expand Down Expand Up @@ -600,15 +582,10 @@ julia> logprior(demo_model([1., 2.]), chain)
```
"""
function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains)
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}(
vn_parent => DynamicPPL.values_from_chain(
var_info, vn_parent, chain, chain_idx, iteration_idx
) for vn_parent in keys(var_info)
)
DynamicPPL.logprior(model, argvals_dict)
end
return map(
DynamicPPL.getlogprior ∘ last,
reevaluate_with_chain(model, chain, (DynamicPPL.LogPriorAccumulator(),), nothing),
)
end

end
26 changes: 0 additions & 26 deletions src/chains.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,3 @@
"""
supports_varname_indexing(chain::AbstractChains)

Return `true` if `chain` supports indexing using `VarName` in place of the
variable name index.
"""
supports_varname_indexing(::AbstractChains) = false

"""
getindex_varname(chain::AbstractChains, sample_idx, varname::VarName, chain_idx)

Return the value of `varname` in `chain` at `sample_idx` and `chain_idx`.

Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref).
"""
function getindex_varname end

"""
varnames(chains::AbstractChains)

Return an iterator over the varnames present in `chains`.

Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref).
"""
function varnames end

"""
ParamsWithStats

Expand Down
2 changes: 1 addition & 1 deletion src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitS
function InitFromParams(
params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior()
)
return InitFromParams(to_varname_dict(params), fallback)
return new{typeof(params),typeof(fallback)}(params, fallback)
end
end
function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams)
Expand Down
24 changes: 17 additions & 7 deletions src/fasteval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ using DynamicPPL:
AbstractContext,
AbstractVarInfo,
AccumulatorTuple,
DynamicPPL,
InitContext,
InitFromParams,
InitFromPrior,
InitFromUniform,
LogJacobianAccumulator,
LogLikelihoodAccumulator,
LogPriorAccumulator,
Expand All @@ -81,6 +86,7 @@ using DynamicPPL:
getlogprior_internal,
leafcontext
using ADTypes: ADTypes
using BangBang: BangBang
using Bijectors: with_logabsdet_jacobian
using AbstractPPL: AbstractPPL, VarName
using Distributions: Distribution
Expand Down Expand Up @@ -108,6 +114,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators())
DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi
DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs
DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs)
@inline Base.haskey(::OnlyAccsVarInfo, ::VarName) = false
@inline DynamicPPL.is_transformed(::OnlyAccsVarInfo) = false
@inline BangBang.push!!(vi::OnlyAccsVarInfo, vn, y, dist) = vi
function DynamicPPL.get_param_eltype(
::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, model::Model
)
Expand All @@ -117,19 +126,20 @@ function DynamicPPL.get_param_eltype(
leaf_ctx = DynamicPPL.leafcontext(model.context)
if leaf_ctx isa FastEvalVectorContext
return eltype(leaf_ctx.params)
elseif leaf_ctx isa InitContext
return _get_strategy_eltype(leaf_ctx.strategy)
else
# TODO(penelopeysm): In principle this can be done with InitContext{InitWithParams}.
# See also `src/simple_varinfo.jl` where `infer_nested_eltype` is used to try to
# figure out the parameter type from a NamedTuple or Dict. The benefit of
# implementing this for InitContext is that we could then use OnlyAccsVarInfo with
# it, which means fast evaluation with NamedTuple or Dict parameters! And I believe
# that Mooncake / Enzyme should be able to differentiate through that too and
# provide a NamedTuple of gradients (although I haven't tested this yet).
error(
"OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))",
)
end
end
_get_strategy_eltype(s::InitFromParams) = DynamicPPL.infer_nested_eltype(typeof(s.params))
# No need to enforce any particular eltype here, since new parameters are sampled
_get_strategy_eltype(::InitFromPrior) = Any
_get_strategy_eltype(::InitFromUniform) = Any
# Default fallback
_get_strategy_eltype(::DynamicPPL.AbstractInitStrategy) = Any

"""
RangeAndLinked
Expand Down
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1148,7 +1148,7 @@ julia> returned(model, Dict{VarName,Float64}(@varname(m) => 2.0))
```
"""
function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarName}})
vi = DynamicPPL.setaccs!!(VarInfo(), ())
vi = DynamicPPL.Experimental.OnlyAccsVarInfo(AccumulatorTuple())
# Note: we can't use `fix(model, parameters)` because
# https://github.com/TuringLang/DynamicPPL.jl/issues/1097
# Use `nothing` as the fallback to ensure that any missing parameters cause an error
Expand Down