Skip to content

Commit a4c71e6

Browse files
committed
Refactor FastLDF to use InitContext
1 parent ecca1af commit a4c71e6

File tree

7 files changed

+214
-199
lines changed

7 files changed

+214
-199
lines changed

ext/DynamicPPLEnzymeCoreExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using EnzymeCore
99
nothing
1010
# Likewise for get_range_and_linked.
1111
@inline EnzymeCore.EnzymeRules.inactive(
12-
::typeof(DynamicPPL.Experimental.get_range_and_linked), args...
12+
::typeof(DynamicPPL._get_range_and_linked), args...
1313
) = nothing
1414

1515
end

ext/DynamicPPLMooncakeExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Mooncake: Mooncake
66
# This is purely an optimisation.
77
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg}
88
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{
9-
typeof(DynamicPPL.Experimental.get_range_and_linked),Vararg
9+
typeof(DynamicPPL._get_range_and_linked),Vararg
1010
}
1111

1212
end # module

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ include("abstract_varinfo.jl")
188188
include("threadsafe.jl")
189189
include("varinfo.jl")
190190
include("simple_varinfo.jl")
191+
include("onlyaccs.jl")
191192
include("compiler.jl")
192193
include("pointwise_logdensities.jl")
193194
include("logdensityfunction.jl")

src/contexts/init.jl

Lines changed: 119 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,24 @@ Any subtype of `AbstractInitStrategy` must implement the
99
"""
1010
abstract type AbstractInitStrategy end
1111

12+
"""
13+
InitValue{T,F}
14+
15+
A wrapper type representing a value of type `T`. The function `F` indicates what transform
16+
to apply to the value to convert it back to the unlinked space. If `value` is already in
17+
unlinked space, then `transform` can be `identity`.
18+
"""
19+
struct InitValue{T,F}
20+
value::T
21+
transform::F
22+
end
23+
1224
"""
1325
init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy)
1426
1527
Generate a new value for a random variable with the given distribution.
1628
17-
!!! warning "Return values must be unlinked"
18-
The values returned by `init` must always be in the untransformed space, i.e.,
19-
they must be within the support of the original distribution. That means that,
20-
for example, `init(rng, dist, u::InitFromUniform)` will in general return values that
21-
are outside the range [u.lower, u.upper].
29+
This function must return a `InitValue`.
2230
"""
2331
function init end
2432

@@ -29,7 +37,7 @@ Obtain new values by sampling from the prior distribution.
2937
"""
3038
struct InitFromPrior <: AbstractInitStrategy end
3139
function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior)
32-
return rand(rng, dist)
40+
return InitValue(rand(rng, dist), identity)
3341
end
3442

3543
"""
@@ -69,7 +77,7 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFro
6977
if x isa Array{<:Any,0}
7078
x = x[]
7179
end
72-
return x
80+
return InitValue(x, identity)
7381
end
7482

7583
"""
@@ -93,19 +101,20 @@ will be thrown. The default for `fallback` is `InitFromPrior()`.
93101
struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy
94102
params::P
95103
fallback::S
104+
96105
function InitFromParams(
97-
params::AbstractDict{<:VarName},
98-
fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior(),
99-
)
100-
return new{typeof(params),typeof(fallback)}(params, fallback)
101-
end
102-
function InitFromParams(
103-
params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior()
104-
)
105-
return new{typeof(params),typeof(fallback)}(params, fallback)
106+
params::P, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior()
107+
) where {P}
108+
return new{P,typeof(fallback)}(params, fallback)
106109
end
107110
end
108-
function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams)
111+
112+
function init(
113+
rng::Random.AbstractRNG,
114+
vn::VarName,
115+
dist::Distribution,
116+
p::InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}},
117+
)
109118
# TODO(penelopeysm): It would be nice to do a check to make sure that all
110119
# of the parameters in `p.params` were actually used, and either warn or
111120
# error if they aren't. This is actually quite non-trivial though because
@@ -119,14 +128,82 @@ function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitF
119128
else
120129
# TODO(penelopeysm): Since x is user-supplied, maybe we could also
121130
# check here that the type / size of x matches the dist?
122-
x
131+
InitValue(x, identity)
123132
end
124133
else
125134
p.fallback === nothing && error("No value was provided for the variable `$(vn)`.")
126135
init(rng, vn, dist, p.fallback)
127136
end
128137
end
129138

139+
"""
140+
RangeAndLinked
141+
142+
Suppose we have vectorised parameters `params::AbstractVector{<:Real}`. Each random variable
143+
in the model will in general correspond to a sub-vector of `params`. This struct stores
144+
information about that range, as well as whether the sub-vector represents a linked value or
145+
an unlinked value.
146+
147+
$(TYPEDFIELDS)
148+
"""
149+
struct RangeAndLinked
150+
# indices that the variable corresponds to in the vectorised parameter
151+
range::UnitRange{Int}
152+
# whether it's linked
153+
is_linked::Bool
154+
end
155+
156+
"""
157+
VectorWithRanges(
158+
iden_varname_ranges::NamedTuple,
159+
varname_ranges::Dict{VarName,RangeAndLinked},
160+
vect::AbstractVector{<:Real},
161+
)
162+
163+
A struct that wraps a vector of parameter values, plus information about how random
164+
variables map to ranges in that vector.
165+
166+
In the simplest case, this could be accomplished only with a single dictionary mapping
167+
VarNames to ranges and link status. However, for performance reasons, we separate out
168+
VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All
169+
non-identity-optic VarNames are stored in the `varname_ranges` Dict.
170+
171+
It would be nice to improve the NamedTuple and Dict approach. See, e.g.
172+
https://github.com/TuringLang/DynamicPPL.jl/issues/1116.
173+
"""
174+
struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}}
175+
# This NamedTuple stores the ranges for identity VarNames
176+
iden_varname_ranges::N
177+
# This Dict stores the ranges for all other VarNames
178+
varname_ranges::Dict{VarName,RangeAndLinked}
179+
# The full parameter vector which we index into to get variable values
180+
vect::T
181+
end
182+
183+
function _get_range_and_linked(
184+
vr::VectorWithRanges, ::VarName{sym,typeof(identity)}
185+
) where {sym}
186+
return vr.iden_varname_ranges[sym]
187+
end
188+
function _get_range_and_linked(vr::VectorWithRanges, vn::VarName)
189+
return vr.varname_ranges[vn]
190+
end
191+
function init(
192+
::Random.AbstractRNG,
193+
vn::VarName,
194+
dist::Distribution,
195+
p::InitFromParams{<:VectorWithRanges},
196+
)
197+
vr = p.params
198+
range_and_linked = _get_range_and_linked(vr, vn)
199+
transform = if range_and_linked.is_linked
200+
from_linked_vec_transform(dist)
201+
else
202+
from_vec_transform(dist)
203+
end
204+
return InitValue((@view vr.vect[range_and_linked.range]), transform)
205+
end
206+
130207
"""
131208
InitContext(
132209
[rng::Random.AbstractRNG=Random.default_rng()],
@@ -156,27 +233,43 @@ function tilde_assume!!(
156233
ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo
157234
)
158235
in_varinfo = haskey(vi, vn)
159-
# `init()` always returns values in original space, i.e. possibly
160-
# constrained
161-
x = init(ctx.rng, vn, dist, ctx.strategy)
236+
init_val = init(ctx.rng, vn, dist, ctx.strategy)
237+
x, inv_logjac = with_logabsdet_jacobian(init_val.transform, init_val.value)
162238
# Determine whether to insert a transformed value into the VarInfo.
163239
# If the VarInfo alrady had a value for this variable, we will
164240
# keep the same linked status as in the original VarInfo. If not, we
165241
# check the rest of the VarInfo to see if other variables are linked.
166242
# is_transformed(vi) returns true if vi is nonempty and all variables in vi
167243
# are linked.
168244
insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi)
169-
y, logjac = if insert_transformed_value
170-
with_logabsdet_jacobian(link_transform(dist), x)
245+
val_to_insert, logjac = if insert_transformed_value
246+
# Calculate the forward logjac and sum them up.
247+
y, fwd_logjac = with_logabsdet_jacobian(link_transform(dist), x)
248+
# Note that if we use VectorWithRanges with a full VarInfo, this double-Jacobian
249+
# calculation wastes a lot of time going from linked vectorised -> unlinked ->
250+
# linked, and `inv_logjac` will also just be the negative of `fwd_logjac`. However,
251+
# `VectorWithRanges` is only really used with `OnlyAccsVarInfo`, in which case this
252+
# branch is never hit (since `in_varinfo` will always be false). So we can leave
253+
# this branch in for full generality with other combinations of init strategies /
254+
# VarInfo.
255+
#
256+
# TODO(penelopeysm): Figure out one day how to refactor this. The crux of the issue
257+
# is that the transform used by `VectorWithRanges` is `from_linked_VEC_transform`,
258+
# which is NOT the same as `inverse(link_transform)` (because there is an additional
259+
# vectorisation step). We need `init` and `tilde_assume!!` to share this information
260+
# but it's not clear right now how to do this. In my opinion, the most productive
261+
# way forward would be to standardise the behaviour of bijectors so that we can have
262+
# a clean separation between the linking and vectorisation parts of it.
263+
y, inv_logjac + fwd_logjac
171264
else
172-
x, zero(LogProbType)
265+
x, inv_logjac
173266
end
174267
# Add the new value to the VarInfo. `push!!` errors if the value already
175268
# exists, hence the need for setindex!!.
176269
if in_varinfo
177-
vi = setindex!!(vi, y, vn)
270+
vi = setindex!!(vi, val_to_insert, vn)
178271
else
179-
vi = push!!(vi, vn, y, dist)
272+
vi = push!!(vi, vn, val_to_insert, dist)
180273
end
181274
# Neither of these set the `trans` flag so we have to do it manually if
182275
# necessary.

0 commit comments

Comments
 (0)