1- using DynamicPPL: FastLDF
1+ using DynamicPPL: LogDensityFunction
22using DynamicPPL. TestUtils. AD: run_ad, WithExpectedResult, NoTest
3- using Random: Xoshiro
43
54@testset " Automatic differentiation" begin
65 # Used as the ground truth that others are compared against.
@@ -16,25 +15,64 @@ using Random: Xoshiro
1615 [AutoReverseDiff (; compile= false ), AutoReverseDiff (; compile= true )]
1716 end
1817
18+ @testset " Unsupported backends" begin
19+ @model demo () = x ~ Normal ()
20+ @test_logs (:warn , r" not officially supported" ) LogDensityFunction (
21+ demo (); adtype= AutoZygote ()
22+ )
23+ end
24+
1925 @testset " Correctness" begin
2026 @testset " $(m. f) " for m in DynamicPPL. TestUtils. DEMO_MODELS
21- varinfo = VarInfo (Xoshiro (468 ), m)
22- linked_varinfo = DynamicPPL. link (varinfo, m)
23- f = FastLDF (m, getlogjoint_internal, linked_varinfo)
24- x = linked_varinfo[:]
25-
26- # Calculate reference logp + gradient of logp using ForwardDiff
27- ref_ad_result = run_ad (m, ref_adtype; varinfo= linked_varinfo, test= NoTest ())
28- ref_logp, ref_grad = ref_ad_result. value_actual, ref_ad_result. grad_actual
29-
30- @testset " $adtype " for adtype in test_adtypes
31- @info " Testing AD on: $(m. f) - $adtype "
32- @test run_ad (
33- m,
34- adtype;
35- varinfo= linked_varinfo,
36- test= WithExpectedResult (ref_logp, ref_grad),
37- ) isa Any
27+ rand_param_values = DynamicPPL. TestUtils. rand_prior_true (m)
28+ vns = DynamicPPL. TestUtils. varnames (m)
29+ varinfos = DynamicPPL. TestUtils. setup_varinfos (m, rand_param_values, vns)
30+
31+ @testset " $(short_varinfo_name (varinfo)) " for varinfo in varinfos
32+ linked_varinfo = DynamicPPL. link (varinfo, m)
33+ f = LogDensityFunction (m, getlogjoint_internal, linked_varinfo)
34+ x = DynamicPPL. getparams (f)
35+
36+ # Calculate reference logp + gradient of logp using ForwardDiff
37+ ref_ad_result = run_ad (m, ref_adtype; varinfo= linked_varinfo, test= NoTest ())
38+ ref_logp, ref_grad = ref_ad_result. value_actual, ref_ad_result. grad_actual
39+
40+ @testset " $adtype " for adtype in test_adtypes
41+ @info " Testing AD on: $(m. f) - $(short_varinfo_name (linked_varinfo)) - $adtype "
42+
43+ # Put predicates here to avoid long lines
44+ is_mooncake = adtype isa AutoMooncake
45+ is_1_10 = v " 1.10" <= VERSION < v " 1.11"
46+ is_1_11 = v " 1.11" <= VERSION < v " 1.12"
47+ is_svi_vnv =
48+ linked_varinfo isa SimpleVarInfo{<: DynamicPPL.VarNamedVector }
49+ is_svi_od = linked_varinfo isa SimpleVarInfo{<: OrderedDict }
50+
51+ # Mooncake doesn't work with several combinations of SimpleVarInfo.
52+ if is_mooncake && is_1_11 && is_svi_vnv
53+ # https://github.com/compintell/Mooncake.jl/issues/470
54+ @test_throws ArgumentError DynamicPPL. LogDensityFunction (
55+ m, getlogjoint_internal, linked_varinfo; adtype= adtype
56+ )
57+ elseif is_mooncake && is_1_10 && is_svi_vnv
58+ # TODO : report upstream
59+ @test_throws UndefRefError DynamicPPL. LogDensityFunction (
60+ m, getlogjoint_internal, linked_varinfo; adtype= adtype
61+ )
62+ elseif is_mooncake && is_1_10 && is_svi_od
63+ # TODO : report upstream
64+ @test_throws Mooncake. MooncakeRuleCompilationError DynamicPPL. LogDensityFunction (
65+ m, getlogjoint_internal, linked_varinfo; adtype= adtype
66+ )
67+ else
68+ @test run_ad (
69+ m,
70+ adtype;
71+ varinfo= linked_varinfo,
72+ test= WithExpectedResult (ref_logp, ref_grad),
73+ ) isa Any
74+ end
75+ end
3876 end
3977 end
4078 end
@@ -45,7 +83,7 @@ using Random: Xoshiro
4583 test_m = randn (2 , 3 )
4684
4785 function eval_logp_and_grad (model, m, adtype)
48- ldf = FastLDF (model (); adtype= adtype)
86+ ldf = LogDensityFunction (model (); adtype= adtype)
4987 return LogDensityProblems. logdensity_and_gradient (ldf, m[:])
5088 end
5189
0 commit comments