Skip to content

Commit 0286756

Browse files
committed
Fix imports, add tests, etc
1 parent 85ad83c commit 0286756

File tree

4 files changed

+153
-5
lines changed

4 files changed

+153
-5
lines changed

src/experimental.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module Experimental
22

33
using DynamicPPL: DynamicPPL
44

5-
include("fastldf.jl")
5+
include("fasteval.jl")
66

77
# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency.
88
"""

src/fasteval.jl

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,35 @@ However, the path towards implementing these is straightforward:
6060
functionality would be quite similar to `InitContext(InitFromParams(...))`.
6161
"""
6262

63+
using DynamicPPL:
64+
AbstractContext,
65+
AbstractVarInfo,
66+
AccumulatorTuple,
67+
Metadata,
68+
Model,
69+
ThreadSafeVarInfo,
70+
VarInfo,
71+
VarNamedVector,
72+
accumulate_assume!!,
73+
accumulate_observe!!,
74+
default_accumulators,
75+
float_type_with_fallback,
76+
from_linked_vec_transform,
77+
from_vec_transform,
78+
getlogjoint,
79+
getlogjoint_internal,
80+
getloglikelihood,
81+
getlogprior,
82+
getlogprior_internal,
83+
leafcontext
84+
using ADTypes: ADTypes
85+
using Bijectors: with_logabsdet_jacobian
86+
using AbstractPPL: AbstractPPL, VarName
87+
using Distributions: Distribution
88+
using DocStringExtensions: TYPEDFIELDS
89+
using LogDensityProblems: LogDensityProblems
90+
import DifferentiationInterface as DI
91+
6392
"""
6493
OnlyAccsVarInfo
6594
@@ -121,7 +150,7 @@ Abstract type representing fast evaluation contexts. This currently is only subt
121150
NamedTuple and Dict parameters.
122151
"""
123152
abstract type AbstractFastEvalContext <: AbstractContext end
124-
DynamicPPL.NodeTrait(::AbstractFastEvalContext) = IsLeaf()
153+
DynamicPPL.NodeTrait(::AbstractFastEvalContext) = DynamicPPL.IsLeaf()
125154

126155
"""
127156
FastEvalVectorContext(
@@ -286,7 +315,7 @@ struct FastLDF{
286315
nothing
287316
else
288317
# Make backend-specific tweaks to the adtype
289-
adtype = tweak_adtype(adtype, model, varinfo)
318+
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
290319
x = [val for val in varinfo[:]]
291320
DI.prepare_gradient(
292321
FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
@@ -341,12 +370,15 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
341370
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
342371
# here.
343372
vi = if Threads.nthreads() > 1
344-
accs = map(acc -> convert_eltype(float_type_with_fallback(eltype(params)), acc), accs)
373+
accs = map(
374+
acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc),
375+
accs,
376+
)
345377
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
346378
else
347379
OnlyAccsVarInfo(accs)
348380
end
349-
_, vi = _evaluate!!(model, vi)
381+
_, vi = DynamicPPL._evaluate!!(model, vi)
350382
return f._getlogdensity(vi)
351383
end
352384

test/fasteval.jl

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
module DynamicPPLFastLDFTests
2+
3+
using DynamicPPL
4+
using Distributions
5+
using DistributionsAD: filldist
6+
using ADTypes
7+
using DynamicPPL.Experimental: FastLDF
8+
9+
@testset "Automatic differentiation" begin
10+
# Used as the ground truth that others are compared against.
11+
ref_adtype = AutoForwardDiff()
12+
13+
test_adtypes = if MOONCAKE_SUPPORTED
14+
[
15+
AutoReverseDiff(; compile=false),
16+
AutoReverseDiff(; compile=true),
17+
AutoMooncake(; config=nothing),
18+
]
19+
else
20+
[AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)]
21+
end
22+
23+
@testset "Unsupported backends" begin
24+
@model demo() = x ~ Normal()
25+
@test_logs (:warn, r"not officially supported") FastLDF(demo(); adtype=AutoZygote())
26+
end
27+
28+
@testset "Correctness" begin
29+
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
30+
varinfo = VarInfo(m)
31+
linked_varinfo = DynamicPPL.link(varinfo, m)
32+
f = FastLDF(m, getlogjoint_internal, linked_varinfo)
33+
x = linked_varinfo[:]
34+
35+
# Calculate reference logp + gradient of logp using ForwardDiff
36+
ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest())
37+
ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual
38+
39+
@testset "$adtype" for adtype in test_adtypes
40+
@info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype"
41+
42+
@test run_ad(
43+
m,
44+
adtype;
45+
varinfo=linked_varinfo,
46+
test=WithExpectedResult(ref_logp, ref_grad),
47+
) isa Any
48+
end
49+
end
50+
end
51+
52+
# Test that various different ways of specifying array types as arguments work with all
53+
# ADTypes.
54+
@testset "Array argument types" begin
55+
test_m = randn(2, 3)
56+
57+
function eval_logp_and_grad(model, m, adtype)
58+
ldf = FastLDF(model(); adtype=adtype)
59+
return LogDensityProblems.logdensity_and_gradient(ldf, m[:])
60+
end
61+
62+
@model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real}
63+
m = Matrix{T}(undef, 2, 3)
64+
return m ~ filldist(MvNormal(zeros(2), I), 3)
65+
end
66+
67+
scalar_matrix_model_reference = eval_logp_and_grad(
68+
scalar_matrix_model, test_m, ref_adtype
69+
)
70+
71+
@model function matrix_model(::Type{T}=Matrix{Float64}) where {T}
72+
m = T(undef, 2, 3)
73+
return m ~ filldist(MvNormal(zeros(2), I), 3)
74+
end
75+
76+
matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype)
77+
78+
@model function scalar_array_model(::Type{T}=Float64) where {T<:Real}
79+
m = Array{T}(undef, 2, 3)
80+
return m ~ filldist(MvNormal(zeros(2), I), 3)
81+
end
82+
83+
scalar_array_model_reference = eval_logp_and_grad(
84+
scalar_array_model, test_m, ref_adtype
85+
)
86+
87+
@model function array_model(::Type{T}=Array{Float64}) where {T}
88+
m = T(undef, 2, 3)
89+
return m ~ filldist(MvNormal(zeros(2), I), 3)
90+
end
91+
92+
array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype)
93+
94+
@testset "$adtype" for adtype in test_adtypes
95+
scalar_matrix_model_logp_and_grad = eval_logp_and_grad(
96+
scalar_matrix_model, test_m, adtype
97+
)
98+
@test scalar_matrix_model_logp_and_grad[1] scalar_matrix_model_reference[1]
99+
@test scalar_matrix_model_logp_and_grad[2] scalar_matrix_model_reference[2]
100+
matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype)
101+
@test matrix_model_logp_and_grad[1] matrix_model_reference[1]
102+
@test matrix_model_logp_and_grad[2] matrix_model_reference[2]
103+
scalar_array_model_logp_and_grad = eval_logp_and_grad(
104+
scalar_array_model, test_m, adtype
105+
)
106+
@test scalar_array_model_logp_and_grad[1] scalar_array_model_reference[1]
107+
@test scalar_array_model_logp_and_grad[2] scalar_array_model_reference[2]
108+
array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype)
109+
@test array_model_logp_and_grad[1] array_model_reference[1]
110+
@test array_model_logp_and_grad[2] array_model_reference[2]
111+
end
112+
end
113+
end
114+
115+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ include("test_util.jl")
8989
include("ext/DynamicPPLMooncakeExt.jl")
9090
end
9191
include("ad.jl")
92+
include("fasteval.jl")
9293
end
9394
@testset "prob and logprob macro" begin
9495
@test_throws ErrorException prob"..."

0 commit comments

Comments
 (0)