Skip to content

Commit 00bf10c

Browse files
authored
Fix for #255 + some other DistributionsAD-stuff (#259)
* fix for #255 and introduction of columnwise * added some tests * version bump * forgot to add the inverse
1 parent 529f6f6 commit 00bf10c

File tree

4 files changed

+35
-4
lines changed

4 files changed

+35
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Bijectors"
22
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
3-
version = "0.12.4"
3+
version = "0.12.5"
44

55
[deps]
66
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

src/compat/distributionsad.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ bijector(::TuringScalMvNormal) = identity
1313
bijector(::TuringDiagMvNormal) = identity
1414
bijector(::TuringDenseMvNormal) = identity
1515

16-
bijector(d::FillVectorOfUnivariate{Continuous}) = bijector(d.v.value)
17-
bijector(d::FillMatrixOfUnivariate{Continuous}) = up1(bijector(d.dists.value))
16+
bijector(d::FillVectorOfUnivariate{Continuous}) = elementwise(bijector(d.v.value))
17+
bijector(d::FillMatrixOfUnivariate{Continuous}) = elementwise(bijector(d.dists.value))
1818
bijector(d::MatrixOfUnivariate{Discrete}) = identity
1919
bijector(d::MatrixOfUnivariate{Continuous}) = TruncatedBijector(_minmax(d.dists)...)
2020
bijector(d::VectorOfMultivariate{Discrete}) = identity
@@ -30,7 +30,7 @@ for T in (:VectorOfMultivariate, :FillVectorOfMultivariate)
3030
bijector(d::$T{Continuous, <:TuringDirichlet}) = SimplexBijector()
3131
end
3232
end
33-
bijector(d::FillVectorOfMultivariate{Continuous}) = bijector(d.dists.value)
33+
bijector(d::FillVectorOfMultivariate{Continuous}) = columnwise(bijector(d.dists.value))
3434

3535
isdirichlet(::VectorOfMultivariate{Continuous, <:Dirichlet}) = true
3636
isdirichlet(::VectorOfMultivariate{Continuous, <:TuringDirichlet}) = true

src/interface.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,21 @@ elementwise(f) = Base.Fix1(broadcast, f)
1717
# TODO: This is makes dispatching quite a bit easier, but uncertain if this is really
1818
# the way to go.
1919
elementwise(f::ComposedFunction) = ComposedFunction(elementwise(f.outer), elementwise(f.inner))
20+
const Columnwise{F} = Base.Fix1{typeof(eachcolmaphcat),F}
21+
"""
22+
23+
Alias for `Base.Fix1(eachcolmaphcat, f)`.
24+
25+
Represents a function `f` which is applied to each column of an input.
26+
"""
27+
columnwise(f) = Base.Fix1(eachcolmaphcat, f)
28+
inverse(f::Columnwise) = columnwise(inverse(f.x))
29+
30+
transform(f::Columnwise, x::AbstractMatrix) = f(x)
31+
function logabsdetjac(f::Columnwise, x::AbstractMatrix)
32+
return sum(Base.Fix1(logabsdetjac, f.x), eachcol(x))
33+
end
34+
with_logabsdet_jacobian(f::Columnwise, x::AbstractMatrix) = (f(x), logabsdetjac(f, x))
2035

2136
######################
2237
# Bijector interface #

test/interface.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,22 @@ end
196196
end
197197
end
198198

199+
@testset "DistributionsAD" begin
200+
@testset "$dist" for dist in [
201+
filldist(Normal(), 2),
202+
filldist(Normal(), 2, 3),
203+
filldist(Exponential(), 2),
204+
filldist(Exponential(), 2, 3),
205+
filldist(filldist(Exponential(), 2), 3),
206+
]
207+
x = rand(dist)
208+
b = bijector(dist)
209+
y = b(x)
210+
td = transformed(dist)
211+
@test logpdf(dist, x) - logabsdetjac(b, x) logpdf(td, y)
212+
end
213+
end
214+
199215
@testset "Stacked <: Bijector" begin
200216
# `logabsdetjac` withOUT AD
201217
d = Beta()

0 commit comments

Comments
 (0)