Skip to content

Commit 2df5699

Browse files
torfjeldedevmotion
andauthored
Rename stack to stack_transforms (#256)
* renamed stack to stack_transforms * bump version * reverted unintentional change * dont export stack * Apply suggestions from code review * added missing constructor * restricted the 2nd argument of Stacked * Update src/bijectors/stacked.jl Co-authored-by: David Widmann <devmotion@users.noreply.github.com> --------- Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
1 parent 8886728 commit 2df5699

File tree

5 files changed

+7
-11
lines changed

5 files changed

+7
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Bijectors"
22
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
3-
version = "0.12.3"
3+
version = "0.12.4"
44

55
[deps]
66
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

src/Bijectors.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ export TransformDistribution,
6666
Bijector,
6767
Inverse,
6868
Stacked,
69-
stack,
7069
bijector,
7170
transformed,
7271
UnivariateTransformed,

src/bijectors/stacked.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ b = stack(b1, b2)
2121
b([0.0, 1.0]) == [b1(0.0), 1.0] # => true
2222
```
2323
"""
24-
struct Stacked{Bs, Rs} <: Transform
24+
struct Stacked{Bs, Rs<:Union{Tuple,AbstractArray}} <: Transform
2525
bs::Bs
2626
ranges::Rs
2727
end
2828
Stacked(bs::Tuple) = Stacked(bs, ntuple(i -> i:i, length(bs)))
2929
Stacked(bs::AbstractArray) = Stacked(bs, [i:i for i in 1:length(bs)])
30+
Stacked(bs...) = Stacked(bs, ntuple(i -> i:i, length(bs)))
3031

3132
# Avoid mixing tuples and arrays.
3233
Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges)
@@ -47,7 +48,6 @@ isclosedform(b::Stacked) = all(isclosedform, b.bs)
4748

4849
isinvertible(b::Stacked) = all(isinvertible, b.bs)
4950

50-
stack(bs...) = Stacked(bs)
5151

5252
# For some reason `inverse.(sb.bs)` was unstable... This works though.
5353
inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges)

test/interface.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
# using Pkg; Pkg.activate("..")
2-
# using TestEnv; TestEnv.activate()
3-
41
using Test
52
using Random
63
using LinearAlgebra
@@ -204,7 +201,7 @@ end
204201
x = rand(d)
205202
y = b(x)
206203

207-
sb1 = @inferred stack(b, b, inverse(b), inverse(b)) # <= Tuple
204+
sb1 = @inferred Stacked(b, b, inverse(b), inverse(b)) # <= Tuple
208205
res1 = with_logabsdet_jacobian(sb1, [x, x, y, y])
209206
@test sb1(param([x, x, y, y])) isa TrackedArray
210207

@@ -222,7 +219,7 @@ end
222219

223220
# value-test
224221
x = ones(3)
225-
sb = @inferred stack(elementwise(exp), elementwise(log), Shift(5.0))
222+
sb = @inferred Stacked(elementwise(exp), elementwise(log), Shift(5.0))
226223
res = with_logabsdet_jacobian(sb, x)
227224
@test sb(param(x)) isa TrackedArray
228225
@test sb(x) == [exp(x[1]), log(x[2]), x[3] + 5.0]

test/norm_flows.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,10 @@ end
117117
x = rand(d) .+ 10
118118
y = b(x)
119119

120-
sb = stack(b1, b1)
120+
sb = Stacked(b1, b1)
121121
@test all((sb b)(x) .≤ 1.0)
122122

123-
sb = stack(b1, b2)
123+
sb = Stacked(b1, b2)
124124
cb = (sb b)
125125
y = cb(x)
126126
@test (0 y[1] 1.0) && (0 < y[2])

0 commit comments

Comments
 (0)