diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index d8c343917..9b34a9849 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -1,41 +1,19 @@ module DynamicPPLMCMCChainsExt -using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC +using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC, Random using MCMCChains: MCMCChains -_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names - -function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains) - return _has_varname_to_symbol(chain.info) -end - -function _check_varname_indexing(c::MCMCChains.Chains) - return DynamicPPL.supports_varname_indexing(c) || - error("This `Chains` object does not support indexing using `VarName`s.") -end - -function DynamicPPL.getindex_varname( +function getindex_varname( c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx ) - _check_varname_indexing(c) return c[sample_idx, c.info.varname_to_symbol[vn], chain_idx] end -function DynamicPPL.varnames(c::MCMCChains.Chains) - _check_varname_indexing(c) +function get_varnames(c::MCMCChains.Chains) + haskey(c.info, :varname_to_symbol) || + error("This `Chains` object does not support indexing using `VarName`s.") return keys(c.info.varname_to_symbol) end -function chain_sample_to_varname_dict( - c::MCMCChains.Chains{Tval}, sample_idx, chain_idx -) where {Tval} - _check_varname_indexing(c) - d = Dict{DynamicPPL.VarName,Tval}() - for vn in DynamicPPL.varnames(c) - d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx) - end - return d -end - """ AbstractMCMC.from_samples( ::Type{MCMCChains.Chains}, @@ -118,8 +96,8 @@ function AbstractMCMC.to_samples( # Get parameters params_matrix = map(idxs) do (sample_idx, chain_idx) d = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}() - for vn in DynamicPPL.varnames(chain) - d[vn] = DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx) + for vn in get_varnames(chain) + d[vn] = getindex_varname(chain, sample_idx, vn, chain_idx) end d end @@ -140,6 +118,47 @@ function AbstractMCMC.to_samples( end end +""" + reevaluate_with( + rng::AbstractRNG, + model::Model, + chain::MCMCChains.Chains; + fallback=nothing, + ) + +Re-evaluate `model` for each sample in `chain`, returning an matrix of (retval, varinfo) +tuples. + +This loops over all entries in the chain and uses `DynamicPPL.InitFromParams` as the +initialisation strategy when re-evaluating the model. For many usecases the fallback should +not be provided (as we expect the chain to contain all necessary variables); but for +`predict` this has to be `InitFromPrior()` to allow sampling new variables (i.e. generating +the posterior predictions). +""" +function reevaluate_with_chain( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + chain::MCMCChains.Chains, + accs::NTuple{N,DynamicPPL.AbstractAccumulator}, + fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, +) where {N} + params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) + return map(params_with_stats) do ps + varinfo = DynamicPPL.Experimental.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(accs)) + DynamicPPL.init!!( + rng, model, varinfo, DynamicPPL.InitFromParams(ps.params, fallback) + ) + end +end +function reevaluate_with_chain( + model::DynamicPPL.Model, + chain::MCMCChains.Chains, + accs::NTuple{N,DynamicPPL.AbstractAccumulator}, + fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, +) where {N} + return reevaluate_with_chain(Random.default_rng(), model, chain, accs, fallback) +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) @@ -208,30 +227,18 @@ function DynamicPPL.predict( include_all=false, ) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - - # Set up a VarInfo with the right accumulators - varinfo = DynamicPPL.setaccs!!( - DynamicPPL.VarInfo(), - ( - DynamicPPL.LogPriorAccumulator(), - DynamicPPL.LogLikelihoodAccumulator(), - DynamicPPL.ValuesAsInModelAccumulator(false), - ), + accs = ( + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.ValuesAsInModelAccumulator(false), ) - _, varinfo = DynamicPPL.init!!(model, varinfo) - varinfo = DynamicPPL.typed_varinfo(varinfo) - - params_and_stats = AbstractMCMC.to_samples( - DynamicPPL.ParamsWithStats, parameter_only_chain + predictions = map( + DynamicPPL.ParamsWithStats ∘ last, + reevaluate_with_chain( + rng, model, parameter_only_chain, accs, DynamicPPL.InitFromPrior() + ), ) - predictions = map(params_and_stats) do ps - _, varinfo = DynamicPPL.init!!( - rng, model, varinfo, DynamicPPL.InitFromParams(ps.params) - ) - DynamicPPL.ParamsWithStats(varinfo) - end chain_result = AbstractMCMC.from_samples(MCMCChains.Chains, predictions) - parameter_names = if include_all MCMCChains.names(chain_result, :parameters) else @@ -311,18 +318,7 @@ julia> returned(model, chain) """ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains) chain = MCMCChains.get_sections(chain_full, :parameters) - varinfo = DynamicPPL.VarInfo(model) - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) - return map(params_with_stats) do ps - first( - DynamicPPL.init!!( - model, - varinfo, - DynamicPPL.InitFromParams(ps.params, DynamicPPL.InitFromPrior()), - ), - ) - end + return map(first, reevaluate_with_chain(model, chain, (), nothing)) end """ @@ -415,24 +411,13 @@ function DynamicPPL.pointwise_logdensities( ::Type{Tout}=MCMCChains.Chains, ::Val{whichlogprob}=Val(:both), ) where {whichlogprob,Tout} - vi = DynamicPPL.VarInfo(model) acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}() accname = DynamicPPL.accumulator_name(acc) - vi = DynamicPPL.setaccs!!(vi, (acc,)) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - pointwise_logps = map(iters) do (sample_idx, chain_idx) - # Extract values from the chain - values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) - # Re-evaluate the model - _, vi = DynamicPPL.init!!( - model, - vi, - DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), - ) - DynamicPPL.getacc(vi, Val(accname)).logps - end - + pointwise_logps = + map(reevaluate_with_chain(model, parameter_only_chain, (acc,), nothing)) do (_, vi) + DynamicPPL.getacc(vi, Val(accname)).logps + end # pointwise_logps is a matrix of OrderedDicts all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() for d in pointwise_logps @@ -519,15 +504,15 @@ julia> logjoint(demo_model([1., 2.]), chain) ``` """ function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains) - var_info = DynamicPPL.VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( - vn_parent => DynamicPPL.values_from_chain( - var_info, vn_parent, chain, chain_idx, iteration_idx - ) for vn_parent in keys(var_info) - ) - DynamicPPL.logjoint(model, argvals_dict) - end + return map( + DynamicPPL.getlogjoint ∘ last, + reevaluate_with_chain( + model, + chain, + (DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator()), + nothing, + ), + ) end """ @@ -559,15 +544,12 @@ julia> loglikelihood(demo_model([1., 2.]), chain) ``` """ function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) - var_info = DynamicPPL.VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( - vn_parent => DynamicPPL.values_from_chain( - var_info, vn_parent, chain, chain_idx, iteration_idx - ) for vn_parent in keys(var_info) - ) - DynamicPPL.loglikelihood(model, argvals_dict) - end + return map( + DynamicPPL.getloglikelihood ∘ last, + reevaluate_with_chain( + model, chain, (DynamicPPL.LogLikelihoodAccumulator(),), nothing + ), + ) end """ @@ -600,15 +582,10 @@ julia> logprior(demo_model([1., 2.]), chain) ``` """ function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) - var_info = DynamicPPL.VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( - vn_parent => DynamicPPL.values_from_chain( - var_info, vn_parent, chain, chain_idx, iteration_idx - ) for vn_parent in keys(var_info) - ) - DynamicPPL.logprior(model, argvals_dict) - end + return map( + DynamicPPL.getlogprior ∘ last, + reevaluate_with_chain(model, chain, (DynamicPPL.LogPriorAccumulator(),), nothing), + ) end end diff --git a/src/chains.jl b/src/chains.jl index 2b5976b9b..d47fb901a 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -1,29 +1,3 @@ -""" - supports_varname_indexing(chain::AbstractChains) - -Return `true` if `chain` supports indexing using `VarName` in place of the -variable name index. -""" -supports_varname_indexing(::AbstractChains) = false - -""" - getindex_varname(chain::AbstractChains, sample_idx, varname::VarName, chain_idx) - -Return the value of `varname` in `chain` at `sample_idx` and `chain_idx`. - -Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref). -""" -function getindex_varname end - -""" - varnames(chains::AbstractChains) - -Return an iterator over the varnames present in `chains`. - -Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref). -""" -function varnames end - """ ParamsWithStats diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 44dbc5508..efc6f1087 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -102,7 +102,7 @@ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitS function InitFromParams( params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() ) - return InitFromParams(to_varname_dict(params), fallback) + return new{typeof(params),typeof(fallback)}(params, fallback) end end function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) diff --git a/src/fasteval.jl b/src/fasteval.jl index c91254d43..5b95c8376 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -60,6 +60,11 @@ using DynamicPPL: AbstractContext, AbstractVarInfo, AccumulatorTuple, + DynamicPPL, + InitContext, + InitFromParams, + InitFromPrior, + InitFromUniform, LogJacobianAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, @@ -81,6 +86,7 @@ using DynamicPPL: getlogprior_internal, leafcontext using ADTypes: ADTypes +using BangBang: BangBang using Bijectors: with_logabsdet_jacobian using AbstractPPL: AbstractPPL, VarName using Distributions: Distribution @@ -108,6 +114,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) +@inline Base.haskey(::OnlyAccsVarInfo, ::VarName) = false +@inline DynamicPPL.is_transformed(::OnlyAccsVarInfo) = false +@inline BangBang.push!!(vi::OnlyAccsVarInfo, vn, y, dist) = vi function DynamicPPL.get_param_eltype( ::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, model::Model ) @@ -117,19 +126,20 @@ function DynamicPPL.get_param_eltype( leaf_ctx = DynamicPPL.leafcontext(model.context) if leaf_ctx isa FastEvalVectorContext return eltype(leaf_ctx.params) + elseif leaf_ctx isa InitContext + return _get_strategy_eltype(leaf_ctx.strategy) else - # TODO(penelopeysm): In principle this can be done with InitContext{InitWithParams}. - # See also `src/simple_varinfo.jl` where `infer_nested_eltype` is used to try to - # figure out the parameter type from a NamedTuple or Dict. The benefit of - # implementing this for InitContext is that we could then use OnlyAccsVarInfo with - # it, which means fast evaluation with NamedTuple or Dict parameters! And I believe - # that Mooncake / Enzyme should be able to differentiate through that too and - # provide a NamedTuple of gradients (although I haven't tested this yet). error( "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", ) end end +_get_strategy_eltype(s::InitFromParams) = DynamicPPL.infer_nested_eltype(typeof(s.params)) +# No need to enforce any particular eltype here, since new parameters are sampled +_get_strategy_eltype(::InitFromPrior) = Any +_get_strategy_eltype(::InitFromUniform) = Any +# Default fallback +_get_strategy_eltype(::DynamicPPL.AbstractInitStrategy) = Any """ RangeAndLinked diff --git a/src/model.jl b/src/model.jl index 6ca06aea6..718c56372 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1148,7 +1148,7 @@ julia> returned(model, Dict{VarName,Float64}(@varname(m) => 2.0)) ``` """ function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarName}}) - vi = DynamicPPL.setaccs!!(VarInfo(), ()) + vi = DynamicPPL.Experimental.OnlyAccsVarInfo(AccumulatorTuple()) # Note: we can't use `fix(model, parameters)` because # https://github.com/TuringLang/DynamicPPL.jl/issues/1097 # Use `nothing` as the fallback to ensure that any missing parameters cause an error