|
1 | | -using DynamicPPL |
2 | | -using Distributions |
3 | | - |
4 | | -using ForwardDiff |
5 | | -using Zygote |
6 | | -using Tracker |
7 | | - |
8 | | -@testset "logp" begin |
9 | | - @model function admodel() |
10 | | - s ~ InverseGamma(2, 3) |
11 | | - m ~ Normal(0, sqrt(s)) |
12 | | - 1.5 ~ Normal(m, sqrt(s)) |
13 | | - 2.0 ~ Normal(m, sqrt(s)) |
14 | | - return s, m |
15 | | - end |
16 | | - |
17 | | - model = admodel() |
18 | | - vi = VarInfo(model) |
19 | | - model(vi, SampleFromPrior()) |
20 | | - x = [vi[@varname(s)], vi[@varname(m)]] |
21 | | - |
22 | | - dist_s = InverseGamma(2,3) |
23 | | - |
24 | | - # Hand-written log probabilities for vector `x = [s, m]`. |
25 | | - function logp_manual(x) |
26 | | - s = x[1] |
27 | | - m = x[2] |
28 | | - dist = Normal(m, sqrt(s)) |
29 | | - |
30 | | - return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) + |
31 | | - logpdf(dist, 1.5) + logpdf(dist, 2.0) |
32 | | - end |
33 | | - |
34 | | - # Log probabilities for vector `x = [s, m]` using the model. |
35 | | - function logp_model(x) |
36 | | - new_vi = VarInfo(vi, SampleFromPrior(), x) |
37 | | - model(new_vi, SampleFromPrior()) |
38 | | - return getlogp(new_vi) |
| 1 | +@testset "ad.jl" begin |
| 2 | + @testset "logp" begin |
| 3 | + # Hand-written log probabilities for vector `x = [s, m]`. |
| 4 | + function logp_gdemo_default(x) |
| 5 | + s = x[1] |
| 6 | + m = x[2] |
| 7 | + dist = Normal(m, sqrt(s)) |
| 8 | + |
| 9 | + return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) + |
| 10 | + logpdf(dist, 1.5) + logpdf(dist, 2.0) |
| 11 | + end |
| 12 | + |
| 13 | + test_model_ad(gdemo_default, logp_gdemo_default) |
| 14 | + |
| 15 | + @model function wishart_ad() |
| 16 | + v ~ Wishart(7, [1 0.5; 0.5 1]) |
| 17 | + end |
| 18 | + |
| 19 | + # Hand-written log probabilities for `x = [v]`. |
| 20 | + function logp_wishart_ad(x) |
| 21 | + dist = Wishart(7, [1 0.5; 0.5 1]) |
| 22 | + return logpdf(dist, reshape(x, 2, 2)) |
| 23 | + end |
| 24 | + |
| 25 | + test_model_ad(wishart_ad(), logp_wishart_ad) |
39 | 26 | end |
40 | | - |
41 | | - # Check that both functions return the same values. |
42 | | - lp = logp_manual(x) |
43 | | - @test logp_model(x) ≈ lp |
44 | | - |
45 | | - # Gradients based on the manual implementation. |
46 | | - grad = ForwardDiff.gradient(logp_manual, x) |
47 | | - |
48 | | - y, back = Tracker.forward(logp_manual, x) |
49 | | - @test Tracker.data(y) ≈ lp |
50 | | - @test Tracker.data(back(1)[1]) ≈ grad |
51 | | - |
52 | | - y, back = Zygote.pullback(logp_manual, x) |
53 | | - @test y ≈ lp |
54 | | - @test back(1)[1] ≈ grad |
55 | | - |
56 | | - # Gradients based on the model. |
57 | | - @test ForwardDiff.gradient(logp_model, x) ≈ grad |
58 | | - |
59 | | - y, back = Tracker.forward(logp_model, x) |
60 | | - @test Tracker.data(y) ≈ lp |
61 | | - @test Tracker.data(back(1)[1]) ≈ grad |
62 | | - |
63 | | - y, back = Zygote.pullback(logp_model, x) |
64 | | - @test y ≈ lp |
65 | | - @test back(1)[1] ≈ grad |
66 | 27 | end |
67 | | - |
0 commit comments