Skip to content

Commit f3b81d1

Browse files
committed
Add correctness tests, fix imports
1 parent b921789 commit f3b81d1

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/fasteval.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ using DynamicPPL:
6060
AbstractContext,
6161
AbstractVarInfo,
6262
AccumulatorTuple,
63+
LogJacobianAccumulator,
64+
LogLikelihoodAccumulator,
65+
LogPriorAccumulator,
6366
Metadata,
6467
Model,
6568
ThreadSafeVarInfo,

test/fasteval.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ using ReverseDiff: ReverseDiff
2020
using Mooncake: Mooncake
2121
end
2222

23-
@testset "get_ranges_and_linked" begin
23+
@testset "FastLDF: Correctness" begin
2424
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
2525
@testset "$varinfo_func" for varinfo_func in [
2626
DynamicPPL.untyped_varinfo,
@@ -51,6 +51,26 @@ end
5151
# Check that the link status is correct
5252
@test range_with_linked.is_linked == islinked
5353
end
54+
55+
# Compare results of FastLDF vs ordinary LogDensityFunction. These tests
56+
# can eventually go once we replace LogDensityFunction with FastLDF, but
57+
# for now it helps to have this check! (Eventually we should just check
58+
# against manually computed log-densities).
59+
#
60+
# TODO(penelopeysm): I think we need to add tests for some really
61+
# pathological models here.
62+
@testset "$getlogdensity" for getlogdensity in (
63+
DynamicPPL.getlogjoint_internal,
64+
DynamicPPL.getlogjoint,
65+
DynamicPPL.getloglikelihood,
66+
DynamicPPL.getlogprior_internal,
67+
DynamicPPL.getlogprior,
68+
)
69+
ldf = DynamicPPL.LogDensityFunction(m, getlogdensity, vi)
70+
fldf = FastLDF(m, getlogdensity, vi)
71+
@test LogDensityProblems.logdensity(ldf, params)
72+
LogDensityProblems.logdensity(fldf, params)
73+
end
5474
end
5575
end
5676
end

0 commit comments

Comments
 (0)