|
1 | | -@testset "Base" begin |
2 | | - @testset "RefValue" begin |
3 | | - x = Ref(1) |
4 | | - p, re = Functors.functor(x) |
5 | | - @test p == (x = 1,) |
6 | | - @test re(p) isa Base.RefValue{Int} |
7 | | - end |
| 1 | + |
| 2 | +@testset "RefValue" begin |
| 3 | + @test fmap(sqrt, Ref(16))[] == 4.0 |
| 4 | + @test fmap(sqrt, Ref(16)) isa Ref |
| 5 | + @test fmapstructure(sqrt, Ref(16)) === (x = 4.0,) |
| 6 | + |
| 7 | + x = Ref(13) |
| 8 | + p, re = Functors.functor(x) |
| 9 | + @test p == (x = 13,) |
| 10 | + @test re(p) isa Base.RefValue{Int} |
| 11 | + |
| 12 | + x2 = (a = x, b = [7, x, nothing], c = (7, nothing, Ref(13))) |
| 13 | + y2 = fmap(identity, x2) |
| 14 | + @test x2.a !== y2.a # it's a new Ref |
| 15 | + @test y2.a === y2.b[2] # relation is maintained |
| 16 | + @test y2.a !== y2.c[3] # no new relation created |
| 17 | + |
| 18 | + x3 = Ref([3.14]) |
| 19 | + f3 = [Foo(x3, x), x3, x] |
| 20 | + @test f3[1].x === f3[2] |
| 21 | + y3 = fmapstructure(identity, f3) # replaces mutable with immutable |
| 22 | + @test y3[1].x === y3[2] |
| 23 | + @test y3[1].x.x === y3[2].x |
| 24 | + z3 = fmapstructure(identity, y3) |
| 25 | + @test z3[1].x === z3[2] |
| 26 | + @test z3[1].x.x === z3[2].x |
| 27 | +end |
| 28 | + |
| 29 | +@testset "ComposedFunction" begin |
| 30 | + f1 = Foo(1.1, 2.2) |
| 31 | + f2 = Bar(3.3) |
| 32 | + @test Functors.functor(f1 ∘ f2)[1] == (outer = f1, inner = f2) |
| 33 | + @test Functors.functor(f1 ∘ f2)[2]((outer = f1, inner = f2)) == f1 ∘ f2 |
| 34 | + @test fmap(x -> x + 10, f1 ∘ f2) == Foo(11.1, 12.2) ∘ Bar(13.3) |
| 35 | +end |
| 36 | + |
| 37 | +@testset "LinearAlgebra containers" begin |
| 38 | + @test fmapstructure(identity, [1,2,3]') == (parent = [1, 2, 3],) |
| 39 | + @test fmapstructure(identity, transpose([1,2,3])) == (parent = [1, 2, 3],) |
| 40 | + |
| 41 | + CNT = Ref(0) |
| 42 | + fv(x::Vector) = (CNT[]+=1; 10v) |
| 43 | + |
| 44 | + v = [1,2,3] |
| 45 | + nt = fmap(fv, (a=v, b=v', c=transpose(v), d=[1,2,3]')) |
| 46 | + |
| 47 | + @test nt.a === adjoint(nt.b) # does not break tie |
| 48 | + @test nt.a === transpose(nt.c) |
| 49 | + |
| 50 | + @test CNT[] == 2 |
| 51 | + @test nt.a == adjoint(nt.d) # does not create a new tie |
| 52 | + @test nt.a !== adjoint(nt.d) |
| 53 | + |
| 54 | + @test nt.b isa Adjoint |
| 55 | + @test nt.c isa Transpose |
| 56 | + |
| 57 | + x = [1,2,3]' |
| 58 | + xs = fmapstructure(identity, x) # check it digests this, e.g. structural gradient representation |
| 59 | + @test Functors.functor(typeof(x), xs) == Functors.functor(x) # (no real need for [2] types to match) |
| 60 | + |
| 61 | + x = transpose([1 2; 3 4]) |
| 62 | + yt = transpose([5 6; 7 8]) |
| 63 | + ym = Matrix(yt) # check it digests this, e.g. simplest Matrix gradient |
| 64 | + @test Functors.functor(typeof(x), yt)[1].parent == Functors.functor(typeof(x), ym)[1].parent |
| 65 | + |
| 66 | + ybc = Broadcast.broadcasted(+, ym, 9) # check it digests this, as Optimisers.jl makes these |
| 67 | + collect(ybc) isa Vector |
| 68 | + zbc = Functors.functor(typeof(x), ybc)[1].parent |
| 69 | + @test zbc .+ 0 == Functors.functor(typeof(x), ym .+ 9)[1].parent |
| 70 | + |
| 71 | + # Similar checks for Adjoint. |
| 72 | + x = adjoint([1 2im 3; 4im 5 6im]) |
| 73 | + yt = adjoint([7im 8 9; 0 im 2]) |
| 74 | + ym = Matrix(yt) |
| 75 | + @test Functors.functor(typeof(x), yt)[1].parent == Functors.functor(typeof(x), ym)[1].parent |
| 76 | + |
| 77 | + ybc = Broadcast.broadcasted(+, ym, [11im, 12, im]) |
| 78 | + collect(ybc) isa Vector |
| 79 | + zbc = Functors.functor(typeof(x), ybc)[1].parent |
| 80 | + @test zbc .+ 0 == Functors.functor(typeof(x), ym .+ [11im, 12, im])[1].parent |
| 81 | +end |
| 82 | + |
| 83 | +@testset "PermutedDimsArray" begin |
| 84 | + @test fmapstructure(identity, PermutedDimsArray([1 2; 3 4], (2,1))) == (parent = [1 2; 3 4],) |
| 85 | + @test fmap(exp, PermutedDimsArray([1 2; 3 4], (2,1))) isa PermutedDimsArray{Float64} |
8 | 86 | end |
0 commit comments