Skip to content

Commit 8eb1ef7

Browse files
committed
fix definition order
1 parent 4a15bb7 commit 8eb1ef7

File tree

1 file changed

+68
-68
lines changed

1 file changed

+68
-68
lines changed

src/contexts/init.jl

Lines changed: 68 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,71 @@
1+
"""
2+
RangeAndLinked
3+
4+
Suppose we have vectorised parameters `params::AbstractVector{<:Real}`. Each random variable
5+
in the model will in general correspond to a sub-vector of `params`. This struct stores
6+
information about that range, as well as whether the sub-vector represents a linked value or
7+
an unlinked value.
8+
9+
$(TYPEDFIELDS)
10+
"""
11+
struct RangeAndLinked
12+
# indices that the variable corresponds to in the vectorised parameter
13+
range::UnitRange{Int}
14+
# whether it's linked
15+
is_linked::Bool
16+
end
17+
18+
"""
19+
VectorWithRanges(
20+
iden_varname_ranges::NamedTuple,
21+
varname_ranges::Dict{VarName,RangeAndLinked},
22+
vect::AbstractVector{<:Real},
23+
)
24+
25+
A struct that wraps a vector of parameter values, plus information about how random
26+
variables map to ranges in that vector.
27+
28+
In the simplest case, this could be accomplished only with a single dictionary mapping
29+
VarNames to ranges and link status. However, for performance reasons, we separate out
30+
VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All
31+
non-identity-optic VarNames are stored in the `varname_ranges` Dict.
32+
33+
It would be nice to improve the NamedTuple and Dict approach. See, e.g.
34+
https://github.com/TuringLang/DynamicPPL.jl/issues/1116.
35+
"""
36+
struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}}
37+
# This NamedTuple stores the ranges for identity VarNames
38+
iden_varname_ranges::N
39+
# This Dict stores the ranges for all other VarNames
40+
varname_ranges::Dict{VarName,RangeAndLinked}
41+
# The full parameter vector which we index into to get variable values
42+
vect::T
43+
end
44+
45+
function _get_range_and_linked(
46+
vr::VectorWithRanges, ::VarName{sym,typeof(identity)}
47+
) where {sym}
48+
return vr.iden_varname_ranges[sym]
49+
end
50+
function _get_range_and_linked(vr::VectorWithRanges, vn::VarName)
51+
return vr.varname_ranges[vn]
52+
end
53+
function init(
54+
::Random.AbstractRNG,
55+
vn::VarName,
56+
dist::Distribution,
57+
p::InitFromParams{<:VectorWithRanges},
58+
)
59+
vr = p.params
60+
range_and_linked = _get_range_and_linked(vr, vn)
61+
transform = if range_and_linked.is_linked
62+
from_linked_vec_transform(dist)
63+
else
64+
from_vec_transform(dist)
65+
end
66+
return (@view vr.vect[range_and_linked.range]), transform
67+
end
68+
169
"""
270
AbstractInitStrategy
371
@@ -194,74 +262,6 @@ function init(
194262
end
195263
end
196264

197-
"""
198-
RangeAndLinked
199-
200-
Suppose we have vectorised parameters `params::AbstractVector{<:Real}`. Each random variable
201-
in the model will in general correspond to a sub-vector of `params`. This struct stores
202-
information about that range, as well as whether the sub-vector represents a linked value or
203-
an unlinked value.
204-
205-
$(TYPEDFIELDS)
206-
"""
207-
struct RangeAndLinked
208-
# indices that the variable corresponds to in the vectorised parameter
209-
range::UnitRange{Int}
210-
# whether it's linked
211-
is_linked::Bool
212-
end
213-
214-
"""
215-
VectorWithRanges(
216-
iden_varname_ranges::NamedTuple,
217-
varname_ranges::Dict{VarName,RangeAndLinked},
218-
vect::AbstractVector{<:Real},
219-
)
220-
221-
A struct that wraps a vector of parameter values, plus information about how random
222-
variables map to ranges in that vector.
223-
224-
In the simplest case, this could be accomplished only with a single dictionary mapping
225-
VarNames to ranges and link status. However, for performance reasons, we separate out
226-
VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All
227-
non-identity-optic VarNames are stored in the `varname_ranges` Dict.
228-
229-
It would be nice to improve the NamedTuple and Dict approach. See, e.g.
230-
https://github.com/TuringLang/DynamicPPL.jl/issues/1116.
231-
"""
232-
struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}}
233-
# This NamedTuple stores the ranges for identity VarNames
234-
iden_varname_ranges::N
235-
# This Dict stores the ranges for all other VarNames
236-
varname_ranges::Dict{VarName,RangeAndLinked}
237-
# The full parameter vector which we index into to get variable values
238-
vect::T
239-
end
240-
241-
function _get_range_and_linked(
242-
vr::VectorWithRanges, ::VarName{sym,typeof(identity)}
243-
) where {sym}
244-
return vr.iden_varname_ranges[sym]
245-
end
246-
function _get_range_and_linked(vr::VectorWithRanges, vn::VarName)
247-
return vr.varname_ranges[vn]
248-
end
249-
function init(
250-
::Random.AbstractRNG,
251-
vn::VarName,
252-
dist::Distribution,
253-
p::InitFromParams{<:VectorWithRanges},
254-
)
255-
vr = p.params
256-
range_and_linked = _get_range_and_linked(vr, vn)
257-
transform = if range_and_linked.is_linked
258-
from_linked_vec_transform(dist)
259-
else
260-
from_vec_transform(dist)
261-
end
262-
return (@view vr.vect[range_and_linked.range]), transform
263-
end
264-
265265
"""
266266
InitContext(
267267
[rng::Random.AbstractRNG=Random.default_rng()],

0 commit comments

Comments
 (0)