Skip to content

Commit 5ed4295

Browse files
committed
Make it work with AD
1 parent 7cddac7 commit 5ed4295

File tree

1 file changed

+49
-7
lines changed

1 file changed

+49
-7
lines changed

src/fastldf.jl

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,25 @@ function tilde_observe!!(
4848
return left, vi
4949
end
5050

51-
struct FastLDF{M<:Model,F<:Function}
51+
struct FastLDF{
52+
M<:Model,
53+
F<:Function,
54+
AD<:Union{ADTypes.AbstractADType,Nothing},
55+
ADP<:Union{Nothing,DI.GradientPrep},
56+
}
5257
_model::M
5358
_getlogdensity::F
5459
_varname_ranges::Dict{VarName,RangeAndLinked}
60+
_adtype::AD
61+
_adprep::ADP
5562

5663
function FastLDF(
5764
model::Model,
5865
getlogdensity::Function,
5966
# This only works with typed Metadata-varinfo.
6067
# Obviously, this can be generalised later.
61-
varinfo::VarInfo{<:NamedTuple{syms}},
68+
varinfo::VarInfo{<:NamedTuple{syms}};
69+
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
6270
) where {syms}
6371
# Figure out which variable corresponds to which index, and
6472
# which variables are linked.
@@ -74,18 +82,52 @@ struct FastLDF{M<:Model,F<:Function}
7482
offset += len
7583
end
7684
end
77-
return new{typeof(model),typeof(getlogdensity)}(model, getlogdensity, all_ranges)
85+
# Do AD prep if needed
86+
prep = if adtype === nothing
87+
nothing
88+
else
89+
# Make backend-specific tweaks to the adtype
90+
adtype = tweak_adtype(adtype, model, varinfo)
91+
x = [val for val in varinfo[:]]
92+
DI.prepare_gradient(
93+
FastLogDensityAt(model, getlogdensity, all_ranges), adtype, x
94+
)
95+
end
96+
97+
return new{typeof(model),typeof(getlogdensity),typeof(adtype),typeof(prep)}(
98+
model, getlogdensity, all_ranges, adtype, prep
99+
)
78100
end
79101
end
80102

81-
function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real})
82-
ctx = FastLDFContext(fldf._varname_ranges, params)
83-
model = DynamicPPL.setleafcontext(fldf._model, ctx)
103+
struct FastLogDensityAt{M<:Model,F<:Function}
104+
_model::M
105+
_getlogdensity::F
106+
_varname_ranges::Dict{VarName,RangeAndLinked}
107+
end
108+
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
109+
ctx = FastLDFContext(f._varname_ranges, params)
110+
model = DynamicPPL.setleafcontext(f._model, ctx)
84111
# This can obviously also be optimised for the case where not
85112
# all accumulators are needed.
86113
accs = AccumulatorTuple((
87114
LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator()
88115
))
89116
_, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs))
90-
return fldf._getlogdensity(vi)
117+
return f._getlogdensity(vi)
118+
end
119+
120+
function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real})
121+
return FastLogDensityAt(fldf._model, fldf._getlogdensity, fldf._varname_ranges)(params)
122+
end
123+
124+
function LogDensityProblems.logdensity_and_gradient(
125+
fldf::FastLDF, params::AbstractVector{<:Real}
126+
)
127+
return DI.value_and_gradient(
128+
FastLogDensityAt(fldf._model, fldf._getlogdensity, fldf._varname_ranges),
129+
fldf._adprep,
130+
fldf._adtype,
131+
params,
132+
)
91133
end

0 commit comments

Comments
 (0)