11"""
22 AbstractInitStrategy
33
4- Abstract type representing the possible ways of initialising new values for
5- the random variables in a model (e.g., when creating a new VarInfo).
4+ Abstract type representing the possible ways of initialising new values for the random
5+ variables in a model (e.g., when creating a new VarInfo).
66
7- Any subtype of `AbstractInitStrategy` must implement the
8- [`DynamicPPL.init `](@ref) method .
7+ Any subtype of `AbstractInitStrategy` must implement the [`DynamicPPL.init`](@ref) method,
8+ and very rarely, [`DynamicPPL.get_param_eltype `](@ref).
99"""
1010abstract type AbstractInitStrategy end
1111
@@ -14,14 +14,58 @@ abstract type AbstractInitStrategy end
1414
1515Generate a new value for a random variable with the given distribution.
1616
17- This function must return a tuple of:
17+ This function must return a tuple `(x, trf)`, where
1818
19- - the generated value
20- - a function that transforms the generated value back to the unlinked space. If the value is
21- already in unlinked space, then this should be `identity`.
19+ - `x` is the generated value
20+
21+ - `trf` is a function that transforms the generated value back to the unlinked space. If the
22+ value is already in unlinked space, then this should be `DynamicPPL.typed_identity`. You
23+ can also use `Base.identity`, but if you use this, you **must** be confident that
24+ `zero(eltype(x))` will **never** error. See the docstring of `typed_identity` for more
25+ information.
2226"""
2327function init end
2428
29+ """
30+ DynamicPPL.get_param_eltype(strategy::AbstractInitStrategy)
31+
32+ Return the element type of the parameters generated from the given initialisation strategy.
33+
34+ The default implementation returns `Any`. However, for `InitFromParams` which provides known
35+ parameters for evaluating the model, methods are implemented in order to return more specific
36+ types.
37+
38+ For the most part, a return value of `Any` will actually suffice. However, there are a few
39+ edge cases in DynamicPPL where the element type is needed. These largely relate to
40+ determining the element type of accumulators ahead of time (_before_ evaluation), as well as
41+ promoting type parameters in model arguments. The classic case is when evaluating a model
42+ with ForwardDiff: the accumulators must be set to `Dual`s, and any `Vector{Float64}`
43+ arguments must be promoted to `Vector{Dual}`. Other tracer types, for example those in
44+ SparseConnectivityTracer.jl, also require similar treatment.
45+
46+ If `AbstractInitStrategy` is never used in combination with tracer types, then it is
47+ perfectly safe to return `Any`. This does not lead to type instability downstream because
48+ the actual accumulators will still be created with concrete Float types (the `Any` is just
49+ used to determine whether the float type needs to be modified).
50+
51+ (Detail: in fact, the above is not always true. Firstly, the accumulator argument is only
52+ true when evaluating with ThreadSafeVarInfo. See the comments in `DynamicPPL.unflatten` for
53+ more details. For non-threadsafe evaluation, Julia is capable of automatically promoting the
54+ types on its own. Secondly, the promotion only matters if you are trying to directly assign
55+ into a `Vector{Float64}` with a `ForwardDiff.Dual` or similar tracer type, for example using
56+ `xs[i] = MyDual`. This doesn't actually apply to tilde-statements like `xs[i] ~ ...` because
57+ those use `Accessors.@set` under the hood, which also does the promotion for you.)
58+ """
59+ get_param_eltype (:: AbstractInitStrategy ) = Any
60+ function get_param_eltype (strategy:: InitFromParams{<:VectorWithRanges} )
61+ return eltype (strategy. params. vect)
62+ end
63+ function get_param_eltype (
64+ strategy:: InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}}
65+ )
66+ return infer_nested_eltype (typeof (strategy. params))
67+ end
68+
2569"""
2670 InitFromPrior()
2771
74118
75119"""
76120 InitFromParams(
77- params::Union{AbstractDict{<:VarName},NamedTuple},
121+ params::Any
78122 fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior()
79123 )
80124
81- Obtain new values by extracting them from the given dictionary or NamedTuple.
125+ Obtain new values by extracting them from the given set of `params`.
126+
127+ The most common use case is to provide a `NamedTuple` or `AbstractDict{<:VarName}`, which
128+ provides a mapping from variable names to values. However, we leave the type of `params`
129+ open in order to allow for custom parameter storage types.
130+
131+ ## Custom parameter storage types
82132
83- The parameter `fallback` specifies how new values are to be obtained if they
84- cannot be found in `params`, or they are specified as `missing`. `fallback`
85- can either be an initialisation strategy itself, in which case it will be
86- used to obtain new values, or it can be `nothing`, in which case an error
87- will be thrown. The default for `fallback` is `InitFromPrior()`.
133+ For `InitFromParams` to work correctly with a custom `params::P`, you need to implement
88134
89- !!! note
90- The values in `params` must be provided in the space of the untransformed
91- distribution.
135+ ```julia
136+ DynamicPPL.init(rng, vn::VarName, dist::Distribution, p::InitFromParams{P}) where {P}
137+ ```
138+
139+ This tells you how to obtain values for the random variable `vn` from `p.params`. Note that
140+ the last argument is `InitFromParams(params)`, not just `params` itself. Please see the
141+ docstring of [`DynamicPPL.init`](@ref) for more information on the expected behaviour.
142+
143+ If you only use `InitFromParams` with `DynamicPPL.OnlyAccsVarInfo`, as is usually the case,
144+ then you will not need to implement anything else. So far, this is the same as you would do
145+ for creating any new `AbstractInitStrategy` subtype.
146+
147+ However, to use `InitFromParams` with a full `DynamicPPL.VarInfo`, you *may* also need to
148+ implement
149+
150+ ```julia
151+ DynamicPPL.get_param_eltype(p::InitFromParams{P}) where {P}
152+ ```
153+
154+ See the docstring of [`DynamicPPL.get_param_eltype`](@ref) for more information on when this
155+ is needed.
156+
157+ The argument `fallback` specifies how new values are to be obtained if they cannot be found
158+ in `params`, or they are specified as `missing`. `fallback` can either be an initialisation
159+ strategy itself, in which case it will be used to obtain new values, or it can be `nothing`,
160+ in which case an error will be thrown. The default for `fallback` is `InitFromPrior()`.
92161"""
93162struct InitFromParams{P,S<: Union{AbstractInitStrategy,Nothing} } <: AbstractInitStrategy
94163 params:: P
@@ -102,11 +171,8 @@ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitS
102171end
103172
104173function init (
105- rng:: Random.AbstractRNG ,
106- vn:: VarName ,
107- dist:: Distribution ,
108- p:: InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}} ,
109- )
174+ rng:: Random.AbstractRNG , vn:: VarName , dist:: Distribution , p:: InitFromParams{P}
175+ ) where {P<: Union{AbstractDict{<:VarName},NamedTuple} }
110176 # TODO (penelopeysm): It would be nice to do a check to make sure that all
111177 # of the parameters in `p.params` were actually used, and either warn or
112178 # error if they aren't. This is actually quite non-trivial though because
0 commit comments