Skip to content

Commit e030d42

Browse files
committed
Fix imports / tests
1 parent 405fdff commit e030d42

File tree

2 files changed

+57
-12
lines changed

2 files changed

+57
-12
lines changed

src/fasteval.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,10 @@ Fast evaluation has not yet been extended to NamedTuple and Dict parameters. Suc
5050
representations are capable of handling models with variable sizes and stochastic control
5151
flow.
5252
53-
However, the path towards implementing these is straightforward:
54-
55-
1. Currently, `FastLDFVectorContext` allows users to input a VarName and obtain the parameter
56-
value, plus a boolean indicating whether the value is linked or unlinked. See the
57-
`get_range_and_linked` function for details.
58-
59-
2. We would need to implement similar contexts for NamedTuple and Dict parameters. The
60-
functionality would be quite similar to `InitContext(InitFromParams(...))`.
53+
However, the path towards implementing these is straightforward: just make `InitContext` work
54+
correctly with `OnlyAccsVarInfo`. There will probably be a few functions that need to be
55+
overloaded to make this work: for example `push!!` on `OnlyAccsVarInfo` can just be defined
56+
as a no-op.
6157
"""
6258

6359
using DynamicPPL:
@@ -119,6 +115,13 @@ function DynamicPPL.get_param_eltype(
119115
if leaf_ctx isa FastEvalVectorContext
120116
return eltype(leaf_ctx.params)
121117
else
118+
# TODO(penelopeysm): In principle this can be done with InitContext{InitWithParams}.
119+
# See also `src/simple_varinfo.jl` where `infer_nested_eltype` is used to try to
120+
# figure out the parameter type from a NamedTuple or Dict. The benefit of
121+
# implementing this for InitContext is that we could then use OnlyAccsVarInfo with
122+
# it, which means fast evaluation with NamedTuple or Dict parameters! And I believe
123+
# that Mooncake / Enzyme should be able to differentiate through that too and
124+
# provide a NamedTuple of gradients (although I haven't tested this yet).
122125
error(
123126
"OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))",
124127
)
@@ -188,7 +191,7 @@ function get_range_and_linked(ctx::FastEvalVectorContext, vn::VarName)
188191
return ctx.varname_ranges[vn]
189192
end
190193

191-
function tilde_assume!!(
194+
function DynamicPPL.tilde_assume!!(
192195
ctx::FastEvalVectorContext, right::Distribution, vn::VarName, vi::AbstractVarInfo
193196
)
194197
# Note that this function does not use the metadata field of `vi` at all.
@@ -204,7 +207,7 @@ function tilde_assume!!(
204207
return x, vi
205208
end
206209

207-
function tilde_observe!!(
210+
function DynamicPPL.tilde_observe!!(
208211
::FastEvalVectorContext,
209212
right::Distribution,
210213
left,
@@ -369,6 +372,9 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
369372
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
370373
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
371374
# here.
375+
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
376+
# it _should_ do, but this is wrong regardless.
377+
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
372378
vi = if Threads.nthreads() > 1
373379
accs = map(
374380
acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc),

test/fasteval.jl

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
module DynamicPPLFastLDFTests
22

3+
using AbstractPPL: AbstractPPL
34
using DynamicPPL
45
using Distributions
56
using DistributionsAD: filldist
67
using ADTypes
78
using DynamicPPL.Experimental: FastLDF
89
using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest
10+
using LinearAlgebra: I
911
using Test
12+
using LogDensityProblems: LogDensityProblems
1013

1114
using ForwardDiff: ForwardDiff
1215
using ReverseDiff: ReverseDiff
@@ -17,7 +20,43 @@ using ReverseDiff: ReverseDiff
1720
using Mooncake: Mooncake
1821
end
1922

20-
@testset "Automatic differentiation" begin
23+
@testset "get_ranges_and_linked" begin
24+
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
25+
@testset "$varinfo_func" for varinfo_func in [
26+
DynamicPPL.untyped_varinfo,
27+
DynamicPPL.typed_varinfo,
28+
DynamicPPL.untyped_vector_varinfo,
29+
DynamicPPL.typed_vector_varinfo,
30+
]
31+
unlinked_vi = varinfo_func(m)
32+
@testset "$islinked" for islinked in (false, true)
33+
vi = if islinked
34+
DynamicPPL.link!!(unlinked_vi, m)
35+
else
36+
unlinked_vi
37+
end
38+
nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi)
39+
params = vi[:]
40+
# Iterate over all variables
41+
for vn in keys(vi)
42+
# Check that `getindex_internal` returns the same thing as using the ranges
43+
# directly
44+
range_with_linked = if AbstractPPL.getoptic(vn) === identity
45+
nt_ranges[AbstractPPL.getsym(vn)]
46+
else
47+
dict_ranges[vn]
48+
end
49+
@test params[range_with_linked.range] ==
50+
DynamicPPL.getindex_internal(vi, vn)
51+
# Check that the link status is correct
52+
@test range_with_linked.is_linked == islinked
53+
end
54+
end
55+
end
56+
end
57+
end
58+
59+
@testset "AD with FastLDF" begin
2160
# Used as the ground truth that others are compared against.
2261
ref_adtype = AutoForwardDiff()
2362

@@ -43,7 +82,7 @@ end
4382
ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual
4483

4584
@testset "$adtype" for adtype in test_adtypes
46-
@info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype"
85+
@info "Testing AD on: $(m.f) - $adtype"
4786

4887
@test run_ad(
4988
m,

0 commit comments

Comments
 (0)