1- # Uniform random numbers with range 4 for robust initializations
1+ # UniformInit random numbers with range 4 for robust initializations
22# Reference: https://mc-stan.org/docs/2_19/reference-manual/initialization.html
33randrealuni (rng:: Random.AbstractRNG ) = 4 * rand (rng) - 2
44randrealuni (rng:: Random.AbstractRNG , args... ) = 4 .* rand (rng, args... ) .- 2
55
6- istransformable (dist) = link_transform (dist) != = identity
7-
8- # ################################
9- # Single-sample initialisations #
10- # ################################
11- inittrans (rng, dist:: UnivariateDistribution ) = Bijectors. invlink (dist, randrealuni (rng))
12- function inittrans (rng, dist:: MultivariateDistribution )
13- # Get the length of the unconstrained vector
14- b = link_transform (dist)
15- d = Bijectors. output_length (b, length (dist))
16- return Bijectors. invlink (dist, randrealuni (rng, d))
17- end
18- function inittrans (rng, dist:: MatrixDistribution )
19- # Get the size of the unconstrained vector
20- b = link_transform (dist)
21- sz = Bijectors. output_size (b, size (dist))
22- return Bijectors. invlink (dist, randrealuni (rng, sz... ))
23- end
24- function inittrans (rng, dist:: Distribution{CholeskyVariate} )
25- # Get the size of the unconstrained vector
26- b = link_transform (dist)
27- sz = Bijectors. output_size (b, size (dist))
28- return Bijectors. invlink (dist, randrealuni (rng, sz... ))
29- end
30- # ###############################
31- # Multi-sample initialisations #
32- # ###############################
33- function inittrans (rng, dist:: UnivariateDistribution , n:: Int )
34- return Bijectors. invlink (dist, randrealuni (rng, n))
35- end
36- function inittrans (rng, dist:: MultivariateDistribution , n:: Int )
37- return Bijectors. invlink (dist, randrealuni (rng, size (dist)[1 ], n))
38- end
39- function inittrans (rng, dist:: MatrixDistribution , n:: Int )
40- return Bijectors. invlink (dist, [randrealuni (rng, size (dist)... ) for _ in 1 : n])
41- end
42-
436"""
447 AbstractInitStrategy
458
@@ -49,15 +12,29 @@ the random variables in a model (e.g., when creating a new VarInfo).
4912abstract type AbstractInitStrategy end
5013
5114"""
52- Prior()
15+ init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy)
16+
17+ Generate a new value for a random variable with the given distribution.
18+
19+ !!! warning "Values must be unlinked"
20+ The values returned by `init` are always in the untransformed space, i.e.,
21+ they must be within the support of the original distribution. That means that,
22+ for example, `init(rng, dist, u::UniformInit)` will in general return values that
23+ are outside the range [u.lower, u.upper].
24+ """
25+ function init end
5326
54- Obtain new values by sampling from the prior.
5527"""
56- struct Prior <: AbstractInitStrategy end
28+ PriorInit()
5729
30+ Obtain new values by sampling from the prior distribution.
5831"""
59- Uniform()
60- Uniform(lower, upper)
32+ struct PriorInit <: AbstractInitStrategy end
33+ init (rng:: Random.AbstractRNG , :: VarName , dist:: Distribution , :: PriorInit ) = rand (rng, dist)
34+
35+ """
36+ UniformInit()
37+ UniformInit(lower, upper)
6138
6239Obtain new values by first transforming the distribution of the random variable
6340to unconstrained space, and then sampling a value uniformly between `lower` and
@@ -70,41 +47,65 @@ default initialisation strategy.
7047
7148[Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization)
7249"""
73- struct Uniform {T<: AbstractFloat } <: AbstractInitStrategy
50+ struct UniformInit {T<: AbstractFloat } <: AbstractInitStrategy
7451 lower:: T
7552 upper:: T
53+ function UniformInit (lower:: T , upper:: T ) where {T<: AbstractFloat }
54+ lower > upper &&
55+ throw (ArgumentError (" `lower` must be less than or equal to `upper`" ))
56+ return new {T} (lower, upper)
57+ end
58+ UniformInit () = UniformInit (- 2.0 , 2.0 )
59+ end
60+ function init (rng:: Random.AbstractRNG , :: VarName , dist:: Distribution , u:: UniformInit )
61+ b = Bijectors. bijector (dist)
62+ sz = Bijectors. output_size (b, size (dist))
63+ y = rand (rng, Uniform (u. lower, u. upper), sz)
64+ b_inv = Bijectors. inverse (b)
65+ return b_inv (y)
7666end
77- Uniform () = Uniform (- 2 , 2 )
7867
7968"""
80- Params (params::AbstractDict{VarName, Any }, default::AbstractInitStrategy)
81- Params (params::NamedTuple, default::AbstractInitStrategy)
69+ ParamsInit (params::AbstractDict{<: VarName}, default::AbstractInitStrategy=PriorInit() )
70+ ParamsInit (params::NamedTuple, default::AbstractInitStrategy=PriorInit() )
8271
8372Obtain new values by extracting them from the given dictionary or NamedTuple.
84- These values are assumed to be provided in the space of the untransformed
85- distribution.
86-
8773The parameter `default` specifies how new values are to be obtained if they
88- cannot be found in `params`. The default for `default` is `Prior()`.
74+ cannot be found in `params`, or they are specified as `missing`. The default
75+ for `default` is `PriorInit()`.
76+
77+ !!! note
78+ These values must be provided in the space of the untransformed distribution.
8979"""
90- struct Params {P,S<: AbstractInitStrategy } <: AbstractInitStrategy
80+ struct ParamsInit {P,S<: AbstractInitStrategy } <: AbstractInitStrategy
9181 params:: P
9282 default:: S
93-
94- function Params (
95- params:: AbstractDict{VarName,Any} , default:: AbstractInitStrategy = Prior ()
96- )
83+ function ParamsInit (params:: AbstractDict{<:VarName} , default:: AbstractInitStrategy )
9784 return new {typeof(params),typeof(default)} (params, default)
9885 end
99- function Params (params:: NamedTuple , default:: AbstractInitStrategy = Prior ())
100- return Params (to_varname_dict (params), default)
86+ ParamsInit (params:: AbstractDict{<:VarName} ) = ParamsInit (params, PriorInit ())
87+ function ParamsInit (params:: NamedTuple , default:: AbstractInitStrategy = PriorInit ())
88+ return ParamsInit (to_varname_dict (params), default)
89+ end
90+ end
91+ function init (rng:: Random.AbstractRNG , vn:: VarName , dist:: Distribution , p:: ParamsInit )
92+ return if hasvalue (p. params, vn)
93+ x = getvalue (p. params, vn)
94+ if x === missing
95+ init (rng, vn, dist, p. default)
96+ else
97+ # TODO : Check that the type of x matches the dist?
98+ x
99+ end
100+ else
101+ init (rng, vn, dist, p. default)
101102 end
102103end
103104
104105"""
105106 InitContext(
106107 [rng::Random.AbstractRNG=Random.default_rng()],
107- [strategy::AbstractInitStrategy=Prior ()],
108+ [strategy::AbstractInitStrategy=PriorInit ()],
108109 )
109110
110111A leaf context that indicates that new values for random variables are
@@ -115,95 +116,144 @@ VarInfo. Note that, if `leafcontext(model.context) isa InitContext`, then
115116struct InitContext{R<: Random.AbstractRNG ,S<: AbstractInitStrategy } <: AbstractContext
116117 rng:: R
117118 strategy:: S
118- function InitContext (rng:: Random.AbstractRNG , strategy:: AbstractInitStrategy = Prior ())
119+ function InitContext (
120+ rng:: Random.AbstractRNG , strategy:: AbstractInitStrategy = PriorInit ()
121+ )
119122 return new {typeof(rng),typeof(strategy)} (rng, strategy)
120123 end
121- function InitContext (strategy:: AbstractInitStrategy = Prior ())
124+ function InitContext (strategy:: AbstractInitStrategy = PriorInit ())
122125 return InitContext (Random. default_rng (), strategy)
123126 end
124127end
125128NodeTrait (:: InitContext ) = IsLeaf ()
126129
127130function tilde_assume (
128- ctx:: InitContext{<:Random.AbstractRNG,Prior} ,
129- dist:: Distribution ,
130- vn:: VarName ,
131- vi:: AbstractVarInfo ,
131+ ctx:: InitContext , dist:: Distribution , vn:: VarName , vi:: AbstractVarInfo
132132)
133- r = rand (ctx. rng, dist)
134- vi[vn] = r
135- # TODO : FIX
136- logjac = 0
137- vi = accumulate_assume!! (vi, r, - logjac, vn, dist)
138- println (" sampled $r from $dist for $vn " )
139- return r, vi
133+ in_varinfo = haskey (vi, vn)
134+ # `init()` always returns values in original space, i.e. possibly
135+ # constrained
136+ x = init (ctx. rng, vn, dist, ctx. strategy)
137+ # There is a function `to_maybe_linked_internal_transform` that does this,
138+ # but unfortunately it uses `istrans(vi, vn)` which fails if vn is not in
139+ # vi, so we have to manually check. By default we will insert an unlinked
140+ # value into the varinfo.
141+ is_transformed = in_varinfo ? istrans (vi, vn) : false
142+ f = if is_transformed
143+ to_linked_internal_transform (vi, vn, dist)
144+ else
145+ to_internal_transform (vi, vn, dist)
146+ end
147+ # TODO (penelopeysm): We would really like to do:
148+ # y, logjac = with_logabsdet_jacobian(f, x)
149+ # Unfortunately, `to_{linked_}internal_transform` returns a function that
150+ # always converts x to a vector, i.e., if dist is univariate, f(x) will be
151+ # a vector of length 1. It would be nice if we could unify these.
152+ y = f (x)
153+ logjac = logabsdetjac (is_transformed ? Bijectors. bijector (dist) : identity, x)
154+ # Add the new value to the VarInfo. `push!!` errors if the value already
155+ # exists, hence the need for setindex!!
156+ if in_varinfo
157+ vi = setindex!! (vi, y, vn)
158+ else
159+ vi = push!! (vi, vn, y, dist)
160+ end
161+ # `accumulate_assume!!` wants untransformed values as the second argument.
162+ vi = accumulate_assume!! (vi, x, - logjac, vn, dist)
163+ # We always return the untransformed value here, as that will determine
164+ # what the lhs of the tilde-statement is set to.
165+ return x, vi
140166end
141167
142- # TODO : Remove this thing.
143- # function assume(
144- # rng::Random.AbstractRNG,
145- # init_strategy::AbstractInitStrategy,
146- # dist::Distribution,
147- # vn::VarName,
148- # vi::AbstractVarInfo,
168+ # """
169+ # set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
170+ # set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
171+ #
172+ # Take the values inside `initial_params`, replace the corresponding values in
173+ # the given VarInfo object, and return a new VarInfo object with the updated values.
174+ #
175+ # This differs from `DynamicPPL.unflatten` in two ways:
176+ #
177+ # 1. It works with `NamedTuple` arguments.
178+ # 2. For the `AbstractVector` method, if any of the elements are missing, it will not
179+ # overwrite the original value in the VarInfo (it will just use the original
180+ # value instead).
181+ # """
182+ # function set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
183+ # throw(
184+ # ArgumentError(
185+ # "`initial_params` must be a vector of type `Union{Real,Missing}`. " *
186+ # "If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first.",
187+ # ),
188+ # )
189+ # end
190+ #
191+ # function set_initial_values(
192+ # varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}}
149193# )
150- # if haskey(vi, vn)
151- # # Always overwrite the parameters with new ones for `SampleFromUniform`.
152- # if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
153- # # TODO (mhauru) Is it important to unset the flag here? The `true` allows us
154- # # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
155- # # if that's okay.
156- # unset_flag!(vi, vn, "del", true)
157- # r = init(rng, dist, sampler)
158- # f = to_maybe_linked_internal_transform(vi, vn, dist)
159- # # TODO (mhauru) This should probably be call a function called setindex_internal!
160- # vi = BangBang.setindex!!(vi, f(r), vn)
161- # setorder!(vi, vn, get_num_produce(vi))
162- # else
163- # # Otherwise we just extract it.
164- # r = vi[vn, dist]
165- # end
166- # else
167- # r = init(rng, dist, sampler)
168- # if istrans(vi)
169- # f = to_linked_internal_transform(vi, vn, dist)
170- # vi = push!!(vi, vn, f(r), dist)
171- # # By default `push!!` sets the transformed flag to `false`.
172- # vi = settrans!!(vi, true, vn)
173- # else
174- # vi = push!!(vi, vn, r, dist)
194+ # flattened_param_vals = varinfo[:]
195+ # length(flattened_param_vals) == length(initial_params) || throw(
196+ # DimensionMismatch(
197+ # "Provided initial value size ($(length(initial_params))) doesn't match " *
198+ # "the model size ($(length(flattened_param_vals))).",
199+ # ),
200+ # )
201+ #
202+ # # Update values that are provided.
203+ # for i in eachindex(initial_params)
204+ # x = initial_params[i]
205+ # if x !== missing
206+ # flattened_param_vals[i] = x
175207# end
176208# end
177209#
178- # # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
179- # logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
180- # vi = accumulate_assume!!(vi, r, -logjac, vn, dist)
181- # return r, vi
210+ # # Update in `varinfo`.
211+ # new_varinfo = unflatten(varinfo, flattened_param_vals)
212+ # return new_varinfo
182213# end
183-
184- # function assume(
185- # rng::Random.AbstractRNG,
186- # sampler::Union{SampleFromPrior,SampleFromUniform},
187- # dist::Distribution,
188- # vn::VarName,
189- # vi::SimpleOrThreadSafeSimple,
190- # )
191- # value = init(rng, dist, sampler)
192- # # Transform if we're working in unconstrained space.
193- # f = to_maybe_linked_internal_transform(vi, vn, dist)
194- # value_raw, logjac = with_logabsdet_jacobian(f, value)
195- # vi = BangBang.push!!(vi, vn, value_raw, dist)
196- # vi = accumulate_assume!!(vi, value, -logjac, vn, dist)
197- # return value, vi
198- # end
199-
200- # Initializations.
201- # init(rng, dist, ::SampleFromPrior) = rand(rng, dist)
202- # function init(rng, dist, ::SampleFromUniform)
203- # return istransformable(dist) ? inittrans(rng, dist) : rand(rng, dist)
214+ #
215+ # function set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
216+ # varinfo = deepcopy(varinfo)
217+ # vars_in_varinfo = keys(varinfo)
218+ # for v in keys(initial_params)
219+ # vn = VarName{v}()
220+ # if !(vn in vars_in_varinfo)
221+ # for vv in vars_in_varinfo
222+ # if subsumes(vn, vv)
223+ # throw(
224+ # ArgumentError(
225+ # "The current model contains sub-variables of $v, such as ($vv). " *
226+ # "Using NamedTuple for initial_params is not supported in such a case. " *
227+ # "Please use AbstractVector for initial_params instead of NamedTuple.",
228+ # ),
229+ # )
230+ # end
231+ # end
232+ # throw(ArgumentError("Variable $v not found in the model."))
233+ # end
234+ # end
235+ # initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing)
236+ # return update_values!!(
237+ # varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params))
238+ # )
204239# end
205240#
206- # init(rng, dist, ::SampleFromPrior, n::Int) = rand(rng, dist, n)
207- # function init(rng, dist, ::SampleFromUniform, n::Int)
208- # return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n)
241+ # function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Model)
242+ # @debug "Using passed-in initial variable values" initial_params
243+ #
244+ # # `link` the varinfo if needed.
245+ # linked = islinked(vi)
246+ # if linked
247+ # vi = invlink!!(vi, model)
248+ # end
249+ #
250+ # # Set the values in `vi`.
251+ # vi = set_initial_values(vi, initial_params)
252+ #
253+ # # `invlink` if needed.
254+ # if linked
255+ # vi = link!!(vi, model)
256+ # end
257+ #
258+ # return vi
209259# end
0 commit comments