@@ -18,7 +18,8 @@ is_supported(::ADTypes.AutoReverseDiff) = true
1818"""
1919 LogDensityFunction(
2020 model::Model,
21- varinfo::AbstractVarInfo=VarInfo(model);
21+ getlogdensity::Function=getlogjoint,
22+ varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity);
2223 adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
2324 )
2425
@@ -28,9 +29,10 @@ A struct which contains a model, along with all the information necessary to:
2829 - and if `adtype` is provided, calculate the gradient of the log density at
2930 that point.
3031
31- At its most basic level, a LogDensityFunction wraps the model together with the
32- type of varinfo to be used. These must be known in order to calculate the log
33- density (using [`DynamicPPL.evaluate!!`](@ref)).
32+ At its most basic level, a LogDensityFunction wraps the model together with a
33+ function that specifies how to extract the log density, and the type of
34+ VarInfo to be used. These must be known in order to calculate the log density
35+ (using [`DynamicPPL.evaluate!!`](@ref)).
3436
3537If the `adtype` keyword argument is provided, then this struct will also store
3638the adtype along with other information for efficient calculation of the
@@ -72,13 +74,13 @@ julia> LogDensityProblems.dimension(f)
72741
7375
7476julia> # By default it uses `VarInfo` under the hood, but this is not necessary.
75- f = LogDensityFunction(model, SimpleVarInfo(model));
77+ f = LogDensityFunction(model, getlogjoint, SimpleVarInfo(model));
7678
7779julia> LogDensityProblems.logdensity(f, [0.0])
7880-2.3378770664093453
7981
80- julia> # LogDensityFunction respects the accumulators in VarInfo :
81- f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) );
82+ julia> # One can also specify evaluating e.g. the log prior only :
83+ f_prior = LogDensityFunction(model, getlogprior );
8284
8385julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
8486true
@@ -93,11 +95,13 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
9395```
9496"""
9597struct LogDensityFunction{
96- M<: Model ,V<: AbstractVarInfo ,AD<: Union{Nothing,ADTypes.AbstractADType}
98+ M<: Model ,F <: Function , V<: AbstractVarInfo ,AD<: Union{Nothing,ADTypes.AbstractADType}
9799} <: AbstractModel
98100 " model used for evaluation"
99101 model:: M
100- " varinfo used for evaluation"
102+ " function to be called on `varinfo` to extract the log density. By default `getlogjoint`."
103+ getlogdensity:: F
104+ " varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`."
101105 varinfo:: V
102106 " AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
103107 adtype:: AD
@@ -106,7 +110,8 @@ struct LogDensityFunction{
106110
107111 function LogDensityFunction (
108112 model:: Model ,
109- varinfo:: AbstractVarInfo = VarInfo (model);
113+ getlogdensity:: Function = getlogjoint,
114+ varinfo:: AbstractVarInfo = ldf_default_varinfo (model, getlogdensity);
110115 adtype:: Union{ADTypes.AbstractADType,Nothing} = nothing ,
111116 )
112117 if adtype === nothing
@@ -120,15 +125,22 @@ struct LogDensityFunction{
120125 # Get a set of dummy params to use for prep
121126 x = map (identity, varinfo[:])
122127 if use_closure (adtype)
123- prep = DI. prepare_gradient (LogDensityAt (model, varinfo), adtype, x)
128+ prep = DI. prepare_gradient (
129+ LogDensityAt (model, getlogdensity, varinfo), adtype, x
130+ )
124131 else
125132 prep = DI. prepare_gradient (
126- logdensity_at, adtype, x, DI. Constant (model), DI. Constant (varinfo)
133+ logdensity_at,
134+ adtype,
135+ x,
136+ DI. Constant (model),
137+ DI. Constant (getlogdensity),
138+ DI. Constant (varinfo),
127139 )
128140 end
129141 end
130- return new {typeof(model),typeof(varinfo),typeof(adtype)} (
131- model, varinfo, adtype, prep
142+ return new {typeof(model),typeof(getlogdensity),typeof( varinfo),typeof(adtype)} (
143+ model, getlogdensity, varinfo, adtype, prep
132144 )
133145 end
134146end
@@ -149,83 +161,112 @@ function LogDensityFunction(
149161 return if adtype === f. adtype
150162 f # Avoid recomputing prep if not needed
151163 else
152- LogDensityFunction (f. model, f. varinfo; adtype= adtype)
164+ LogDensityFunction (f. model, f. getlogdensity, f . varinfo; adtype= adtype)
153165 end
154166end
155167
168+ """
169+ ldf_default_varinfo(model::Model, getlogdensity::Function)
170+
171+ Create the default AbstractVarInfo that should be used for evaluating the log density.
172+
173+ Only the accumulators necesessary for `getlogdensity` will be used.
174+ """
175+ function ldf_default_varinfo (:: Model , getlogdensity:: Function )
176+ msg = """
177+ LogDensityFunction does not know what sort of VarInfo should be used when \
178+ `getlogdensity` is $getlogdensity . Please specify a VarInfo explicitly.
179+ """
180+ return error (msg)
181+ end
182+
183+ ldf_default_varinfo (model:: Model , :: typeof (getlogjoint)) = VarInfo (model)
184+
185+ function ldf_default_varinfo (model:: Model , :: typeof (getlogprior))
186+ return setaccs!! (VarInfo (model), (LogPriorAccumulator (),))
187+ end
188+
189+ function ldf_default_varinfo (model:: Model , :: typeof (getloglikelihood))
190+ return setaccs!! (VarInfo (model), (LogLikelihoodAccumulator (),))
191+ end
192+
156193"""
157194 logdensity_at(
158195 x::AbstractVector,
159196 model::Model,
197+ getlogdensity::Function,
160198 varinfo::AbstractVarInfo,
161199 )
162200
163- Evaluate the log density of the given `model` at the given parameter values `x`,
164- using the given `varinfo`. Note that the `varinfo` argument is provided only
165- for its structure, in the sense that the parameters from the vector `x` are
166- inserted into it, and its own parameters are discarded. It does, however,
167- determine whether the log prior, likelihood, or joint is returned, based on
168- which accumulators are set in it.
201+ Evaluate the log density of the given `model` at the given parameter values
202+ `x`, using the given `varinfo`. Note that the `varinfo` argument is provided
203+ only for its structure, in the sense that the parameters from the vector `x`
204+ are inserted into it, and its own parameters are discarded. `getlogdensity` is
205+ the function that extracts the log density from the evaluated varinfo.
169206"""
170- function logdensity_at (x:: AbstractVector , model:: Model , varinfo:: AbstractVarInfo )
207+ function logdensity_at (
208+ x:: AbstractVector , model:: Model , getlogdensity:: Function , varinfo:: AbstractVarInfo
209+ )
171210 varinfo_new = unflatten (varinfo, x)
172211 varinfo_eval = last (evaluate!! (model, varinfo_new))
173- has_prior = hasacc (varinfo_eval, Val (:LogPrior ))
174- has_likelihood = hasacc (varinfo_eval, Val (:LogLikelihood ))
175- if has_prior && has_likelihood
176- return getlogjoint (varinfo_eval)
177- elseif has_prior
178- return getlogprior (varinfo_eval)
179- elseif has_likelihood
180- return getloglikelihood (varinfo_eval)
181- else
182- error (" LogDensityFunction: varinfo tracks neither log prior nor log likelihood" )
183- end
212+ return getlogdensity (varinfo_eval)
184213end
185214
186215"""
187- LogDensityAt{M<:Model,V<:AbstractVarInfo}(
216+ LogDensityAt{M<:Model,F<:Function, V<:AbstractVarInfo}(
188217 model::M
218+ getlogdensity::F,
189219 varinfo::V
190220 )
191221
192222A callable struct that serves the same purpose as `x -> logdensity_at(x, model,
193- varinfo)`.
223+ getlogdensity, varinfo)`.
194224"""
195- struct LogDensityAt{M<: Model ,V<: AbstractVarInfo }
225+ struct LogDensityAt{M<: Model ,F <: Function , V<: AbstractVarInfo }
196226 model:: M
227+ getlogdensity:: F
197228 varinfo:: V
198229end
199- (ld:: LogDensityAt )(x:: AbstractVector ) = logdensity_at (x, ld. model, ld. varinfo)
230+ function (ld:: LogDensityAt )(x:: AbstractVector )
231+ return logdensity_at (x, ld. model, ld. getlogdensity, ld. varinfo)
232+ end
200233
201234# ## LogDensityProblems interface
202235
203236function LogDensityProblems. capabilities (
204- :: Type{<:LogDensityFunction{M,V,Nothing}}
205- ) where {M,V}
237+ :: Type{<:LogDensityFunction{M,F, V,Nothing}}
238+ ) where {M,F, V}
206239 return LogDensityProblems. LogDensityOrder {0} ()
207240end
208241function LogDensityProblems. capabilities (
209- :: Type{<:LogDensityFunction{M,V,AD}}
210- ) where {M,V,AD<: ADTypes.AbstractADType }
242+ :: Type{<:LogDensityFunction{M,F, V,AD}}
243+ ) where {M,F, V,AD<: ADTypes.AbstractADType }
211244 return LogDensityProblems. LogDensityOrder {1} ()
212245end
213246function LogDensityProblems. logdensity (f:: LogDensityFunction , x:: AbstractVector )
214- return logdensity_at (x, f. model, f. varinfo)
247+ return logdensity_at (x, f. model, f. getlogdensity, f . varinfo)
215248end
216249function LogDensityProblems. logdensity_and_gradient (
217- f:: LogDensityFunction{M,V,AD} , x:: AbstractVector
218- ) where {M,V,AD<: ADTypes.AbstractADType }
250+ f:: LogDensityFunction{M,F, V,AD} , x:: AbstractVector
251+ ) where {M,F, V,AD<: ADTypes.AbstractADType }
219252 f. prep === nothing &&
220253 error (" Gradient preparation not available; this should not happen" )
221254 x = map (identity, x) # Concretise type
222255 # Make branching statically inferrable, i.e. type-stable (even if the two
223256 # branches happen to return different types)
224257 return if use_closure (f. adtype)
225- DI. value_and_gradient (LogDensityAt (f. model, f. varinfo), f. prep, f. adtype, x)
258+ DI. value_and_gradient (
259+ LogDensityAt (f. model, f. getlogdensity, f. varinfo), f. prep, f. adtype, x
260+ )
226261 else
227262 DI. value_and_gradient (
228- logdensity_at, f. prep, f. adtype, x, DI. Constant (f. model), DI. Constant (f. varinfo)
263+ logdensity_at,
264+ f. prep,
265+ f. adtype,
266+ x,
267+ DI. Constant (f. model),
268+ DI. Constant (f. getlogdensity),
269+ DI. Constant (f. varinfo),
229270 )
230271 end
231272end
@@ -264,9 +305,9 @@ There are two ways of dealing with this:
264305
2653061. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f)
266307
267- 2. Use a constant context. This lets us pass a two-argument function to
268- DifferentiationInterface, as long as we also give it the 'inactive argument'
269- (i.e. the model) wrapped in `DI.Constant`.
308+ 2. Use a constant DI.Context. This lets us pass a two-argument function to DI,
309+ as long as we also give it the 'inactive argument' (i.e. the model) wrapped
310+ in `DI.Constant`.
270311
271312The relative performance of the two approaches, however, depends on the AD
272313backend used. Some benchmarks are provided here:
@@ -292,7 +333,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model
292333Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
293334"""
294335function setmodel (f:: DynamicPPL.LogDensityFunction , model:: DynamicPPL.Model )
295- return LogDensityFunction (model, f. varinfo; adtype= f. adtype)
336+ return LogDensityFunction (model, f. getlogdensity, f . varinfo; adtype= f. adtype)
296337end
297338
298339"""
0 commit comments