Skip to content

Commit 529f6f6

Browse files
torfjeldedevmotionyebai
authored
Added Reshape (#257)
* initial work * added tests for reshape * bump patch version * added tests for some reshaped distribtuions * minor change * added docstring to Reshape * bump distributions version * implement with_logabsdet_jacobian instead * Update src/bijectors/reshape.jl Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * fixed typo --------- Co-authored-by: David Widmann <devmotion@users.noreply.github.com> Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
1 parent 2df5699 commit 529f6f6

File tree

7 files changed

+50
-3
lines changed

7 files changed

+50
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ ArgCheck = "1, 2"
2626
ChainRulesCore = "0.10.11, 1"
2727
ChangesOfVariables = "0.1"
2828
Compat = "3, 4"
29-
Distributions = "0.23.3, 0.24, 0.25"
29+
Distributions = "0.25.33"
3030
Functors = "0.1, 0.2, 0.3, 0.4"
3131
InverseFunctions = "0.1"
3232
IrrationalConstants = "0.1, 0.2"

src/bijectors/reshape.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""
2+
Reshape(in_shape, out_shape)
3+
4+
A [`Bijector`](@ref) that reshapes the input to the output shape.
5+
6+
# Example
7+
8+
```jldoctest
9+
julia> using Bijectors: Reshape
10+
11+
julia> b = Reshape((2, 3), (3, 2))
12+
Reshape{Tuple{Int64, Int64}, Tuple{Int64, Int64}}((2, 3), (3, 2))
13+
14+
julia> Array(transform(b, reshape(1:6, 2, 3)))
15+
3×2 Matrix{Int64}:
16+
1 4
17+
2 5
18+
3 6
19+
"""
20+
struct Reshape{N1,N2} <: Bijector
21+
in_shape::NTuple{N1,Int}
22+
out_shape::NTuple{N2,Int}
23+
end
24+
25+
inverse(b::Reshape) = Reshape(b.out_shape, b.in_shape)
26+
27+
with_logabsdet_jacobian(b::Reshape, x) = reshape(x, b.out_shape), zero(eltype(x))

src/interface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ logabsdetjac!(::typeof(identity), x, logjac) = logjac
200200
# General
201201
include("bijectors/composed.jl")
202202
include("bijectors/stacked.jl")
203+
include("bijectors/reshape.jl")
203204

204205
# Specific
205206
include("bijectors/exp_log.jl")

src/transformed_distribution.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,13 @@ bijector(d::MatrixBeta) = PDBijector()
7979

8080
bijector(d::LKJ) = CorrBijector()
8181

82+
function bijector(d::Distributions.ReshapedDistribution)
83+
inner_dims = size(d.dist)
84+
outer_dims = d.dims
85+
b = Reshape(outer_dims, inner_dims)
86+
return inverse(b) bijector(d.dist) b
87+
end
88+
8289
##############################
8390
# Distributions.jl interface #
8491
##############################

test/bijectors/reshape.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
using Bijectors: Reshape
2+
3+
@testset "Reshape" begin
4+
dist = reshape(product_distribution(fill(InverseGamma(2, 3), 10)), 2, 5)
5+
b = bijector(dist)
6+
7+
x = rand(dist)
8+
test_bijector(b, x)
9+
end

test/interface.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ end
120120
Dirichlet([1000 * one(Float64), eps(Float64)]),
121121
Dirichlet([eps(Float64), 1000 * one(Float64)]),
122122
transformed(MvNormal(randn(10), Diagonal(exp.(randn(10))))),
123-
transformed(MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10))))))
123+
transformed(MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10)))))),
124+
transformed(reshape(product_distribution(fill(InverseGamma(2, 3), 6)), 2, 3)),
124125
]
125126

126127
for dist in vector_dists
@@ -172,7 +173,8 @@ end
172173
InverseWishart(v,S),
173174
TuringWishart(v,S),
174175
TuringInverseWishart(v,S),
175-
LKJ(3, 1.)
176+
LKJ(3, 1.),
177+
reshape(MvNormal(zeros(6), I), 2, 3),
176178
]
177179

178180
for dist in matrix_dists

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ if GROUP == "All" || GROUP == "Interface"
3636
include("bijectors/coupling.jl")
3737
include("bijectors/ordered.jl")
3838
include("bijectors/pd.jl")
39+
include("bijectors/reshape.jl")
3940
end
4041

4142
if GROUP == "All" || GROUP == "AD"

0 commit comments

Comments
 (0)