Skip to content

Commit 676f11e

Browse files
committed
Simplify OptimLogDensity construction
1 parent 6bceedb commit 676f11e

File tree

1 file changed

+18
-61
lines changed

1 file changed

+18
-61
lines changed

src/optimisation/Optimisation.jl

Lines changed: 18 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -44,79 +44,35 @@ Concrete type for maximum a posteriori estimation.
4444
struct MAP <: ModeEstimator end
4545

4646
"""
47-
OptimLogDensity{
48-
M<:DynamicPPL.Model,
49-
F<:Function,
50-
V<:DynamicPPL.AbstractVarInfo,
51-
AD<:ADTypes.AbstractADType,
52-
}
53-
54-
A struct that wraps a single LogDensityFunction. Can be invoked either using
55-
56-
```julia
57-
OptimLogDensity(model, varinfo; adtype=adtype)
58-
```
59-
60-
or
47+
OptimLogDensity{L<:DynamicPPL.LogDensityFunction}
6148
62-
```julia
63-
OptimLogDensity(model; adtype=adtype)
64-
```
49+
A struct that represents a log-density function, which can be used with Optimization.jl.
50+
This is a thin wrapper around `DynamicPPL.LogDensityFunction`: the main difference is that
51+
the log-density is negated (because Optimization.jl performs minimisation, and we usually
52+
want to maximise the log-density).
6553
66-
If not specified, `adtype` defaults to `AutoForwardDiff()`.
54+
An `OptimLogDensity` does not, in itself, obey the LogDensityProblems.jl interface. Thus, if
55+
you want to calculate the log density of its contents at the point `z`, you should manually
56+
call `LogDensityProblems.logdensity(f.ldf, z)`, instead of `LogDensityProblems.logdensity(f,
57+
z)`.
6758
68-
An OptimLogDensity does not, in itself, obey the LogDensityProblems interface.
69-
Thus, if you want to calculate the log density of its contents at the point
70-
`z`, you should manually call
71-
72-
```julia
73-
LogDensityProblems.logdensity(f.ldf, z)
74-
```
75-
76-
However, it is a callable object which returns the *negative* log density of
77-
the underlying LogDensityFunction at the point `z`. This is done to satisfy
78-
the Optim.jl interface.
79-
80-
```julia
81-
optim_ld = OptimLogDensity(model, varinfo)
82-
optim_ld(z) # returns -logp
83-
```
59+
However, because Optimization.jl requires the objective function to be callable, you can
60+
also call `f(z)` directly to get the negative log density at `z`.
8461
"""
8562
struct OptimLogDensity{L<:DynamicPPL.LogDensityFunction}
8663
ldf::L
87-
88-
function OptimLogDensity(
89-
model::DynamicPPL.Model,
90-
getlogdensity::Function,
91-
vi::DynamicPPL.AbstractVarInfo;
92-
adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE,
93-
)
94-
ldf = DynamicPPL.LogDensityFunction(model, getlogdensity, vi; adtype=adtype)
95-
return new{typeof(ldf)}(ldf)
96-
end
97-
function OptimLogDensity(
98-
model::DynamicPPL.Model,
99-
getlogdensity::Function;
100-
adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE,
101-
)
102-
# No varinfo
103-
return OptimLogDensity(
104-
model,
105-
getlogdensity,
106-
DynamicPPL.ldf_default_varinfo(model, getlogdensity);
107-
adtype=adtype,
108-
)
109-
end
11064
end
11165

11266
"""
11367
(f::OptimLogDensity)(z)
11468
(f::OptimLogDensity)(z, _)
11569
116-
Evaluate the negative log joint or log likelihood at the array `z`. Which one is evaluated
117-
depends on the context of `f`.
70+
Evaluate the negative log probability density at the array `z`. Which kind of probability
71+
density is evaluated depends on the `getlogdensity` function used to construct the
72+
underlying `LogDensityFunction` (e.g., `DynamicPPL.getlogjoint` for MAP estimation, or
73+
`DynamicPPL.getloglikelihood` for MLE).
11874
119-
Any second argument is ignored. The two-argument method only exists to match interface the
75+
Any second argument is ignored. The two-argument method only exists to match the interface
12076
required by Optimization.jl.
12177
"""
12278
(f::OptimLogDensity)(z::AbstractVector) = -LogDensityProblems.logdensity(f.ldf, z)
@@ -540,7 +496,8 @@ function estimate_mode(
540496
vi = DynamicPPL.link(vi, model)
541497
end
542498

543-
log_density = OptimLogDensity(model, getlogdensity, vi)
499+
ldf = DynamicPPL.LogDensityFunction(model, getlogdensity, vi; adtype=adtype)
500+
log_density = OptimLogDensity(ldf)
544501

545502
prob = Optimization.OptimizationProblem(log_density, adtype, constraints)
546503
solution = Optimization.solve(prob, solver; kwargs...)

0 commit comments

Comments
 (0)