Skip to content

Commit 7cddac7

Browse files
committed
Fast Log Density Function
1 parent 6532d96 commit 7cddac7

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ include("simple_varinfo.jl")
191191
include("compiler.jl")
192192
include("pointwise_logdensities.jl")
193193
include("logdensityfunction.jl")
194+
include("fastldf.jl")
194195
include("model_utils.jl")
195196
include("extract_priors.jl")
196197
include("values_as_in_model.jl")

src/fastldf.jl

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo
2+
accs::Accs
3+
end
4+
DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs
5+
DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi
6+
DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs)
7+
8+
struct RangeAndLinked
9+
# indices that the variable corresponds to in the vectorised parameter
10+
range::UnitRange{Int}
11+
# whether it's linked
12+
is_linked::Bool
13+
end
14+
15+
struct FastLDFContext{T<:AbstractVector{<:Real}} <: AbstractContext
16+
varname_ranges::Dict{VarName,RangeAndLinked}
17+
params::T
18+
end
19+
DynamicPPL.NodeTrait(::FastLDFContext) = IsLeaf()
20+
21+
function tilde_assume!!(
22+
ctx::FastLDFContext, right::Distribution, vn::VarName, vi::OnlyAccsVarInfo
23+
)
24+
# Don't need to read the data from the varinfo at all since it's
25+
# all inside the context.
26+
range_and_linked = ctx.varname_ranges[vn]
27+
y = @view ctx.params[range_and_linked.range]
28+
is_linked = range_and_linked.is_linked
29+
f = if is_linked
30+
from_linked_vec_transform(right)
31+
else
32+
from_vec_transform(right)
33+
end
34+
x, inv_logjac = with_logabsdet_jacobian(f, y)
35+
vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right)
36+
return x, vi
37+
end
38+
39+
function tilde_observe!!(
40+
::FastLDFContext,
41+
right::Distribution,
42+
left,
43+
vn::Union{VarName,Nothing},
44+
vi::OnlyAccsVarInfo,
45+
)
46+
# This is the same as for DefaultContext
47+
vi = accumulate_observe!!(vi, right, left, vn)
48+
return left, vi
49+
end
50+
51+
struct FastLDF{M<:Model,F<:Function}
52+
_model::M
53+
_getlogdensity::F
54+
_varname_ranges::Dict{VarName,RangeAndLinked}
55+
56+
function FastLDF(
57+
model::Model,
58+
getlogdensity::Function,
59+
# This only works with typed Metadata-varinfo.
60+
# Obviously, this can be generalised later.
61+
varinfo::VarInfo{<:NamedTuple{syms}},
62+
) where {syms}
63+
# Figure out which variable corresponds to which index, and
64+
# which variables are linked.
65+
all_ranges = Dict{VarName,RangeAndLinked}()
66+
offset = 1
67+
for sym in syms
68+
md = varinfo.metadata[sym]
69+
for (vn, idx) in md.idcs
70+
len = length(md.ranges[idx])
71+
is_linked = md.is_transformed[idx]
72+
range = offset:(offset + len - 1)
73+
all_ranges[vn] = RangeAndLinked(range, is_linked)
74+
offset += len
75+
end
76+
end
77+
return new{typeof(model),typeof(getlogdensity)}(model, getlogdensity, all_ranges)
78+
end
79+
end
80+
81+
function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real})
82+
ctx = FastLDFContext(fldf._varname_ranges, params)
83+
model = DynamicPPL.setleafcontext(fldf._model, ctx)
84+
# This can obviously also be optimised for the case where not
85+
# all accumulators are needed.
86+
accs = AccumulatorTuple((
87+
LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator()
88+
))
89+
_, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs))
90+
return fldf._getlogdensity(vi)
91+
end

0 commit comments

Comments
 (0)