Skip to content

Commit 462c32f

Browse files
committed
Add basic AD test
1 parent 9f0c204 commit 462c32f

File tree

3 files changed

+70
-1
lines changed

3 files changed

+70
-1
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
4545
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4646
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
4747
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
48+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4849

4950
[targets]
50-
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs"]
51+
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"]

test/compat/ad.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
using DynamicPPL
2+
using Distributions
3+
4+
using ForwardDiff
5+
using Zygote
6+
using Tracker
7+
8+
@testset "logp" begin
9+
@model function admodel()
10+
s ~ InverseGamma(2, 3)
11+
m ~ Normal(0, sqrt(s))
12+
1.5 ~ Normal(m, sqrt(s))
13+
2.0 ~ Normal(m, sqrt(s))
14+
return s, m
15+
end
16+
17+
model = admodel()
18+
vi = VarInfo(model)
19+
model(vi, SampleFromPrior())
20+
x = [vi[@varname(s)], vi[@varname(m)]]
21+
22+
dist_s = InverseGamma(2,3)
23+
24+
# Hand-written log probabilities for vector `x = [s, m]`.
25+
function logp_manual(x)
26+
s = x[1]
27+
m = x[2]
28+
dist = Normal(m, sqrt(s))
29+
30+
return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) +
31+
logpdf(dist, 1.5) + logpdf(dist, 2.0)
32+
end
33+
34+
# Log probabilities for vector `x = [s, m]` using the model.
35+
function logp_model(x)
36+
new_vi = VarInfo(vi, SampleFromPrior(), x)
37+
model(new_vi, SampleFromPrior())
38+
return getlogp(new_vi)
39+
end
40+
41+
# Check that both functions return the same values.
42+
lp = logp_manual(x)
43+
@test logp_model(x) lp
44+
45+
# Gradients based on the manual implementation.
46+
grad = ForwardDiff.gradient(logp_manual, x)
47+
48+
y, back = Tracker.forward(logp_manual, x)
49+
@test Tracker.data(y) lp
50+
@test Tracker.data(back(1)[1]) grad
51+
52+
y, back = Zygote.pullback(logp_manual, x)
53+
@test y lp
54+
@test back(1)[1] grad
55+
56+
# Gradients based on the model.
57+
@test ForwardDiff.gradient(logp_model, x) grad
58+
59+
y, back = Tracker.forward(logp_model, x)
60+
@test Tracker.data(y) lp
61+
@test Tracker.data(back(1)[1]) grad
62+
63+
y, back = Zygote.gradient(logp_model, x)
64+
@test y lp
65+
@test back(1) grad
66+
end
67+

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ turnprogress(false)
88
@testset "DynamicPPL.jl" begin
99
include("utils.jl")
1010
include("compiler.jl")
11+
include("compat/ad.jl")
1112
include("varinfo.jl")
1213
include("sampler.jl")
1314
include("prob_macro.jl")

0 commit comments

Comments
 (0)