Skip to content

Commit 04b79dd

Browse files
Fix for #287 (#288)
* added has_constant_bijector and made bijector of product distributions return the identity whenever possible * no need to limit ourselves to identity for constant bijectors * no need to limit ourselves to identity for Product * bump patch version * Update test/bijectors/ordered.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed tests * attempt at fix for ordered MvTDist test * dispatch on GenericMvTDist instead of TDist * Update test/bijectors/ordered.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added some tests for MvTDist * make elementwise acting on identity return identity * fixed bug in error --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 9419d4f commit 04b79dd

File tree

5 files changed

+57
-18
lines changed

5 files changed

+57
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Bijectors"
22
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
3-
version = "0.13.6"
3+
version = "0.13.7"
44

55
[deps]
66
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

src/interface.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ In the case where `f::ComposedFunction`, the result is
1414
`Base.Fix1(broadcast, f)`.
1515
"""
1616
elementwise(f) = Base.Fix1(broadcast, f)
17+
elementwise(f::typeof(identity)) = identity
1718
# TODO: This is makes dispatching quite a bit easier, but uncertain if this is really
1819
# the way to go.
1920
function elementwise(f::ComposedFunction)
@@ -91,7 +92,7 @@ function transform(t::Transform, x)
9192
res = with_logabsdet_jacobian(t, x)
9293
if res isa ChangesOfVariables.NoLogAbsDetJacobian
9394
error(
94-
"`transform` not implemented for $(typeof(b)); implement `transform` and/or `with_logabsdet_jacobian`.",
95+
"`transform` not implemented for $(typeof(f)); implement `transform` and/or `with_logabsdet_jacobian`.",
9596
)
9697
end
9798

src/transformed_distribution.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,43 @@ function bijector(td::TransformedDistribution)
4747
b = bijector(td.dist)
4848
return b === identity ? inverse(td.transform) : b inverse(td.transform)
4949
end
50+
51+
"""
52+
has_constant_bijector(dist_type::Type)
53+
54+
Returns `true` if the distribution type `dist_type` has a constant bijector,
55+
i.e. the return-value of [`bijector`](@ref) does not depend on runtime information.
56+
"""
57+
has_constant_bijector(d::Type) = false
58+
has_constant_bijector(d::Type{<:Normal}) = true
59+
has_constant_bijector(d::Type{<:Distributions.AbstractMvNormal}) = true
60+
has_constant_bijector(d::Type{<:Distributions.AbstractMvLogNormal}) = true
61+
has_constant_bijector(d::Type{<:TDist}) = true
62+
has_constant_bijector(d::Type{<:Distributions.GenericMvTDist}) = true
63+
has_constant_bijector(d::Type{<:PositiveDistribution}) = true
64+
has_constant_bijector(d::Type{<:SimplexDistribution}) = true
65+
has_constant_bijector(d::Type{<:KSOneSided}) = true
66+
function has_constant_bijector(::Type{<:Product{Continuous,D}}) where {D}
67+
return has_constant_bijector(D)
68+
end
69+
70+
# Container distributions.
5071
bijector(d::DiscreteUnivariateDistribution) = identity
5172
bijector(d::DiscreteMultivariateDistribution) = identity
5273
bijector(d::ContinuousUnivariateDistribution) = TruncatedBijector(minimum(d), maximum(d))
5374
bijector(d::Product{Discrete}) = identity
5475
function bijector(d::Product{Continuous})
55-
return TruncatedBijector(_minmax(d.v)...)
76+
D = eltype(d.v)
77+
return if has_constant_bijector(D)
78+
elementwise(bijector(d.v[1]))
79+
else
80+
# FIXME: This is not great. Should use something like
81+
# `Stacked(map(bijector, d.v))` instead.
82+
# TODO: Specialize. F.ex. for FillArrays.jl we can do much better.
83+
TruncatedBijector(_minmax(d.v)...)
84+
end
5685
end
86+
5787
@generated function _minmax(d::AbstractArray{T}) where {T}
5888
try
5989
min, max = minimum(T), maximum(T)
@@ -63,9 +93,12 @@ end
6393
end
6494
end
6595

96+
# Specialized implementations.
6697
bijector(d::Normal) = identity
6798
bijector(d::Distributions.AbstractMvNormal) = identity
6899
bijector(d::Distributions.AbstractMvLogNormal) = elementwise(log)
100+
bijector(d::TDist) = identity
101+
bijector(d::Distributions.GenericMvTDist) = identity
69102
bijector(d::PositiveDistribution) = elementwise(log)
70103
bijector(d::SimplexDistribution) = SimplexBijector()
71104
bijector(d::KSOneSided) = Logit(zero(eltype(d)), one(eltype(d)))

test/bijectors/ordered.jl

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,23 @@ using LinearAlgebra
1717
end
1818

1919
@testset "ordered" begin
20-
d = MvNormal(1:5, Diagonal(6:10))
21-
d_ordered = ordered(d)
22-
@test d_ordered isa Bijectors.TransformedDistribution
23-
@test d_ordered.dist === d
24-
@test d_ordered.transform isa OrderedBijector
25-
y = randn(5)
26-
x = inverse(bijector(d_ordered))(y)
27-
@test issorted(x)
20+
@testset "$d" for d in [
21+
MvNormal(1:5, Diagonal(6:10)),
22+
MvTDist(1, collect(1.0:5), Matrix(I(5))),
23+
product_distribution(fill(Normal(), 5)),
24+
product_distribution(fill(TDist(1), 5)),
25+
]
26+
d_ordered = ordered(d)
27+
@test d_ordered isa Bijectors.TransformedDistribution
28+
@test d_ordered.dist === d
29+
@test d_ordered.transform isa OrderedBijector
30+
y = randn(5)
31+
x = inverse(bijector(d_ordered))(y)
32+
@test issorted(x)
33+
end
2834

29-
d = Product(fill(Normal(), 5))
30-
# currently errors because `bijector(Product(fill(Normal(), 5)))` is not an `Identity`
31-
@test_broken ordered(d) isa Bijectors.TransformedDistribution
32-
33-
# non-Identity bijector is not supported
34-
d = Dirichlet(ones(5))
35-
@test_throws ArgumentError ordered(d)
35+
@testset "non-identity bijector is not supported" begin
36+
d = Dirichlet(ones(5))
37+
@test_throws ArgumentError ordered(d)
38+
end
3639
end

test/interface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ end
136136
MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10))))),
137137
Dirichlet([1000 * one(Float64), eps(Float64)]),
138138
Dirichlet([eps(Float64), 1000 * one(Float64)]),
139+
MvTDist(1, randn(10), Matrix(Diagonal(exp.(randn(10))))),
139140
transformed(MvNormal(randn(10), Diagonal(exp.(randn(10))))),
140141
transformed(MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10)))))),
141142
transformed(reshape(product_distribution(fill(InverseGamma(2, 3), 6)), 2, 3)),
@@ -200,6 +201,7 @@ end
200201
TuringInverseWishart(v, S),
201202
LKJ(3, 1.0),
202203
reshape(MvNormal(zeros(6), I), 2, 3),
204+
product_distribution(fill(InverseGamma(2, 3), 6)),
203205
]
204206

205207
for dist in matrix_dists

0 commit comments

Comments
 (0)