6868
6969Return a default varinfo object for the given `model` and `sampler`.
7070
71+ The default method for this returns an empty NTVarInfo (i.e. 'typed varinfo').
72+
7173# Arguments
7274- `rng::Random.AbstractRNG`: Random number generator.
7375- `model::Model`: Model for which we want to create a varinfo object.
@@ -76,9 +78,10 @@ Return a default varinfo object for the given `model` and `sampler`.
7678# Returns
7779- `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`.
7880"""
79- function default_varinfo (rng:: Random.AbstractRNG , model:: Model , sampler:: AbstractSampler )
80- init_sampler = initialsampler (sampler)
81- return typed_varinfo (rng, model, init_sampler)
81+ function default_varinfo (:: Random.AbstractRNG , :: Model , :: AbstractSampler )
82+ # Note that variable values are unconditionally initialized later, so no
83+ # point putting them in now.
84+ return typed_varinfo (VarInfo ())
8285end
8386
8487function AbstractMCMC. sample (
@@ -96,24 +99,32 @@ function AbstractMCMC.sample(
9699 )
97100end
98101
99- # initial step: general interface for resuming and
102+ """
103+ init_strategy(sampler)
104+
105+ Define the initialisation strategy used for generating initial values when
106+ sampling with `sampler`. Defaults to `PriorInit()`, but can be overridden.
107+ """
108+ init_strategy (:: Sampler ) = PriorInit ()
109+
100110function AbstractMCMC. step (
101- rng:: Random.AbstractRNG , model:: Model , spl:: Sampler ; initial_params= nothing , kwargs...
111+ rng:: Random.AbstractRNG ,
112+ model:: Model ,
113+ spl:: Sampler ;
114+ initial_params:: AbstractInitStrategy = init_strategy (spl),
115+ kwargs... ,
102116)
103- # Sample initial values.
117+ # Generate the default varinfo (usually this just makes an empty VarInfo
118+ # with NamedTuple of Metadata).
104119 vi = default_varinfo (rng, model, spl)
105120
106- # Update the parameters if provided.
107- if initial_params != = nothing
108- vi = initialize_parameters!! (vi, initial_params, model)
109-
110- # Update joint log probability.
111- # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
112- # and https://github.com/TuringLang/Turing.jl/issues/1563
113- # to avoid that existing variables are resampled
114- vi = last (evaluate!! (model, vi))
115- end
121+ # Fill it with initial parameters. Note that, if `ParamsInit` is used, the
122+ # parameters provided must be in unlinked space (when inserted into the
123+ # varinfo, they will be adjusted to match the linking status of the
124+ # varinfo).
125+ _, vi = init!! (rng, model, vi, initial_params)
116126
127+ # Call the actual function that does the first step.
117128 return initialstep (rng, model, spl, vi; initial_params, kwargs... )
118129end
119130
@@ -131,110 +142,7 @@ loadstate(data) = data
131142
132143Default type of the chain of posterior samples from `sampler`.
133144"""
134- default_chain_type (sampler:: Sampler ) = Any
135-
136- """
137- initialsampler(sampler::Sampler)
138-
139- Return the sampler that is used for generating the initial parameters when sampling with
140- `sampler`.
141-
142- By default, it returns an instance of [`SampleFromPrior`](@ref).
143- """
144- initialsampler (spl:: Sampler ) = SampleFromPrior ()
145-
146- """
147- set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
148- set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
149-
150- Take the values inside `initial_params`, replace the corresponding values in
151- the given VarInfo object, and return a new VarInfo object with the updated values.
152-
153- This differs from `DynamicPPL.unflatten` in two ways:
154-
155- 1. It works with `NamedTuple` arguments.
156- 2. For the `AbstractVector` method, if any of the elements are missing, it will not
157- overwrite the original value in the VarInfo (it will just use the original
158- value instead).
159- """
160- function set_initial_values (varinfo:: AbstractVarInfo , initial_params:: AbstractVector )
161- throw (
162- ArgumentError (
163- " `initial_params` must be a vector of type `Union{Real,Missing}`. " *
164- " If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first." ,
165- ),
166- )
167- end
168-
169- function set_initial_values (
170- varinfo:: AbstractVarInfo , initial_params:: AbstractVector{<:Union{Real,Missing}}
171- )
172- flattened_param_vals = varinfo[:]
173- length (flattened_param_vals) == length (initial_params) || throw (
174- DimensionMismatch (
175- " Provided initial value size ($(length (initial_params)) ) doesn't match " *
176- " the model size ($(length (flattened_param_vals)) )." ,
177- ),
178- )
179-
180- # Update values that are provided.
181- for i in eachindex (initial_params)
182- x = initial_params[i]
183- if x != = missing
184- flattened_param_vals[i] = x
185- end
186- end
187-
188- # Update in `varinfo`.
189- new_varinfo = unflatten (varinfo, flattened_param_vals)
190- return new_varinfo
191- end
192-
193- function set_initial_values (varinfo:: AbstractVarInfo , initial_params:: NamedTuple )
194- varinfo = deepcopy (varinfo)
195- vars_in_varinfo = keys (varinfo)
196- for v in keys (initial_params)
197- vn = VarName {v} ()
198- if ! (vn in vars_in_varinfo)
199- for vv in vars_in_varinfo
200- if subsumes (vn, vv)
201- throw (
202- ArgumentError (
203- " The current model contains sub-variables of $v , such as ($vv ). " *
204- " Using NamedTuple for initial_params is not supported in such a case. " *
205- " Please use AbstractVector for initial_params instead of NamedTuple." ,
206- ),
207- )
208- end
209- end
210- throw (ArgumentError (" Variable $v not found in the model." ))
211- end
212- end
213- initial_params = NamedTuple (k => v for (k, v) in pairs (initial_params) if v != = missing )
214- return update_values!! (
215- varinfo, initial_params, map (k -> VarName {k} (), keys (initial_params))
216- )
217- end
218-
219- function initialize_parameters!! (vi:: AbstractVarInfo , initial_params, model:: Model )
220- @debug " Using passed-in initial variable values" initial_params
221-
222- # `link` the varinfo if needed.
223- linked = islinked (vi)
224- if linked
225- vi = invlink!! (vi, model)
226- end
227-
228- # Set the values in `vi`.
229- vi = set_initial_values (vi, initial_params)
230-
231- # `invlink` if needed.
232- if linked
233- vi = link!! (vi, model)
234- end
235-
236- return vi
237- end
145+ default_chain_type (:: Sampler ) = Any
238146
239147"""
240148 initialstep(rng, model, sampler, varinfo; kwargs...)
0 commit comments