@@ -12,9 +12,6 @@ abstract type AbstractInitStrategy end
1212"""
1313 InitValue{T,F}
1414
15- A wrapper type representing a value of type `T`. The function `F` indicates what transform
16- to apply to the value to convert it back to the unlinked space. If `value` is already in
17- unlinked space, then `transform` can be `identity`.
1815"""
1916struct InitValue{T,F}
2017 value:: T
2623
2724Generate a new value for a random variable with the given distribution.
2825
29- This function must return a `InitValue`.
26+ This function must return a tuple of:
27+
28+ - the generated value
29+ - a function that transforms the generated value back to the unlinked space. If the value is
30+ already in unlinked space, then this should be `identity`.
3031"""
3132function init end
3233
@@ -37,7 +38,7 @@ Obtain new values by sampling from the prior distribution.
3738"""
3839struct InitFromPrior <: AbstractInitStrategy end
3940function init (rng:: Random.AbstractRNG , :: VarName , dist:: Distribution , :: InitFromPrior )
40- return InitValue ( rand (rng, dist), identity)
41+ return rand (rng, dist), identity
4142end
4243
4344"""
@@ -77,7 +78,7 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFro
7778 if x isa Array{<: Any ,0 }
7879 x = x[]
7980 end
80- return InitValue ( x, identity)
81+ return x, identity
8182end
8283
8384"""
@@ -128,7 +129,7 @@ function init(
128129 else
129130 # TODO (penelopeysm): Since x is user-supplied, maybe we could also
130131 # check here that the type / size of x matches the dist?
131- InitValue ( x, identity)
132+ x, identity
132133 end
133134 else
134135 p. fallback === nothing && error (" No value was provided for the variable `$(vn) `." )
@@ -201,7 +202,7 @@ function init(
201202 else
202203 from_vec_transform (dist)
203204 end
204- return InitValue (( @view vr. vect[range_and_linked. range]), transform)
205+ return ( @view vr. vect[range_and_linked. range]), transform
205206end
206207
207208"""
@@ -233,8 +234,8 @@ function tilde_assume!!(
233234 ctx:: InitContext , dist:: Distribution , vn:: VarName , vi:: AbstractVarInfo
234235)
235236 in_varinfo = haskey (vi, vn)
236- init_val = init (ctx. rng, vn, dist, ctx. strategy)
237- x, inv_logjac = with_logabsdet_jacobian (init_val . transform, init_val . value )
237+ val, transform = init (ctx. rng, vn, dist, ctx. strategy)
238+ x, inv_logjac = with_logabsdet_jacobian (transform, val )
238239 # Determine whether to insert a transformed value into the VarInfo.
239240 # If the VarInfo alrady had a value for this variable, we will
240241 # keep the same linked status as in the original VarInfo. If not, we
0 commit comments