Skip to content

Commit 1156a49

Browse files
committed
Tidy up loads of things
1 parent de88c78 commit 1156a49

File tree

4 files changed

+103
-38
lines changed

4 files changed

+103
-38
lines changed

docs/src/api.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,12 @@ DynamicPPL.prefix
170170

171171
## Utilities
172172

173+
`typed_identity` is the same as `identity`, but with an overload for `with_logabsdet_jacobian` that ensures that it never errors.
174+
175+
```@docs
176+
typed_identity
177+
```
178+
173179
It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function.
174180

175181
```@docs
@@ -517,10 +523,12 @@ InitFromParams
517523
```
518524

519525
If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method.
526+
In very rare situations, you may also need to implement `get_param_eltype`, which defines the element type of the parameters generated by the strategy.
520527

521528
```@docs
522529
AbstractInitStrategy
523530
init
531+
get_param_eltype
524532
```
525533

526534
### Choosing a suitable VarInfo

src/DynamicPPL.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ export AbstractVarInfo,
8484
# Compiler
8585
@model,
8686
# Utilities
87-
init,
8887
OrderedDict,
88+
typed_identity,
8989
# Model
9090
Model,
9191
getmissings,

src/contexts/init.jl

Lines changed: 89 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""
22
AbstractInitStrategy
33
4-
Abstract type representing the possible ways of initialising new values for
5-
the random variables in a model (e.g., when creating a new VarInfo).
4+
Abstract type representing the possible ways of initialising new values for the random
5+
variables in a model (e.g., when creating a new VarInfo).
66
7-
Any subtype of `AbstractInitStrategy` must implement the
8-
[`DynamicPPL.init`](@ref) method.
7+
Any subtype of `AbstractInitStrategy` must implement the [`DynamicPPL.init`](@ref) method,
8+
and very rarely, [`DynamicPPL.get_param_eltype`](@ref).
99
"""
1010
abstract type AbstractInitStrategy end
1111

@@ -14,14 +14,58 @@ abstract type AbstractInitStrategy end
1414
1515
Generate a new value for a random variable with the given distribution.
1616
17-
This function must return a tuple of:
17+
This function must return a tuple `(x, trf)`, where
1818
19-
- the generated value
20-
- a function that transforms the generated value back to the unlinked space. If the value is
21-
already in unlinked space, then this should be `identity`.
19+
- `x` is the generated value
20+
21+
- `trf` is a function that transforms the generated value back to the unlinked space. If the
22+
value is already in unlinked space, then this should be `DynamicPPL.typed_identity`. You
23+
can also use `Base.identity`, but if you use this, you **must** be confident that
24+
`zero(eltype(x))` will **never** error. See the docstring of `typed_identity` for more
25+
information.
2226
"""
2327
function init end
2428

29+
"""
30+
DynamicPPL.get_param_eltype(strategy::AbstractInitStrategy)
31+
32+
Return the element type of the parameters generated from the given initialisation strategy.
33+
34+
The default implementation returns `Any`. However, for `InitFromParams` which provides known
35+
parameters for evaluating the model, methods are implemented in order to return more specific
36+
types.
37+
38+
For the most part, a return value of `Any` will actually suffice. However, there are a few
39+
edge cases in DynamicPPL where the element type is needed. These largely relate to
40+
determining the element type of accumulators ahead of time (_before_ evaluation), as well as
41+
promoting type parameters in model arguments. The classic case is when evaluating a model
42+
with ForwardDiff: the accumulators must be set to `Dual`s, and any `Vector{Float64}`
43+
arguments must be promoted to `Vector{Dual}`. Other tracer types, for example those in
44+
SparseConnectivityTracer.jl, also require similar treatment.
45+
46+
If `AbstractInitStrategy` is never used in combination with tracer types, then it is
47+
perfectly safe to return `Any`. This does not lead to type instability downstream because
48+
the actual accumulators will still be created with concrete Float types (the `Any` is just
49+
used to determine whether the float type needs to be modified).
50+
51+
(Detail: in fact, the above is not always true. Firstly, the accumulator argument is only
52+
true when evaluating with ThreadSafeVarInfo. See the comments in `DynamicPPL.unflatten` for
53+
more details. For non-threadsafe evaluation, Julia is capable of automatically promoting the
54+
types on its own. Secondly, the promotion only matters if you are trying to directly assign
55+
into a `Vector{Float64}` with a `ForwardDiff.Dual` or similar tracer type, for example using
56+
`xs[i] = MyDual`. This doesn't actually apply to tilde-statements like `xs[i] ~ ...` because
57+
those use `Accessors.@set` under the hood, which also does the promotion for you.)
58+
"""
59+
get_param_eltype(::AbstractInitStrategy) = Any
60+
function get_param_eltype(strategy::InitFromParams{<:VectorWithRanges})
61+
return eltype(strategy.params.vect)
62+
end
63+
function get_param_eltype(
64+
strategy::InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}}
65+
)
66+
return infer_nested_eltype(typeof(strategy.params))
67+
end
68+
2569
"""
2670
InitFromPrior()
2771
@@ -74,21 +118,46 @@ end
74118

75119
"""
76120
InitFromParams(
77-
params::Union{AbstractDict{<:VarName},NamedTuple},
121+
params::Any
78122
fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior()
79123
)
80124
81-
Obtain new values by extracting them from the given dictionary or NamedTuple.
125+
Obtain new values by extracting them from the given set of `params`.
126+
127+
The most common use case is to provide a `NamedTuple` or `AbstractDict{<:VarName}`, which
128+
provides a mapping from variable names to values. However, we leave the type of `params`
129+
open in order to allow for custom parameter storage types.
130+
131+
## Custom parameter storage types
82132
83-
The parameter `fallback` specifies how new values are to be obtained if they
84-
cannot be found in `params`, or they are specified as `missing`. `fallback`
85-
can either be an initialisation strategy itself, in which case it will be
86-
used to obtain new values, or it can be `nothing`, in which case an error
87-
will be thrown. The default for `fallback` is `InitFromPrior()`.
133+
For `InitFromParams` to work correctly with a custom `params::P`, you need to implement
88134
89-
!!! note
90-
The values in `params` must be provided in the space of the untransformed
91-
distribution.
135+
```julia
136+
DynamicPPL.init(rng, vn::VarName, dist::Distribution, p::InitFromParams{P}) where {P}
137+
```
138+
139+
This tells you how to obtain values for the random variable `vn` from `p.params`. Note that
140+
the last argument is `InitFromParams(params)`, not just `params` itself. Please see the
141+
docstring of [`DynamicPPL.init`](@ref) for more information on the expected behaviour.
142+
143+
If you only use `InitFromParams` with `DynamicPPL.OnlyAccsVarInfo`, as is usually the case,
144+
then you will not need to implement anything else. So far, this is the same as you would do
145+
for creating any new `AbstractInitStrategy` subtype.
146+
147+
However, to use `InitFromParams` with a full `DynamicPPL.VarInfo`, you *may* also need to
148+
implement
149+
150+
```julia
151+
DynamicPPL.get_param_eltype(p::InitFromParams{P}) where {P}
152+
```
153+
154+
See the docstring of [`DynamicPPL.get_param_eltype`](@ref) for more information on when this
155+
is needed.
156+
157+
The argument `fallback` specifies how new values are to be obtained if they cannot be found
158+
in `params`, or they are specified as `missing`. `fallback` can either be an initialisation
159+
strategy itself, in which case it will be used to obtain new values, or it can be `nothing`,
160+
in which case an error will be thrown. The default for `fallback` is `InitFromPrior()`.
92161
"""
93162
struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy
94163
params::P
@@ -102,11 +171,8 @@ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitS
102171
end
103172

104173
function init(
105-
rng::Random.AbstractRNG,
106-
vn::VarName,
107-
dist::Distribution,
108-
p::InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}},
109-
)
174+
rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams{P}
175+
) where {P<:Union{AbstractDict{<:VarName},NamedTuple}}
110176
# TODO(penelopeysm): It would be nice to do a check to make sure that all
111177
# of the parameters in `p.params` were actually used, and either warn or
112178
# error if they aren't. This is actually quite non-trivial though because

src/model.jl

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,28 +1021,19 @@ By default, this uses `eltype(varinfo)` which is slightly cursed. This relies on
10211021
that typically, before evaluation, the parameters will have been inserted into the VarInfo's
10221022
metadata field.
10231023
1024-
For InitContext, it's quite different: because InitContext is responsible for supplying the
1025-
parameters, we can avoid using `eltype(varinfo)` and instead query the parameters inside it.
1024+
For `InitContext`, it's quite different: because `InitContext` is responsible for supplying
1025+
the parameters, we can avoid using `eltype(varinfo)` and instead query the parameters inside
1026+
it. See the docstring of `get_param_eltype(strategy::AbstractInitStrategy)` for more
1027+
explanation.
10261028
"""
10271029
function get_param_eltype(vi::AbstractVarInfo, ctx::AbstractParentContext)
10281030
return get_param_eltype(vi, DynamicPPL.childcontext(ctx))
10291031
end
10301032
get_param_eltype(vi::AbstractVarInfo, ::AbstractContext) = eltype(vi)
10311033
function get_param_eltype(::AbstractVarInfo, ctx::InitContext)
1032-
return _get_strat_param_eltype(ctx.strategy)
1034+
return get_param_eltype(ctx.strategy)
10331035
end
10341036

1035-
function _get_strat_param_eltype(strategy::InitFromParams{<:VectorWithRanges})
1036-
return eltype(strategy.params.vect)
1037-
end
1038-
function _get_strat_param_eltype(
1039-
strategy::InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}}
1040-
)
1041-
return infer_nested_eltype(typeof(strategy.params))
1042-
end
1043-
# No need to specify a type since new ones are generated
1044-
_get_strat_param_eltype(::Union{InitFromPrior,InitFromUniform}) = Any
1045-
10461037
"""
10471038
getargnames(model::Model)
10481039

0 commit comments

Comments
 (0)