Skip to content

Commit c49a3c6

Browse files
committed
fix transforms for pathological distributions
1 parent 02451a3 commit c49a3c6

File tree

2 files changed

+51
-10
lines changed

2 files changed

+51
-10
lines changed

src/contexts/init.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Obtain new values by sampling from the prior distribution.
2929
"""
3030
struct InitFromPrior <: AbstractInitStrategy end
3131
function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior)
32-
return rand(rng, dist), identity
32+
return rand(rng, dist), _typed_identity
3333
end
3434

3535
"""
@@ -69,7 +69,7 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFro
6969
if x isa Array{<:Any,0}
7070
x = x[]
7171
end
72-
return x, identity
72+
return x, _typed_identity
7373
end
7474

7575
"""
@@ -120,7 +120,7 @@ function init(
120120
else
121121
# TODO(penelopeysm): Since x is user-supplied, maybe we could also
122122
# check here that the type / size of x matches the dist?
123-
x, identity
123+
x, _typed_identity
124124
end
125125
else
126126
p.fallback === nothing && error("No value was provided for the variable `$(vn)`.")
@@ -238,19 +238,25 @@ function tilde_assume!!(
238238
y, fwd_logjac = with_logabsdet_jacobian(link_transform(dist), x)
239239
# Note that if we use VectorWithRanges with a full VarInfo, this double-Jacobian
240240
# calculation wastes a lot of time going from linked vectorised -> unlinked ->
241-
# linked, and `inv_logjac` will also just be the negative of `fwd_logjac`. However,
242-
# `VectorWithRanges` is only really used with `OnlyAccsVarInfo`, in which case this
243-
# branch is never hit (since `in_varinfo` will always be false). So we can leave
244-
# this branch in for full generality with other combinations of init strategies /
245-
# VarInfo.
241+
# linked, and `inv_logjac` will also just be the negative of `fwd_logjac`.
242+
#
243+
# However, `VectorWithRanges` is only really used with `OnlyAccsVarInfo`, in which
244+
# case this branch is never hit (since `in_varinfo` will always be false). It does
245+
# mean that the combination of InitFromParams{<:VectorWithRanges} with a full,
246+
# linked, VarInfo will be very slow. That should never really be used, though. So
247+
# (at least for now) we can leave this branch in for full generality with other
248+
# combinations of init strategies / VarInfo.
246249
#
247250
# TODO(penelopeysm): Figure out one day how to refactor this. The crux of the issue
248251
# is that the transform used by `VectorWithRanges` is `from_linked_VEC_transform`,
249252
# which is NOT the same as `inverse(link_transform)` (because there is an additional
250253
# vectorisation step). We need `init` and `tilde_assume!!` to share this information
251254
# but it's not clear right now how to do this. In my opinion, the most productive
252-
# way forward would be to standardise the behaviour of bijectors so that we can have
253-
# a clean separation between the linking and vectorisation parts of it.
255+
# way forward would be to clean up the behaviour of bijectors so that we can have a
256+
# clean separation between the linking and vectorisation parts of it. That way, `x`
257+
# can either be unlinked, unlinked vectorised, linked, or linked vectorised, and
258+
# regardless of which it is, we should only need to apply at most one linking and
259+
# one vectorisation transform.
254260
y, -inv_logjac + fwd_logjac
255261
else
256262
x, -inv_logjac

src/utils.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,41 @@ This is Float64 on 64-bit systems and Float32 on 32-bit systems.
1515
"""
1616
const LogProbType = float(Real)
1717

18+
"""
19+
_typed_identity(x)
20+
21+
Identity function, but with an overload for `with_logabsdet_jacobian` to ensure
22+
that it returns a sensible zero logjac.
23+
24+
The problem with plain old `identity` is that the default definition of
25+
`with_logabsdet_jacobian` for `identity` returns `zero(eltype(x))`:
26+
https://github.com/JuliaMath/ChangesOfVariables.jl/blob/d6a8115fc9b9419decbdb48e2c56ec9675b4c6a4/src/with_ladj.jl#L154
27+
28+
This is fine for most samples `x`, but if `eltype(x)` doesn't return a sensible type (e.g.
29+
if it's `Any`), then using `identity` will error with `zero(Any)`. This can happen with,
30+
for example, `ProductNamedTupleDistribution`:
31+
32+
```julia
33+
julia> using Distributions; d = product_distribution((a = Normal(), b = LKJCholesky(3, 0.5)));
34+
35+
julia> eltype(rand(d))
36+
Any
37+
```
38+
39+
The same problem precludes us from eventually broadening the scope of DynamicPPL.jl to
40+
support distributions with non-numeric samples.
41+
42+
Furthermore, in principle, the type of the log-probability should be separate from the type
43+
of the sample. Thus, instead of using `zero(LogProbType)`, we should use the eltype of the
44+
LogJacobianAccumulator. There's no easy way to thread that through here, but if a way to do
45+
this is discovered, then `_typed_identity` is what will allow us to obtain that custom
46+
behaviour.
47+
"""
48+
function _typed_identity end
49+
@inline _typed_identity(x) = x
50+
@inline Bijectors.with_logabsdet_jacobian(::typeof(_typed_identity), x) =
51+
(x, zero(LogProbType))
52+
1853
"""
1954
@addlogprob!(ex)
2055

0 commit comments

Comments
 (0)