@@ -4,6 +4,7 @@ using ..Turing
44using NamedArrays: NamedArrays
55using AbstractPPL: AbstractPPL
66using DynamicPPL: DynamicPPL
7+ using DocStringExtensions: TYPEDFIELDS
78using LogDensityProblems: LogDensityProblems
89using Optimization: Optimization
910using OptimizationOptimJL: OptimizationOptimJL
@@ -154,13 +155,22 @@ end
154155 V<:NamedArrays.NamedArray,
155156 M<:NamedArrays.NamedArray,
156157 O<:Optim.MultivariateOptimizationResults,
157- S<:NamedArrays.NamedArray
158+ S<:NamedArrays.NamedArray,
159+ P<:AbstractDict{<:VarName,<:Any}
158160 }
159161
160162A wrapper struct to store various results from a MAP or MLE estimation.
163+
164+ ## Fields
165+
166+ $(TYPEDFIELDS)
161167"""
162- struct ModeResult{V<: NamedArrays.NamedArray ,O<: Any ,M<: OptimLogDensity } < :
163- StatsBase. StatisticalModel
168+ struct ModeResult{
169+ V<: NamedArrays.NamedArray ,
170+ O<: Any ,
171+ M<: OptimLogDensity ,
172+ P<: AbstractDict{<:AbstractPPL.VarName,<:Any} ,
173+ } <: StatsBase.StatisticalModel
164174 " A vector with the resulting point estimates."
165175 values:: V
166176 " The stored optimiser results."
@@ -169,6 +179,8 @@ struct ModeResult{V<:NamedArrays.NamedArray,O<:Any,M<:OptimLogDensity} <:
169179 lp:: Float64
170180 " The evaluation function used to calculate the output."
171181 f:: M
182+ " Dictionary of parameter values"
183+ params:: P
172184end
173185
174186function Base. show (io:: IO , :: MIME"text/plain" , m:: ModeResult )
@@ -182,6 +194,15 @@ function Base.show(io::IO, m::ModeResult)
182194 return show (io, m. values. array)
183195end
184196
197+ """
198+ InitFromParams(m::ModeResult)
199+
200+ Initialize a model from the parameters stored in a `ModeResult`.
201+ """
202+ function DynamicPPL. InitFromParams (m:: ModeResult )
203+ return DynamicPPL. InitFromParams (m. params)
204+ end
205+
185206# Various StatsBase methods for ModeResult
186207
187208"""
@@ -355,9 +376,13 @@ function ModeResult(log_density::OptimLogDensity, solution::SciMLBase.Optimizati
355376 iters = map (AbstractPPL. varname_and_value_leaves, keys (vals), values (vals))
356377 vns_vals_iter = mapreduce (collect, vcat, iters)
357378 syms = map (Symbol ∘ first, vns_vals_iter)
358- vals = map (last, vns_vals_iter)
379+ split_vals = map (last, vns_vals_iter)
359380 return ModeResult (
360- NamedArrays. NamedArray (vals, syms), solution, - solution. objective, log_density
381+ NamedArrays. NamedArray (split_vals, syms),
382+ solution,
383+ - solution. objective,
384+ log_density,
385+ vals,
361386 )
362387end
363388
0 commit comments