Skip to content

Commit 7ab2efd

Browse files
committed
add some tests
1 parent 4adbcf1 commit 7ab2efd

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

src/functor.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ end
8181
function _default_walk(f, x, ys...)
8282
func, re = functor(x)
8383
yfuncs = map(y -> functor(typeof(x), y)[1], ys)
84+
for yf in yfuncs
85+
nx = propertynames(func)
86+
ny = propertynames(yf)
87+
nx == ny || throw(ArgumentError("names of children must agree, got $nx != $ny"))
88+
end
8489
re(map(f, func, yfuncs...))
8590
end
8691

test/basics.jl

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,24 @@ end
5353
@test (model′.x, model′.y, model′.z) == (1, 4, 3)
5454
end
5555

56+
@testset "cache" begin
57+
shared = [1,2,3]
58+
m1 = Foo(shared, Foo([1,2,3], Foo(shared, [1,2,3])))
59+
m1f = fmap(float, m1)
60+
@test m1f.x === m1f.y.y.x
61+
@test m1f.x !== m1f.y.x
62+
m1p = fmapstructure(identity, m1; prune = nothing)
63+
@test m1p == (x = [1, 2, 3], y = (x = [1, 2, 3], y = (x = nothing, y = [1, 2, 3])))
64+
65+
# A non-leaf node can also be repeated:
66+
m2 = Foo(Foo(shared, 4), Foo(shared, 4))
67+
@test m2.x === m2.y
68+
m2f = fmap(float, m2)
69+
@test m2f.x.x === m2f.y.x
70+
m2p = fmapstructure(identity, m2; prune = Bar(0))
71+
@test m2p == (x = (x = [1, 2, 3], y = 4), y = Bar(0))
72+
end
73+
5674
###
5775
### Extras
5876
###
@@ -91,7 +109,29 @@ end
91109
###
92110

93111
@testset "fmap(f, x, y)" begin
94-
@test true # TODO
112+
m1 = (x = [1,2], y = 3)
113+
n1 = (x = [4,5], y = 6)
114+
@test fmap(+, m1, n1) == (x = [5, 7], y = 9)
115+
116+
# Reconstruction type comes from the first argument
117+
foo1 = Foo([7,8], 9)
118+
@test_broken fmap(+, m1, foo1) == (x = [8, 10], y = 12) # https://github.com/FluxML/Functors.jl/issues/38
119+
@test fmap(+, foo1, n1) isa Foo
120+
@test fmap(+, foo1, n1).x == [11, 13]
121+
122+
# Mismatched trees should be an error
123+
m2 = (x = [1,2], y = (a = [3,4], b = 5))
124+
n2 = (x = [6,7], y = 8)
125+
@test_throws ArgumentError fmap(firsttuple, m2, n2)
126+
@test_broken @test_throws ArgumentError fmap(firsttuple, m2, n2) # now (x = [6, 7], y = 8)
127+
128+
# The cache uses IDs from the first argument
129+
shared = [1,2,3]
130+
m3 = (x = shared, y = [4,5,6], z = shared)
131+
n3 = (x = shared, y = shared, z = [7,8,9])
132+
@test fmap(+, m3, n3) == (x = [2, 4, 6], y = [5, 7, 9], z = [2, 4, 6])
133+
z3 = fmap(+, m3, n3)
134+
@test z3.x === z3.z
95135
end
96136

97137
@testset "old test update.jl" begin

0 commit comments

Comments
 (0)