Skip to content

Commit 75b5d15

Browse files
Added missing impl of output_size for ComposedFunction (#296)
* add `output_size` for `ComposedFunction` * add tests for `output_size` for `ComposedFunction` * bump patch version * Update test/bijectors/stacked.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 04b79dd commit 75b5d15

File tree

3 files changed

+42
-1
lines changed

3 files changed

+42
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Bijectors"
22
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
3-
version = "0.13.7"
3+
version = "0.13.8"
44

55
[deps]
66
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

src/interface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ with_logabsdet_jacobian(f::Columnwise, x::AbstractMatrix) = (f(x), logabsdetjac(
4242
Returns the output size of `f` given the input size `sz`.
4343
"""
4444
output_size(f, sz) = sz
45+
output_size(f::ComposedFunction, sz) = output_size(f.outer, output_size(f.inner, sz))
4546

4647
"""
4748
output_length(f, len::Int)

test/bijectors/stacked.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,44 @@ end
3434
@test y == [exp(1.0), 2.0]
3535
@test binv(y) == [1.0, 2.0, 0.0]
3636
end
37+
38+
@testset "composition" begin
39+
# Composition with one dimension reduction.
40+
b = Stacked((elementwise(exp), ProjectionBijector() identity), [1:1, 2:3])
41+
binv = inverse(b)
42+
x = [1.0, 2.0, 3.0]
43+
y = b(x)
44+
x_ = binv(y)
45+
46+
# Are the values of correct size?
47+
@test size(y) == (2,)
48+
@test size(x_) == (3,)
49+
# Can we determine the sizes correctly?
50+
@test Bijectors.output_size(b, size(x)) == (2,)
51+
@test Bijectors.output_size(binv, size(y)) == (3,)
52+
53+
# Are values correct?
54+
@test y == [exp(1.0), 2.0]
55+
@test binv(y) == [1.0, 2.0, 0.0]
56+
57+
# Composition with two dimension reductions.
58+
b = Stacked(
59+
(elementwise(exp), ProjectionBijector() ProjectionBijector()), [1:1, 2:4]
60+
)
61+
binv = inverse(b)
62+
x = [1.0, 2.0, 3.0, 4.0]
63+
y = b(x)
64+
x_ = binv(y)
65+
66+
# Are the values of correct size?
67+
@test size(y) == (2,)
68+
@test size(x_) == (4,)
69+
# Can we determine the sizes correctly?
70+
@test Bijectors.output_size(b, size(x)) == (2,)
71+
@test Bijectors.output_size(binv, size(y)) == (4,)
72+
73+
# Are values correct?
74+
@test y == [exp(1.0), 2.0]
75+
@test binv(y) == [1.0, 2.0, 0.0, 0.0]
76+
end
3777
end

0 commit comments

Comments
 (0)