|
53 | 53 | @test (model′.x, model′.y, model′.z) == (1, 4, 3) |
54 | 54 | end |
55 | 55 |
|
| 56 | +@testset "cache" begin |
| 57 | + shared = [1,2,3] |
| 58 | + m1 = Foo(shared, Foo([1,2,3], Foo(shared, [1,2,3]))) |
| 59 | + m1f = fmap(float, m1) |
| 60 | + @test m1f.x === m1f.y.y.x |
| 61 | + @test m1f.x !== m1f.y.x |
| 62 | + m1p = fmapstructure(identity, m1; prune = nothing) |
| 63 | + @test m1p == (x = [1, 2, 3], y = (x = [1, 2, 3], y = (x = nothing, y = [1, 2, 3]))) |
| 64 | + |
| 65 | + # A non-leaf node can also be repeated: |
| 66 | + m2 = Foo(Foo(shared, 4), Foo(shared, 4)) |
| 67 | + @test m2.x === m2.y |
| 68 | + m2f = fmap(float, m2) |
| 69 | + @test m2f.x.x === m2f.y.x |
| 70 | + m2p = fmapstructure(identity, m2; prune = Bar(0)) |
| 71 | + @test m2p == (x = (x = [1, 2, 3], y = 4), y = Bar(0)) |
| 72 | +end |
| 73 | + |
56 | 74 | ### |
57 | 75 | ### Extras |
58 | 76 | ### |
|
91 | 109 | ### |
92 | 110 |
|
93 | 111 | @testset "fmap(f, x, y)" begin |
94 | | - @test true # TODO |
| 112 | + m1 = (x = [1,2], y = 3) |
| 113 | + n1 = (x = [4,5], y = 6) |
| 114 | + @test fmap(+, m1, n1) == (x = [5, 7], y = 9) |
| 115 | + |
| 116 | + # Reconstruction type comes from the first argument |
| 117 | + foo1 = Foo([7,8], 9) |
| 118 | + @test_broken fmap(+, m1, foo1) == (x = [8, 10], y = 12) # https://github.com/FluxML/Functors.jl/issues/38 |
| 119 | + @test fmap(+, foo1, n1) isa Foo |
| 120 | + @test fmap(+, foo1, n1).x == [11, 13] |
| 121 | + |
| 122 | + # Mismatched trees should be an error |
| 123 | + m2 = (x = [1,2], y = (a = [3,4], b = 5)) |
| 124 | + n2 = (x = [6,7], y = 8) |
| 125 | + @test_throws ArgumentError fmap(first∘tuple, m2, n2) |
| 126 | + @test_broken @test_throws ArgumentError fmap(first∘tuple, m2, n2) # now (x = [6, 7], y = 8) |
| 127 | + |
| 128 | + # The cache uses IDs from the first argument |
| 129 | + shared = [1,2,3] |
| 130 | + m3 = (x = shared, y = [4,5,6], z = shared) |
| 131 | + n3 = (x = shared, y = shared, z = [7,8,9]) |
| 132 | + @test fmap(+, m3, n3) == (x = [2, 4, 6], y = [5, 7, 9], z = [2, 4, 6]) |
| 133 | + z3 = fmap(+, m3, n3) |
| 134 | + @test z3.x === z3.z |
95 | 135 | end |
96 | 136 |
|
97 | 137 | @testset "old test update.jl" begin |
|
0 commit comments