|
| 1 | +# Test struct for `rand_tangent` and `difference`. |
| 2 | +struct Bar |
| 3 | + a::Float64 |
| 4 | + b::Int |
| 5 | + c::Any |
| 6 | + end |
| 7 | +@testset "rand_tangent" begin |
| 8 | + rng = MersenneTwister(123456) |
| 9 | + |
| 10 | + @testset "Primal: $(typeof(x)), Tangent: $T_tangent" for (x, T_tangent) in [ |
| 11 | + |
| 12 | + # Things without sensible tangents. |
| 13 | + ("hi", NoTangent), |
| 14 | + ('a', NoTangent), |
| 15 | + (:a, NoTangent), |
| 16 | + (true, NoTangent), |
| 17 | + (4, NoTangent), |
| 18 | + (FiniteDifferences, NoTangent), # Module object |
| 19 | + # Types (not instances of type) |
| 20 | + (Bar, NoTangent), |
| 21 | + (Union{Int, Bar}, NoTangent), |
| 22 | + (Union{Int, Bar}, NoTangent), |
| 23 | + (Vector, NoTangent), |
| 24 | + (Vector{Float64}, NoTangent), |
| 25 | + (Integer, NoTangent), |
| 26 | + (Type{<:Real}, NoTangent), |
| 27 | + |
| 28 | + # Numbers. |
| 29 | + (5.0, Float64), |
| 30 | + (5.0 + 0.4im, Complex{Float64}), |
| 31 | + (big(5.0), BigFloat), |
| 32 | + |
| 33 | + # StridedArrays. |
| 34 | + (fill(randn(Float32)), Array{Float32, 0}), |
| 35 | + (fill(randn(Float64)), Array{Float64, 0}), |
| 36 | + (randn(Float32, 3), Vector{Float32}), |
| 37 | + (randn(Complex{Float64}, 2), Vector{Complex{Float64}}), |
| 38 | + (randn(5, 4), Matrix{Float64}), |
| 39 | + (randn(Complex{Float32}, 5, 4), Matrix{Complex{Float32}}), |
| 40 | + ([randn(5, 4), 4.0], Vector{Any}), |
| 41 | + |
| 42 | + # Wrapper Arrays |
| 43 | + (randn(5, 4)', Adjoint{Float64, Matrix{Float64}}), |
| 44 | + (transpose(randn(5, 4)), Transpose{Float64, Matrix{Float64}}), |
| 45 | + |
| 46 | + |
| 47 | + # Tuples. |
| 48 | + ((4.0, ), Tangent{Tuple{Float64}}), |
| 49 | + ((5.0, randn(3)), Tangent{Tuple{Float64, Vector{Float64}}}), |
| 50 | + |
| 51 | + # NamedTuples. |
| 52 | + ((a=4.0, ), Tangent{NamedTuple{(:a,), Tuple{Float64}}}), |
| 53 | + ((a=5.0, b=1), Tangent{NamedTuple{(:a, :b), Tuple{Float64, Int}}}), |
| 54 | + |
| 55 | + # structs. |
| 56 | + (Bar(5.0, 4, rand(rng, 3)), Tangent{Bar}), |
| 57 | + (Bar(4.0, 3, Bar(5.0, 2, 4)), Tangent{Bar}), |
| 58 | + (sin, NoTangent), |
| 59 | + # all fields NoTangent implies NoTangent |
| 60 | + (Pair(:a, "b"), NoTangent), |
| 61 | + (1:10, NoTangent), |
| 62 | + (1:2:10, NoTangent), |
| 63 | + |
| 64 | + # LinearAlgebra types (also just structs). |
| 65 | + ( |
| 66 | + UpperTriangular(randn(3, 3)), |
| 67 | + Tangent{UpperTriangular{Float64, Matrix{Float64}}}, |
| 68 | + ), |
| 69 | + ( |
| 70 | + Diagonal(randn(2)), |
| 71 | + Tangent{Diagonal{Float64, Vector{Float64}}}, |
| 72 | + ), |
| 73 | + ( |
| 74 | + Symmetric(randn(2, 2)), |
| 75 | + Tangent{Symmetric{Float64, Matrix{Float64}}}, |
| 76 | + ), |
| 77 | + ( |
| 78 | + Hermitian(randn(ComplexF64, 1, 1)), |
| 79 | + Tangent{Hermitian{ComplexF64, Matrix{ComplexF64}}}, |
| 80 | + ), |
| 81 | + ] |
| 82 | + @test rand_tangent(rng, x) isa T_tangent |
| 83 | + @test rand_tangent(x) isa T_tangent |
| 84 | + end |
| 85 | + |
| 86 | + @testset "erroring cases" begin |
| 87 | + # Ensure struct fallback errors for non-struct types. |
| 88 | + @test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0) |
| 89 | + end |
| 90 | + |
| 91 | + @testset "compsition of addition" begin |
| 92 | + x = Bar(1.5, 2, Bar(1.1, 3, [1.7, 1.4, 0.9])) |
| 93 | + @test x + rand_tangent(x) isa typeof(x) |
| 94 | + @test x + (rand_tangent(x) + rand_tangent(x)) isa typeof(x) |
| 95 | + end |
| 96 | + |
| 97 | + # Julia 1.6 changed to using Ryu printing algorithm and seems better at printing short |
| 98 | + VERSION > v"1.6" && @testset "niceness of printing" begin |
| 99 | + for i in 1:50 |
| 100 | + @test length(string(rand_tangent(1.0))) <= 6 |
| 101 | + @test length(string(rand_tangent(1.0 + 1.0im))) <= 12 |
| 102 | + @test length(string(rand_tangent(1f0))) <= 12 |
| 103 | + @test length(string(rand_tangent(big"1.0"))) <= 12 |
| 104 | + end |
| 105 | + end |
| 106 | +end |
0 commit comments