@@ -48,17 +48,25 @@ function tilde_observe!!(
4848 return left, vi
4949end
5050
51- struct FastLDF{M<: Model ,F<: Function }
51+ struct FastLDF{
52+ M<: Model ,
53+ F<: Function ,
54+ AD<: Union{ADTypes.AbstractADType,Nothing} ,
55+ ADP<: Union{Nothing,DI.GradientPrep} ,
56+ }
5257 _model:: M
5358 _getlogdensity:: F
5459 _varname_ranges:: Dict{VarName,RangeAndLinked}
60+ _adtype:: AD
61+ _adprep:: ADP
5562
5663 function FastLDF (
5764 model:: Model ,
5865 getlogdensity:: Function ,
5966 # This only works with typed Metadata-varinfo.
6067 # Obviously, this can be generalised later.
61- varinfo:: VarInfo{<:NamedTuple{syms}} ,
68+ varinfo:: VarInfo{<:NamedTuple{syms}} ;
69+ adtype:: Union{ADTypes.AbstractADType,Nothing} = nothing ,
6270 ) where {syms}
6371 # Figure out which variable corresponds to which index, and
6472 # which variables are linked.
@@ -74,18 +82,52 @@ struct FastLDF{M<:Model,F<:Function}
7482 offset += len
7583 end
7684 end
77- return new {typeof(model),typeof(getlogdensity)} (model, getlogdensity, all_ranges)
85+ # Do AD prep if needed
86+ prep = if adtype === nothing
87+ nothing
88+ else
89+ # Make backend-specific tweaks to the adtype
90+ adtype = tweak_adtype (adtype, model, varinfo)
91+ x = [val for val in varinfo[:]]
92+ DI. prepare_gradient (
93+ FastLogDensityAt (model, getlogdensity, all_ranges), adtype, x
94+ )
95+ end
96+
97+ return new {typeof(model),typeof(getlogdensity),typeof(adtype),typeof(prep)} (
98+ model, getlogdensity, all_ranges, adtype, prep
99+ )
78100 end
79101end
80102
81- function LogDensityProblems. logdensity (fldf:: FastLDF , params:: AbstractVector{<:Real} )
82- ctx = FastLDFContext (fldf. _varname_ranges, params)
83- model = DynamicPPL. setleafcontext (fldf. _model, ctx)
103+ struct FastLogDensityAt{M<: Model ,F<: Function }
104+ _model:: M
105+ _getlogdensity:: F
106+ _varname_ranges:: Dict{VarName,RangeAndLinked}
107+ end
108+ function (f:: FastLogDensityAt )(params:: AbstractVector{<:Real} )
109+ ctx = FastLDFContext (f. _varname_ranges, params)
110+ model = DynamicPPL. setleafcontext (f. _model, ctx)
84111 # This can obviously also be optimised for the case where not
85112 # all accumulators are needed.
86113 accs = AccumulatorTuple ((
87114 LogPriorAccumulator (), LogLikelihoodAccumulator (), LogJacobianAccumulator ()
88115 ))
89116 _, vi = DynamicPPL. _evaluate!! (model, OnlyAccsVarInfo (accs))
90- return fldf. _getlogdensity (vi)
117+ return f. _getlogdensity (vi)
118+ end
119+
120+ function LogDensityProblems. logdensity (fldf:: FastLDF , params:: AbstractVector{<:Real} )
121+ return FastLogDensityAt (fldf. _model, fldf. _getlogdensity, fldf. _varname_ranges)(params)
122+ end
123+
124+ function LogDensityProblems. logdensity_and_gradient (
125+ fldf:: FastLDF , params:: AbstractVector{<:Real}
126+ )
127+ return DI. value_and_gradient (
128+ FastLogDensityAt (fldf. _model, fldf. _getlogdensity, fldf. _varname_ranges),
129+ fldf. _adprep,
130+ fldf. _adtype,
131+ params,
132+ )
91133end
0 commit comments