Skip to content

Commit ea9bb54

Browse files
committed
Merge remote-tracking branch 'origin/main' into breaking
2 parents 7835d7f + 4153a83 commit ea9bb54

File tree

4 files changed

+83
-7
lines changed

4 files changed

+83
-7
lines changed

HISTORY.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,24 @@ As long as the above functions are defined correctly, Turing will be able to use
1818

1919
The `Turing.Inference.isgibbscomponent(::MySampler)` interface function still exists, but in this version the default has been changed to `true`, so you should not need to overload this.
2020

21+
# 0.41.1
22+
23+
The `ModeResult` struct returned by `maximum_a_posteriori` and `maximum_likelihood` can now be wrapped in `InitFromParams()`.
24+
This makes it easier to use the parameters in downstream code, e.g. when specifying initial parameters for MCMC sampling.
25+
For example:
26+
27+
```julia
28+
@model function f()
29+
# ...
30+
end
31+
model = f()
32+
opt_result = maximum_a_posteriori(model)
33+
34+
sample(model, NUTS(), 1000; initial_params=InitFromParams(opt_result))
35+
```
36+
37+
If you need to access the dictionary of parameters, it is stored in `opt_result.params` but note that this field may change in future breaking releases as that Turing's optimisation interface is slated for overhaul in the near future.
38+
2139
# 0.41.0
2240

2341
## DynamicPPL 0.38

ext/TuringOptimExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ function _optimize(
192192
varnames = map(Symbol first, vns_vals_iter)
193193
vals = map(last, vns_vals_iter)
194194
vmat = NamedArrays.NamedArray(vals, varnames)
195-
return Optimisation.ModeResult(vmat, M, -M.minimum, logdensity_optimum)
195+
return Optimisation.ModeResult(vmat, M, -M.minimum, logdensity_optimum, vals_dict)
196196
end
197197

198198
end # module

src/optimisation/Optimisation.jl

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using ..Turing
44
using NamedArrays: NamedArrays
55
using AbstractPPL: AbstractPPL
66
using DynamicPPL: DynamicPPL
7+
using DocStringExtensions: TYPEDFIELDS
78
using LogDensityProblems: LogDensityProblems
89
using Optimization: Optimization
910
using OptimizationOptimJL: OptimizationOptimJL
@@ -154,13 +155,22 @@ end
154155
V<:NamedArrays.NamedArray,
155156
M<:NamedArrays.NamedArray,
156157
O<:Optim.MultivariateOptimizationResults,
157-
S<:NamedArrays.NamedArray
158+
S<:NamedArrays.NamedArray,
159+
P<:AbstractDict{<:VarName,<:Any}
158160
}
159161
160162
A wrapper struct to store various results from a MAP or MLE estimation.
163+
164+
## Fields
165+
166+
$(TYPEDFIELDS)
161167
"""
162-
struct ModeResult{V<:NamedArrays.NamedArray,O<:Any,M<:OptimLogDensity} <:
163-
StatsBase.StatisticalModel
168+
struct ModeResult{
169+
V<:NamedArrays.NamedArray,
170+
O<:Any,
171+
M<:OptimLogDensity,
172+
P<:AbstractDict{<:AbstractPPL.VarName,<:Any},
173+
} <: StatsBase.StatisticalModel
164174
"A vector with the resulting point estimates."
165175
values::V
166176
"The stored optimiser results."
@@ -169,6 +179,8 @@ struct ModeResult{V<:NamedArrays.NamedArray,O<:Any,M<:OptimLogDensity} <:
169179
lp::Float64
170180
"The evaluation function used to calculate the output."
171181
f::M
182+
"Dictionary of parameter values"
183+
params::P
172184
end
173185

174186
function Base.show(io::IO, ::MIME"text/plain", m::ModeResult)
@@ -182,6 +194,15 @@ function Base.show(io::IO, m::ModeResult)
182194
return show(io, m.values.array)
183195
end
184196

197+
"""
198+
InitFromParams(m::ModeResult)
199+
200+
Initialize a model from the parameters stored in a `ModeResult`.
201+
"""
202+
function DynamicPPL.InitFromParams(m::ModeResult)
203+
return DynamicPPL.InitFromParams(m.params)
204+
end
205+
185206
# Various StatsBase methods for ModeResult
186207

187208
"""
@@ -355,9 +376,13 @@ function ModeResult(log_density::OptimLogDensity, solution::SciMLBase.Optimizati
355376
iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals))
356377
vns_vals_iter = mapreduce(collect, vcat, iters)
357378
syms = map(Symbol first, vns_vals_iter)
358-
vals = map(last, vns_vals_iter)
379+
split_vals = map(last, vns_vals_iter)
359380
return ModeResult(
360-
NamedArrays.NamedArray(vals, syms), solution, -solution.objective, log_density
381+
NamedArrays.NamedArray(split_vals, syms),
382+
solution,
383+
-solution.objective,
384+
log_density,
385+
vals,
361386
)
362387
end
363388

test/optimisation/Optimisation.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,13 @@ using Turing
101101
@test result.optim_result.retcode == Optimization.ReturnCode.Success
102102
end
103103
@test isapprox(result.lp, true_logp, atol=0.01)
104+
# check that the parameter dict matches the NamedArray
105+
# NOTE: This test only works for models where all parameters are identity
106+
# varnames AND real-valued. Thankfully, this is true for `gdemo`.
107+
@test length(only(result.values.dicts)) == length(result.params)
108+
for (k, index) in only(result.values.dicts)
109+
@test result.params[AbstractPPL.VarName{k}()] == result.values.array[index]
110+
end
104111
end
105112

106113
@testset "MLE" begin
@@ -546,6 +553,26 @@ using Turing
546553
end
547554
end
548555

556+
@testset "using ModeResult to initialise MCMC" begin
557+
@model function f(y)
558+
μ ~ Normal(0, 1)
559+
σ ~ Gamma(2, 1)
560+
return y ~ Normal(μ, σ)
561+
end
562+
model = f(randn(10))
563+
mle = maximum_likelihood(model)
564+
# TODO(penelopeysm): This relies on the fact that HMC does indeed
565+
# use the initial_params passed to it. We should use something
566+
# like a StaticSampler (see test/mcmc/Inference) to make this more
567+
# robust.
568+
chain = sample(
569+
model, HMC(0.1, 10), 2; initial_params=InitFromParams(mle), num_warmup=0
570+
)
571+
# Check that those parameters were indeed used as initial params
572+
@test chain[][1] == mle.params[@varname(µ)]
573+
@test chain[][1] == mle.params[@varname(σ)]
574+
end
575+
549576
# Issue: https://discourse.julialang.org/t/turing-mixture-models-with-dirichlet-weightings/112910
550577
@testset "Optimization with different linked dimensionality" begin
551578
@model demo_dirichlet() = x ~ Dirichlet(2 * ones(3))
@@ -621,7 +648,13 @@ using Turing
621648
m = saddle_model()
622649
optim_ld = Turing.Optimisation.OptimLogDensity(m, DynamicPPL.getloglikelihood)
623650
vals = Turing.Optimisation.NamedArrays.NamedArray([0.0, 0.0])
624-
m = Turing.Optimisation.ModeResult(vals, nothing, 0.0, optim_ld)
651+
m = Turing.Optimisation.ModeResult(
652+
vals,
653+
nothing,
654+
0.0,
655+
optim_ld,
656+
Dict{AbstractPPL.VarName,Float64}(@varname(x) => 0.0, @varname(y) => 0.0),
657+
)
625658
ct = coeftable(m)
626659
@assert isnan(ct.cols[2][1])
627660
@assert ct.colnms[end] == "Error notes"

0 commit comments

Comments
 (0)