Skip to content

Commit 81282de

Browse files
committed
Simplify OptimLogDensity construction
1 parent 46014a9 commit 81282de

File tree

1 file changed

+18
-61
lines changed

1 file changed

+18
-61
lines changed

src/optimisation/Optimisation.jl

Lines changed: 18 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -43,79 +43,35 @@ Concrete type for maximum a posteriori estimation.
4343
struct MAP <: ModeEstimator end
4444

4545
"""
46-
OptimLogDensity{
47-
M<:DynamicPPL.Model,
48-
F<:Function,
49-
V<:DynamicPPL.AbstractVarInfo,
50-
AD<:ADTypes.AbstractADType,
51-
}
52-
53-
A struct that wraps a single LogDensityFunction. Can be invoked either using
54-
55-
```julia
56-
OptimLogDensity(model, varinfo; adtype=adtype)
57-
```
58-
59-
or
46+
OptimLogDensity{L<:DynamicPPL.LogDensityFunction}
6047
61-
```julia
62-
OptimLogDensity(model; adtype=adtype)
63-
```
48+
A struct that represents a log-density function, which can be used with Optimization.jl.
49+
This is a thin wrapper around `DynamicPPL.LogDensityFunction`: the main difference is that
50+
the log-density is negated (because Optimization.jl performs minimisation, and we usually
51+
want to maximise the log-density).
6452
65-
If not specified, `adtype` defaults to `AutoForwardDiff()`.
53+
An `OptimLogDensity` does not, in itself, obey the LogDensityProblems.jl interface. Thus, if
54+
you want to calculate the log density of its contents at the point `z`, you should manually
55+
call `LogDensityProblems.logdensity(f.ldf, z)`, instead of `LogDensityProblems.logdensity(f,
56+
z)`.
6657
67-
An OptimLogDensity does not, in itself, obey the LogDensityProblems interface.
68-
Thus, if you want to calculate the log density of its contents at the point
69-
`z`, you should manually call
70-
71-
```julia
72-
LogDensityProblems.logdensity(f.ldf, z)
73-
```
74-
75-
However, it is a callable object which returns the *negative* log density of
76-
the underlying LogDensityFunction at the point `z`. This is done to satisfy
77-
the Optim.jl interface.
78-
79-
```julia
80-
optim_ld = OptimLogDensity(model, varinfo)
81-
optim_ld(z) # returns -logp
82-
```
58+
However, because Optimization.jl requires the objective function to be callable, you can
59+
also call `f(z)` directly to get the negative log density at `z`.
8360
"""
8461
struct OptimLogDensity{L<:DynamicPPL.LogDensityFunction}
8562
ldf::L
86-
87-
function OptimLogDensity(
88-
model::DynamicPPL.Model,
89-
getlogdensity::Function,
90-
vi::DynamicPPL.AbstractVarInfo;
91-
adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE,
92-
)
93-
ldf = DynamicPPL.LogDensityFunction(model, getlogdensity, vi; adtype=adtype)
94-
return new{typeof(ldf)}(ldf)
95-
end
96-
function OptimLogDensity(
97-
model::DynamicPPL.Model,
98-
getlogdensity::Function;
99-
adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE,
100-
)
101-
# No varinfo
102-
return OptimLogDensity(
103-
model,
104-
getlogdensity,
105-
DynamicPPL.ldf_default_varinfo(model, getlogdensity);
106-
adtype=adtype,
107-
)
108-
end
10963
end
11064

11165
"""
11266
(f::OptimLogDensity)(z)
11367
(f::OptimLogDensity)(z, _)
11468
115-
Evaluate the negative log joint or log likelihood at the array `z`. Which one is evaluated
116-
depends on the context of `f`.
69+
Evaluate the negative log probability density at the array `z`. Which kind of probability
70+
density is evaluated depends on the `getlogdensity` function used to construct the
71+
underlying `LogDensityFunction` (e.g., `DynamicPPL.getlogjoint` for MAP estimation, or
72+
`DynamicPPL.getloglikelihood` for MLE).
11773
118-
Any second argument is ignored. The two-argument method only exists to match interface the
74+
Any second argument is ignored. The two-argument method only exists to match the interface
11975
required by Optimization.jl.
12076
"""
12177
(f::OptimLogDensity)(z::AbstractVector) = -LogDensityProblems.logdensity(f.ldf, z)
@@ -515,7 +471,8 @@ function estimate_mode(
515471
vi = DynamicPPL.link(vi, model)
516472
end
517473

518-
log_density = OptimLogDensity(model, getlogdensity, vi)
474+
ldf = DynamicPPL.LogDensityFunction(model, getlogdensity, vi; adtype=adtype)
475+
log_density = OptimLogDensity(ldf)
519476

520477
prob = Optimization.OptimizationProblem(log_density, adtype, constraints)
521478
solution = Optimization.solve(prob, solver; kwargs...)

0 commit comments

Comments
 (0)