@@ -9,16 +9,24 @@ Any subtype of `AbstractInitStrategy` must implement the
99"""
1010abstract type AbstractInitStrategy end
1111
12+ """
13+ InitValue{T,F}
14+
15+ A wrapper type representing a value of type `T`. The function `F` indicates what transform
16+ to apply to the value to convert it back to the unlinked space. If `value` is already in
17+ unlinked space, then `transform` can be `identity`.
18+ """
19+ struct InitValue{T,F}
20+ value:: T
21+ transform:: F
22+ end
23+
1224"""
1325 init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy)
1426
1527Generate a new value for a random variable with the given distribution.
1628
17- !!! warning "Return values must be unlinked"
18- The values returned by `init` must always be in the untransformed space, i.e.,
19- they must be within the support of the original distribution. That means that,
20- for example, `init(rng, dist, u::InitFromUniform)` will in general return values that
21- are outside the range [u.lower, u.upper].
29+ This function must return a `InitValue`.
2230"""
2331function init end
2432
@@ -29,7 +37,7 @@ Obtain new values by sampling from the prior distribution.
2937"""
3038struct InitFromPrior <: AbstractInitStrategy end
3139function init (rng:: Random.AbstractRNG , :: VarName , dist:: Distribution , :: InitFromPrior )
32- return rand (rng, dist)
40+ return InitValue ( rand (rng, dist), identity )
3341end
3442
3543"""
@@ -69,7 +77,7 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFro
6977 if x isa Array{<: Any ,0 }
7078 x = x[]
7179 end
72- return x
80+ return InitValue (x, identity)
7381end
7482
7583"""
@@ -93,19 +101,20 @@ will be thrown. The default for `fallback` is `InitFromPrior()`.
93101struct InitFromParams{P,S<: Union{AbstractInitStrategy,Nothing} } <: AbstractInitStrategy
94102 params:: P
95103 fallback:: S
104+
96105 function InitFromParams (
97- params:: AbstractDict{<:VarName} ,
98- fallback:: Union{AbstractInitStrategy,Nothing} = InitFromPrior (),
99- )
100- return new {typeof(params),typeof(fallback)} (params, fallback)
101- end
102- function InitFromParams (
103- params:: NamedTuple , fallback:: Union{AbstractInitStrategy,Nothing} = InitFromPrior ()
104- )
105- return new {typeof(params),typeof(fallback)} (params, fallback)
106+ params:: P , fallback:: Union{AbstractInitStrategy,Nothing} = InitFromPrior ()
107+ ) where {P}
108+ return new {P,typeof(fallback)} (params, fallback)
106109 end
107110end
108- function init (rng:: Random.AbstractRNG , vn:: VarName , dist:: Distribution , p:: InitFromParams )
111+
112+ function init (
113+ rng:: Random.AbstractRNG ,
114+ vn:: VarName ,
115+ dist:: Distribution ,
116+ p:: InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}} ,
117+ )
109118 # TODO (penelopeysm): It would be nice to do a check to make sure that all
110119 # of the parameters in `p.params` were actually used, and either warn or
111120 # error if they aren't. This is actually quite non-trivial though because
@@ -119,14 +128,82 @@ function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitF
119128 else
120129 # TODO (penelopeysm): Since x is user-supplied, maybe we could also
121130 # check here that the type / size of x matches the dist?
122- x
131+ InitValue (x, identity)
123132 end
124133 else
125134 p. fallback === nothing && error (" No value was provided for the variable `$(vn) `." )
126135 init (rng, vn, dist, p. fallback)
127136 end
128137end
129138
139+ """
140+ RangeAndLinked
141+
142+ Suppose we have vectorised parameters `params::AbstractVector{<:Real}`. Each random variable
143+ in the model will in general correspond to a sub-vector of `params`. This struct stores
144+ information about that range, as well as whether the sub-vector represents a linked value or
145+ an unlinked value.
146+
147+ $(TYPEDFIELDS)
148+ """
149+ struct RangeAndLinked
150+ # indices that the variable corresponds to in the vectorised parameter
151+ range:: UnitRange{Int}
152+ # whether it's linked
153+ is_linked:: Bool
154+ end
155+
156+ """
157+ VectorWithRanges(
158+ iden_varname_ranges::NamedTuple,
159+ varname_ranges::Dict{VarName,RangeAndLinked},
160+ vect::AbstractVector{<:Real},
161+ )
162+
163+ A struct that wraps a vector of parameter values, plus information about how random
164+ variables map to ranges in that vector.
165+
166+ In the simplest case, this could be accomplished only with a single dictionary mapping
167+ VarNames to ranges and link status. However, for performance reasons, we separate out
168+ VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All
169+ non-identity-optic VarNames are stored in the `varname_ranges` Dict.
170+
171+ It would be nice to improve the NamedTuple and Dict approach. See, e.g.
172+ https://github.com/TuringLang/DynamicPPL.jl/issues/1116.
173+ """
174+ struct VectorWithRanges{N<: NamedTuple ,T<: AbstractVector{<:Real} }
175+ # This NamedTuple stores the ranges for identity VarNames
176+ iden_varname_ranges:: N
177+ # This Dict stores the ranges for all other VarNames
178+ varname_ranges:: Dict{VarName,RangeAndLinked}
179+ # The full parameter vector which we index into to get variable values
180+ vect:: T
181+ end
182+
183+ function _get_range_and_linked (
184+ vr:: VectorWithRanges , :: VarName{sym,typeof(identity)}
185+ ) where {sym}
186+ return vr. iden_varname_ranges[sym]
187+ end
188+ function _get_range_and_linked (vr:: VectorWithRanges , vn:: VarName )
189+ return vr. varname_ranges[vn]
190+ end
191+ function init (
192+ :: Random.AbstractRNG ,
193+ vn:: VarName ,
194+ dist:: Distribution ,
195+ p:: InitFromParams{<:VectorWithRanges} ,
196+ )
197+ vr = p. params
198+ range_and_linked = _get_range_and_linked (vr, vn)
199+ transform = if range_and_linked. is_linked
200+ from_linked_vec_transform (dist)
201+ else
202+ from_vec_transform (dist)
203+ end
204+ return InitValue ((@view vr. vect[range_and_linked. range]), transform)
205+ end
206+
130207"""
131208 InitContext(
132209 [rng::Random.AbstractRNG=Random.default_rng()],
@@ -156,27 +233,43 @@ function tilde_assume!!(
156233 ctx:: InitContext , dist:: Distribution , vn:: VarName , vi:: AbstractVarInfo
157234)
158235 in_varinfo = haskey (vi, vn)
159- # `init()` always returns values in original space, i.e. possibly
160- # constrained
161- x = init (ctx. rng, vn, dist, ctx. strategy)
236+ init_val = init (ctx. rng, vn, dist, ctx. strategy)
237+ x, inv_logjac = with_logabsdet_jacobian (init_val. transform, init_val. value)
162238 # Determine whether to insert a transformed value into the VarInfo.
163239 # If the VarInfo alrady had a value for this variable, we will
164240 # keep the same linked status as in the original VarInfo. If not, we
165241 # check the rest of the VarInfo to see if other variables are linked.
166242 # is_transformed(vi) returns true if vi is nonempty and all variables in vi
167243 # are linked.
168244 insert_transformed_value = in_varinfo ? is_transformed (vi, vn) : is_transformed (vi)
169- y, logjac = if insert_transformed_value
170- with_logabsdet_jacobian (link_transform (dist), x)
245+ val_to_insert, logjac = if insert_transformed_value
246+ # Calculate the forward logjac and sum them up.
247+ y, fwd_logjac = with_logabsdet_jacobian (link_transform (dist), x)
248+ # Note that if we use VectorWithRanges with a full VarInfo, this double-Jacobian
249+ # calculation wastes a lot of time going from linked vectorised -> unlinked ->
250+ # linked, and `inv_logjac` will also just be the negative of `fwd_logjac`. However,
251+ # `VectorWithRanges` is only really used with `OnlyAccsVarInfo`, in which case this
252+ # branch is never hit (since `in_varinfo` will always be false). So we can leave
253+ # this branch in for full generality with other combinations of init strategies /
254+ # VarInfo.
255+ #
256+ # TODO (penelopeysm): Figure out one day how to refactor this. The crux of the issue
257+ # is that the transform used by `VectorWithRanges` is `from_linked_VEC_transform`,
258+ # which is NOT the same as `inverse(link_transform)` (because there is an additional
259+ # vectorisation step). We need `init` and `tilde_assume!!` to share this information
260+ # but it's not clear right now how to do this. In my opinion, the most productive
261+ # way forward would be to standardise the behaviour of bijectors so that we can have
262+ # a clean separation between the linking and vectorisation parts of it.
263+ y, inv_logjac + fwd_logjac
171264 else
172- x, zero (LogProbType)
265+ x, inv_logjac
173266 end
174267 # Add the new value to the VarInfo. `push!!` errors if the value already
175268 # exists, hence the need for setindex!!.
176269 if in_varinfo
177- vi = setindex!! (vi, y , vn)
270+ vi = setindex!! (vi, val_to_insert , vn)
178271 else
179- vi = push!! (vi, vn, y , dist)
272+ vi = push!! (vi, vn, val_to_insert , dist)
180273 end
181274 # Neither of these set the `trans` flag so we have to do it manually if
182275 # necessary.
0 commit comments