Skip to content

Commit 85ad83c

Browse files
committed
Move FastLDF to experimental for now
1 parent 46033ef commit 85ad83c

File tree

7 files changed

+75
-31
lines changed

7 files changed

+75
-31
lines changed

benchmarks/benchmarks.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ chosen_combinations = [
5959
false,
6060
),
6161
("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false),
62-
# ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true),
63-
# ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true),
64-
# ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true),
65-
# ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true),
66-
# ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true),
62+
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true),
63+
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true),
64+
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true),
65+
("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true),
66+
("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true),
6767
("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true),
6868
("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true),
69-
# ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true),
69+
("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true),
7070
("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true),
7171
("Multivariate 1k", multivariate1k, :typed, :mooncake, true),
7272
("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true),

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
9494
vi = DynamicPPL.link(vi, model)
9595
end
9696

97-
f = DynamicPPL.FastLDF(model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend)
97+
f = DynamicPPL.LogDensityFunction(
98+
model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend
99+
)
98100
# The parameters at which we evaluate f.
99101
θ = vi[:]
100102

src/DynamicPPL.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ include("simple_varinfo.jl")
191191
include("compiler.jl")
192192
include("pointwise_logdensities.jl")
193193
include("logdensityfunction.jl")
194-
include("fastldf.jl")
195194
include("model_utils.jl")
196195
include("extract_priors.jl")
197196
include("values_as_in_model.jl")

src/experimental.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ module Experimental
22

33
using DynamicPPL: DynamicPPL
44

5+
include("fastldf.jl")
6+
57
# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency.
68
"""
79
is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...)
File renamed without changes.

src/test_utils/ad.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ using ADTypes: AbstractADType, AutoForwardDiff
44
using Chairmarks: @be
55
import DifferentiationInterface as DI
66
using DocStringExtensions
7-
using DynamicPPL: Model, FastLDF, VarInfo, AbstractVarInfo, getlogjoint_internal, link
7+
using DynamicPPL:
8+
Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link
89
using LogDensityProblems: logdensity, logdensity_and_gradient
910
using Random: AbstractRNG, default_rng
1011
using Statistics: median
@@ -264,7 +265,7 @@ function run_ad(
264265
# Calculate log-density and gradient with the backend of interest
265266
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
266267
verbose && println(" params : $(params)")
267-
ldf = FastLDF(model, getlogdensity, varinfo; adtype=adtype)
268+
ldf = LogDensityFunction(model, getlogdensity, varinfo; adtype=adtype)
268269

269270
value, grad = logdensity_and_gradient(ldf, params)
270271
# collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
@@ -281,7 +282,9 @@ function run_ad(
281282
value_true = test.value
282283
grad_true = test.grad
283284
elseif test isa WithBackend
284-
ldf_reference = FastLDF(model, getlogdensity, varinfo; adtype=test.adtype)
285+
ldf_reference = LogDensityFunction(
286+
model, getlogdensity, varinfo; adtype=test.adtype
287+
)
285288
value_true, grad_true = logdensity_and_gradient(ldf_reference, params)
286289
# collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
287290
grad_true = collect(grad_true)

test/ad.jl

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
using DynamicPPL: FastLDF
1+
using DynamicPPL: LogDensityFunction
22
using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest
3-
using Random: Xoshiro
43

54
@testset "Automatic differentiation" begin
65
# Used as the ground truth that others are compared against.
@@ -16,25 +15,64 @@ using Random: Xoshiro
1615
[AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)]
1716
end
1817

18+
@testset "Unsupported backends" begin
19+
@model demo() = x ~ Normal()
20+
@test_logs (:warn, r"not officially supported") LogDensityFunction(
21+
demo(); adtype=AutoZygote()
22+
)
23+
end
24+
1925
@testset "Correctness" begin
2026
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
21-
varinfo = VarInfo(Xoshiro(468), m)
22-
linked_varinfo = DynamicPPL.link(varinfo, m)
23-
f = FastLDF(m, getlogjoint_internal, linked_varinfo)
24-
x = linked_varinfo[:]
25-
26-
# Calculate reference logp + gradient of logp using ForwardDiff
27-
ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest())
28-
ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual
29-
30-
@testset "$adtype" for adtype in test_adtypes
31-
@info "Testing AD on: $(m.f) - $adtype"
32-
@test run_ad(
33-
m,
34-
adtype;
35-
varinfo=linked_varinfo,
36-
test=WithExpectedResult(ref_logp, ref_grad),
37-
) isa Any
27+
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
28+
vns = DynamicPPL.TestUtils.varnames(m)
29+
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
30+
31+
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
32+
linked_varinfo = DynamicPPL.link(varinfo, m)
33+
f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo)
34+
x = DynamicPPL.getparams(f)
35+
36+
# Calculate reference logp + gradient of logp using ForwardDiff
37+
ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest())
38+
ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual
39+
40+
@testset "$adtype" for adtype in test_adtypes
41+
@info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype"
42+
43+
# Put predicates here to avoid long lines
44+
is_mooncake = adtype isa AutoMooncake
45+
is_1_10 = v"1.10" <= VERSION < v"1.11"
46+
is_1_11 = v"1.11" <= VERSION < v"1.12"
47+
is_svi_vnv =
48+
linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector}
49+
is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict}
50+
51+
# Mooncake doesn't work with several combinations of SimpleVarInfo.
52+
if is_mooncake && is_1_11 && is_svi_vnv
53+
# https://github.com/compintell/Mooncake.jl/issues/470
54+
@test_throws ArgumentError DynamicPPL.LogDensityFunction(
55+
m, getlogjoint_internal, linked_varinfo; adtype=adtype
56+
)
57+
elseif is_mooncake && is_1_10 && is_svi_vnv
58+
# TODO: report upstream
59+
@test_throws UndefRefError DynamicPPL.LogDensityFunction(
60+
m, getlogjoint_internal, linked_varinfo; adtype=adtype
61+
)
62+
elseif is_mooncake && is_1_10 && is_svi_od
63+
# TODO: report upstream
64+
@test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction(
65+
m, getlogjoint_internal, linked_varinfo; adtype=adtype
66+
)
67+
else
68+
@test run_ad(
69+
m,
70+
adtype;
71+
varinfo=linked_varinfo,
72+
test=WithExpectedResult(ref_logp, ref_grad),
73+
) isa Any
74+
end
75+
end
3876
end
3977
end
4078
end
@@ -45,7 +83,7 @@ using Random: Xoshiro
4583
test_m = randn(2, 3)
4684

4785
function eval_logp_and_grad(model, m, adtype)
48-
ldf = FastLDF(model(); adtype=adtype)
86+
ldf = LogDensityFunction(model(); adtype=adtype)
4987
return LogDensityProblems.logdensity_and_gradient(ldf, m[:])
5088
end
5189

0 commit comments

Comments
 (0)