@@ -2,23 +2,112 @@ module AbstractPPLDistributionsExt
22
33using AbstractPPL: AbstractPPL, VarName, Accessors
44using Distributions: Distributions
5+ using LinearAlgebra: Cholesky, LowerTriangular, UpperTriangular
6+
7+ #=
8+ This section is copied from Accessors.jl's documentation:
9+ https://juliaobjects.github.io/Accessors.jl/stable/examples/custom_macros/
10+
11+ It defines a wrapper that, when called with `set`, mutates the original value
12+ rather than returning a new value. We need this because the non-mutating optics
13+ don't work for triangular matrices (and hence LKJCholesky): see
14+ https://github.com/JuliaObjects/Accessors.jl/issues/203
15+ =#
16+ struct Lens!{L}
17+ pure:: L
18+ end
19+ (l:: Lens! )(o) = l. pure (o)
20+ function Accessors. set (o, l:: Lens!{<:ComposedFunction} , val)
21+ o_inner = l. pure. inner (o)
22+ return Accessors. set (o_inner, Lens! (l. pure. outer), val)
23+ end
24+ function Accessors. set (o, l:: Lens!{Accessors.PropertyLens{prop}} , val) where {prop}
25+ setproperty! (o, prop, val)
26+ return o
27+ end
28+ function Accessors. set (o, l:: Lens!{<:Accessors.IndexLens} , val)
29+ o[l. pure. indices... ] = val
30+ return o
31+ end
32+
33+ """
34+ get_optics(dist::MultivariateDistribution)
35+ get_optics(dist::MatrixDistribution)
36+ get_optics(dist::LKJCholesky)
37+
38+ Return a complete set of optics for each element of the type returned by `rand(dist)`.
39+ """
40+ function get_optics (
41+ dist:: Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution}
42+ )
43+ indices = CartesianIndices (size (dist))
44+ return map (idx -> Accessors. IndexLens (idx. I), indices)
45+ end
46+ function get_optics (dist:: Distributions.LKJCholesky )
47+ is_up = dist. uplo == ' U'
48+ cartesian_indices = filter (CartesianIndices (size (dist))) do cartesian_index
49+ i, j = cartesian_index. I
50+ is_up ? i <= j : i >= j
51+ end
52+ # there is an additional layer as we need to access `.L` or `.U` before we
53+ # can index into it
54+ field_lens = is_up ? (Accessors. @o _. U) : (Accessors. @o _. L)
55+ return map (idx -> Accessors. IndexLens (idx. I) ∘ field_lens, cartesian_indices)
56+ end
57+
58+ """
59+ make_empty_value(dist::MultivariateDistribution)
60+ make_empty_value(dist::MatrixDistribution)
61+ make_empty_value(dist::LKJCholesky)
62+
63+ Construct a fresh value filled with zeros that corresponds to the size of `dist`.
64+
65+ For all distributions that this function accepts, it should hold that
66+ `o(make_empty_value(dist))` is zero for all `o` in `get_optics(dist)`.
67+ """
68+ function make_empty_value (
69+ dist:: Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution}
70+ )
71+ return zeros (size (dist))
72+ end
73+ function make_empty_value (dist:: Distributions.LKJCholesky )
74+ if dist. uplo == ' U'
75+ return Cholesky (UpperTriangular (zeros (size (dist))))
76+ else
77+ return Cholesky (LowerTriangular (zeros (size (dist))))
78+ end
79+ end
580
681# TODO (penelopeysm): Figure out tuple / namedtuple distributions, and LKJCholesky (grr)
782function AbstractPPL. hasvalue (
8- vals:: AbstractDict , vn:: VarName , dist:: Distributions.Distribution
83+ vals:: AbstractDict ,
84+ vn:: VarName ,
85+ dist:: Distributions.Distribution ;
86+ error_on_incomplete:: Bool = false ,
987)
1088 @warn " `hasvalue(vals, vn, dist)` is not implemented for $(typeof (dist)) ; falling back to `hasvalue(vals, vn)`."
1189 return AbstractPPL. hasvalue (vals, vn)
1290end
1391function AbstractPPL. hasvalue (
14- vals:: AbstractDict , vn:: VarName , :: Distributions.UnivariateDistribution
92+ vals:: AbstractDict ,
93+ vn:: VarName ,
94+ :: Distributions.UnivariateDistribution ;
95+ error_on_incomplete:: Bool = false ,
1596)
97+ # TODO (penelopeysm): We could also implement a check for the type to catch
98+ # invalid values. Unsure if that is worth it. It may be easier to just let
99+ # the user handle it.
16100 return AbstractPPL. hasvalue (vals, vn)
17101end
18102function AbstractPPL. hasvalue (
19103 vals:: AbstractDict{<:VarName} ,
20104 vn:: VarName{sym} ,
21- dist:: Union{Distributions.MultivariateDistribution,Distributions.MatrixDistribution} ,
105+ dist:: Union {
106+ Distributions. MultivariateDistribution,
107+ Distributions. MatrixDistribution,
108+ Distributions. LKJCholesky,
109+ };
110+ error_on_incomplete:: Bool = false ,
22111) where {sym}
23112 # If `vn` is present as-is, then we are good
24113 AbstractPPL. hasvalue (vals, vn) && return true
@@ -30,13 +119,66 @@ function AbstractPPL.hasvalue(
30119 # To do this, we get the size of the distribution and iterate over all
31120 # possible indices. If every index can be found in `subsumed_keys`, then we
32121 # can return true.
33- sz = size (dist)
34- for idx in Iterators. product (map (Base. OneTo, sz)... )
35- new_optic = Accessors. IndexLens (idx) ∘ AbstractPPL. getoptic (vn)
36- new_vn = VarName {sym} (new_optic)
37- AbstractPPL. hasvalue (vals, new_vn) || return false
122+ optics = get_optics (dist)
123+ original_optic = AbstractPPL. getoptic (vn)
124+ expected_vns = map (o -> VarName {sym} (o ∘ original_optic), optics)
125+ if all (sub_vn -> AbstractPPL. hasvalue (vals, sub_vn), expected_vns)
126+ return true
127+ else
128+ if error_on_incomplete &&
129+ any (sub_vn -> AbstractPPL. hasvalue (vals, sub_vn), expected_vns)
130+ error (" hasvalue: only partial values for `$vn ` found in the values provided" )
131+ end
132+ return false
133+ end
134+ end
135+
136+ function AbstractPPL. getvalue (
137+ vals:: AbstractDict , vn:: VarName , dist:: Distributions.Distribution ;
138+ )
139+ @warn " `getvalue(vals, vn, dist)` is not implemented for $(typeof (dist)) ; falling back to `getvalue(vals, vn)`."
140+ return AbstractPPL. getvalue (vals, vn)
141+ end
142+ function AbstractPPL. getvalue (
143+ vals:: AbstractDict , vn:: VarName , :: Distributions.UnivariateDistribution ;
144+ )
145+ # TODO (penelopeysm): We could also implement a check for the type to catch
146+ # invalid values. Unsure if that is worth it. It may be easier to just let
147+ # the user handle it.
148+ return AbstractPPL. getvalue (vals, vn)
149+ end
150+ function AbstractPPL. getvalue (
151+ vals:: AbstractDict{<:VarName} ,
152+ vn:: VarName{sym} ,
153+ dist:: Union {
154+ Distributions. MultivariateDistribution,
155+ Distributions. MatrixDistribution,
156+ Distributions. LKJCholesky,
157+ };
158+ ) where {sym}
159+ # If `vn` is present as-is, then we can just return that
160+ AbstractPPL. hasvalue (vals, vn) && return AbstractPPL. getvalue (vals, vn)
161+ # If not, then we need to start looking inside `vals`, in exactly the
162+ # same way we did for `hasvalue`.
163+ optics = get_optics (dist)
164+ original_optic = AbstractPPL. getoptic (vn)
165+ expected_vns = map (o -> VarName {sym} (o ∘ original_optic), optics)
166+ if all (sub_vn -> AbstractPPL. hasvalue (vals, sub_vn), expected_vns)
167+ # Reconstruct the value index by index.
168+ value = make_empty_value (dist)
169+ for (o, sub_vn) in zip (optics, expected_vns)
170+ # Retrieve the value of this given index
171+ sub_value = AbstractPPL. getvalue (vals, sub_vn)
172+ # Set it inside the value we're reconstructing.
173+ # Note: `o` is normally non-mutating. We have to wrap it in `Lens!`
174+ # to make it mutating, because Cholesky distributions are broken
175+ # by https://github.com/JuliaObjects/Accessors.jl/issues/203.
176+ Accessors. set (value, Lens! (o), sub_value)
177+ end
178+ return value
179+ else
180+ error (" getvalue: $(vn) was not found in the values provided" )
38181 end
39- return true
40182end
41183
42184end
0 commit comments