|
| 1 | +struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo |
| 2 | + accs::Accs |
| 3 | +end |
| 4 | +DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs |
| 5 | +DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi |
| 6 | +DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) |
| 7 | + |
| 8 | +struct RangeAndLinked |
| 9 | + # indices that the variable corresponds to in the vectorised parameter |
| 10 | + range::UnitRange{Int} |
| 11 | + # whether it's linked |
| 12 | + is_linked::Bool |
| 13 | +end |
| 14 | + |
| 15 | +struct FastLDFContext{T<:AbstractVector{<:Real}} <: AbstractContext |
| 16 | + varname_ranges::Dict{VarName,RangeAndLinked} |
| 17 | + params::T |
| 18 | +end |
| 19 | +DynamicPPL.NodeTrait(::FastLDFContext) = IsLeaf() |
| 20 | + |
| 21 | +function tilde_assume!!( |
| 22 | + ctx::FastLDFContext, right::Distribution, vn::VarName, vi::OnlyAccsVarInfo |
| 23 | +) |
| 24 | + # Don't need to read the data from the varinfo at all since it's |
| 25 | + # all inside the context. |
| 26 | + range_and_linked = ctx.varname_ranges[vn] |
| 27 | + y = @view ctx.params[range_and_linked.range] |
| 28 | + is_linked = range_and_linked.is_linked |
| 29 | + f = if is_linked |
| 30 | + from_linked_vec_transform(right) |
| 31 | + else |
| 32 | + from_vec_transform(right) |
| 33 | + end |
| 34 | + x, inv_logjac = with_logabsdet_jacobian(f, y) |
| 35 | + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) |
| 36 | + return x, vi |
| 37 | +end |
| 38 | + |
| 39 | +function tilde_observe!!( |
| 40 | + ::FastLDFContext, |
| 41 | + right::Distribution, |
| 42 | + left, |
| 43 | + vn::Union{VarName,Nothing}, |
| 44 | + vi::OnlyAccsVarInfo, |
| 45 | +) |
| 46 | + # This is the same as for DefaultContext |
| 47 | + vi = accumulate_observe!!(vi, right, left, vn) |
| 48 | + return left, vi |
| 49 | +end |
| 50 | + |
| 51 | +struct FastLDF{M<:Model,F<:Function} |
| 52 | + _model::M |
| 53 | + _getlogdensity::F |
| 54 | + _varname_ranges::Dict{VarName,RangeAndLinked} |
| 55 | + |
| 56 | + function FastLDF( |
| 57 | + model::Model, |
| 58 | + getlogdensity::Function, |
| 59 | + # This only works with typed Metadata-varinfo. |
| 60 | + # Obviously, this can be generalised later. |
| 61 | + varinfo::VarInfo{<:NamedTuple{syms}}, |
| 62 | + ) where {syms} |
| 63 | + # Figure out which variable corresponds to which index, and |
| 64 | + # which variables are linked. |
| 65 | + all_ranges = Dict{VarName,RangeAndLinked}() |
| 66 | + offset = 1 |
| 67 | + for sym in syms |
| 68 | + md = varinfo.metadata[sym] |
| 69 | + for (vn, idx) in md.idcs |
| 70 | + len = length(md.ranges[idx]) |
| 71 | + is_linked = md.is_transformed[idx] |
| 72 | + range = offset:(offset + len - 1) |
| 73 | + all_ranges[vn] = RangeAndLinked(range, is_linked) |
| 74 | + offset += len |
| 75 | + end |
| 76 | + end |
| 77 | + return new{typeof(model),typeof(getlogdensity)}(model, getlogdensity, all_ranges) |
| 78 | + end |
| 79 | +end |
| 80 | + |
| 81 | +function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) |
| 82 | + ctx = FastLDFContext(fldf._varname_ranges, params) |
| 83 | + model = DynamicPPL.setleafcontext(fldf._model, ctx) |
| 84 | + # This can obviously also be optimised for the case where not |
| 85 | + # all accumulators are needed. |
| 86 | + accs = AccumulatorTuple(( |
| 87 | + LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator() |
| 88 | + )) |
| 89 | + _, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs)) |
| 90 | + return fldf._getlogdensity(vi) |
| 91 | +end |
0 commit comments