@@ -20,14 +20,14 @@ Base.length(x::DummyType) = size(x.X, 1)
2020 rng, fdm = MersenneTwister (123456 ), central_fdm (5 , 1 )
2121 x = randn (rng, T, 2 )
2222 xc = copy (x)
23- @test grad (fdm, x-> sin (x[1 ]) + cos (x[2 ]), x) ≈ [cos (x[1 ]), - sin (x[2 ])]
23+ @test grad (fdm, x-> sin (x[1 ]) + cos (x[2 ]), x)[ 1 ] ≈ [cos (x[1 ]), - sin (x[2 ])]
2424 @test xc == x
2525 end
2626
2727 function check_jac_and_jvp_and_j′vp (fdm, f, ȳ, x, ẋ, J_exact)
2828 xc = copy (x)
29- @test jacobian (fdm, f, x; len= length (ȳ)) ≈ J_exact
30- @test jacobian (fdm, f, x) == jacobian (fdm, f, x; len= length (ȳ))
29+ @test jacobian (fdm, f, x; len= length (ȳ))[ 1 ] ≈ J_exact
30+ @test jacobian (fdm, f, x)[ 1 ] == jacobian (fdm, f, x; len= length (ȳ))[ 1 ]
3131 @test _jvp (fdm, f, x, ẋ) ≈ J_exact * ẋ
3232 @test _j′vp (fdm, f, ȳ, x) ≈ transpose (J_exact) * ȳ
3333 @test xc == x
@@ -56,46 +56,46 @@ Base.length(x::DummyType) = size(x.X, 1)
5656 @testset " check multiple matrices" begin
5757 x, y = rand (rng, 3 , 3 ), rand (rng, 3 , 3 )
5858 jac_xs = jacobian (fdm, f1, x, y)
59- @test jac_xs[1 ] ≈ jacobian (fdm, x-> f1 (x, y), x)
60- @test jac_xs[2 ] ≈ jacobian (fdm, y-> f1 (x, y), y)
59+ @test jac_xs[1 ] ≈ jacobian (fdm, x-> f1 (x, y), x)[ 1 ]
60+ @test jac_xs[2 ] ≈ jacobian (fdm, y-> f1 (x, y), y)[ 1 ]
6161 end
6262
6363 @testset " check mixed scalar and matrices" begin
6464 x, y = rand (3 , 3 ), 2
6565 jac_xs = jacobian (fdm, f1, x, y)
66- @test jac_xs[1 ] ≈ jacobian (fdm, x-> f1 (x, y), x)
67- @test jac_xs[2 ] ≈ jacobian (fdm, y-> f1 (x, y), y)
66+ @test jac_xs[1 ] ≈ jacobian (fdm, x-> f1 (x, y), x)[ 1 ]
67+ @test jac_xs[2 ] ≈ jacobian (fdm, y-> f1 (x, y), y)[ 1 ]
6868 end
6969 end
7070
7171 @testset " grad" begin
7272 @testset " check multiple matrices" begin
7373 x, y = rand (rng, 3 , 3 ), rand (rng, 3 , 3 )
7474 dxs = grad (fdm, f2, x, y)
75- @test dxs[1 ] ≈ grad (fdm, x-> f2 (x, y), x)
76- @test dxs[2 ] ≈ grad (fdm, y-> f2 (x, y), y)
75+ @test dxs[1 ] ≈ grad (fdm, x-> f2 (x, y), x)[ 1 ]
76+ @test dxs[2 ] ≈ grad (fdm, y-> f2 (x, y), y)[ 1 ]
7777 end
7878
7979 @testset " check mixed scalar & matrices" begin
8080 x, y = rand (rng, 3 , 3 ), 2
8181 dxs = grad (fdm, f2, x, y)
82- @test dxs[1 ] ≈ grad (fdm, x-> f2 (x, y), x)
83- @test dxs[2 ] ≈ grad (fdm, y-> f2 (x, y), y)
82+ @test dxs[1 ] ≈ grad (fdm, x-> f2 (x, y), x)[ 1 ]
83+ @test dxs[2 ] ≈ grad (fdm, y-> f2 (x, y), y)[ 1 ]
8484 end
8585
8686 @testset " check tuple" begin
8787 x, y = rand (rng, 3 , 3 ), 2
88- dxs = grad (fdm, f3, (x, y))
89- @test dxs[1 ] ≈ grad (fdm, x-> f3 ((x, y)), x)
90- @test dxs[2 ] ≈ grad (fdm, y-> f3 ((x, y)), y)
88+ dxs = grad (fdm, f3, (x, y))[ 1 ]
89+ @test dxs[1 ] ≈ grad (fdm, x-> f3 ((x, y)), x)[ 1 ]
90+ @test dxs[2 ] ≈ grad (fdm, y-> f3 ((x, y)), y)[ 1 ]
9191 end
9292
9393 @testset " check dict" begin
9494 x, y = rand (rng, 3 , 3 ), 2
9595 d = Dict (:x => x, :y => y)
96- dxs = grad (fdm, f4, d)
97- @test dxs[:x ] ≈ grad (fdm, x-> f3 ((x, y)), x)
98- @test dxs[:y ] ≈ grad (fdm, y-> f3 ((x, y)), y)
96+ dxs = grad (fdm, f4, d)[ 1 ]
97+ @test dxs[:x ] ≈ grad (fdm, x-> f3 ((x, y)), x)[ 1 ]
98+ @test dxs[:y ] ≈ grad (fdm, y-> f3 ((x, y)), y)[ 1 ]
9999 end
100100 end
101101 end
@@ -168,8 +168,8 @@ Base.length(x::DummyType) = size(x.X, 1)
168168 x, y = randn (rng, T, N), randn (rng, T, M)
169169 z̄ = randn (rng, T, N + M)
170170 xy = vcat (x, y)
171- x̄ȳ_manual = j′vp (fdm, xy-> sin .(xy), z̄, xy)
172- x̄ȳ_auto = j′vp (fdm, x-> sin .(vcat (x[1 ], x[2 ])), z̄, (x, y))
171+ x̄ȳ_manual = j′vp (fdm, xy-> sin .(xy), z̄, xy)[ 1 ]
172+ x̄ȳ_auto = j′vp (fdm, x-> sin .(vcat (x[1 ], x[2 ])), z̄, (x, y))[ 1 ]
173173 x̄ȳ_multi = j′vp (fdm, (x, y)-> sin .(vcat (x, y)), z̄, x, y)
174174 @test x̄ȳ_manual ≈ vcat (x̄ȳ_auto... )
175175 @test x̄ȳ_manual ≈ vcat (x̄ȳ_multi... )
0 commit comments