@@ -50,14 +50,10 @@ Fast evaluation has not yet been extended to NamedTuple and Dict parameters. Suc
5050representations are capable of handling models with variable sizes and stochastic control
5151flow.
5252
53- However, the path towards implementing these is straightforward:
54-
55- 1. Currently, `FastLDFVectorContext` allows users to input a VarName and obtain the parameter
56- value, plus a boolean indicating whether the value is linked or unlinked. See the
57- `get_range_and_linked` function for details.
58-
59- 2. We would need to implement similar contexts for NamedTuple and Dict parameters. The
60- functionality would be quite similar to `InitContext(InitFromParams(...))`.
53+ However, the path towards implementing these is straightforward: just make `InitContext` work
54+ correctly with `OnlyAccsVarInfo`. There will probably be a few functions that need to be
55+ overloaded to make this work: for example `push!!` on `OnlyAccsVarInfo` can just be defined
56+ as a no-op.
6157"""
6258
6359using DynamicPPL:
@@ -119,6 +115,13 @@ function DynamicPPL.get_param_eltype(
119115 if leaf_ctx isa FastEvalVectorContext
120116 return eltype (leaf_ctx. params)
121117 else
118+ # TODO (penelopeysm): In principle this can be done with InitContext{InitWithParams}.
119+ # See also `src/simple_varinfo.jl` where `infer_nested_eltype` is used to try to
120+ # figure out the parameter type from a NamedTuple or Dict. The benefit of
121+ # implementing this for InitContext is that we could then use OnlyAccsVarInfo with
122+ # it, which means fast evaluation with NamedTuple or Dict parameters! And I believe
123+ # that Mooncake / Enzyme should be able to differentiate through that too and
124+ # provide a NamedTuple of gradients (although I haven't tested this yet).
122125 error (
123126 " OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof (leaf_ctx)) " ,
124127 )
@@ -188,7 +191,7 @@ function get_range_and_linked(ctx::FastEvalVectorContext, vn::VarName)
188191 return ctx. varname_ranges[vn]
189192end
190193
191- function tilde_assume!! (
194+ function DynamicPPL . tilde_assume!! (
192195 ctx:: FastEvalVectorContext , right:: Distribution , vn:: VarName , vi:: AbstractVarInfo
193196)
194197 # Note that this function does not use the metadata field of `vi` at all.
@@ -204,7 +207,7 @@ function tilde_assume!!(
204207 return x, vi
205208end
206209
207- function tilde_observe!! (
210+ function DynamicPPL . tilde_observe!! (
208211 :: FastEvalVectorContext ,
209212 right:: Distribution ,
210213 left,
@@ -369,6 +372,9 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
369372 # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
370373 # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
371374 # here.
375+ # TODO (penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
376+ # it _should_ do, but this is wrong regardless.
377+ # https://github.com/TuringLang/DynamicPPL.jl/issues/1086
372378 vi = if Threads. nthreads () > 1
373379 accs = map (
374380 acc -> DynamicPPL. convert_eltype (float_type_with_fallback (eltype (params)), acc),
0 commit comments