@@ -77,13 +77,14 @@ struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo
7777 accs:: Accs
7878end
7979OnlyAccsVarInfo () = OnlyAccsVarInfo (default_accumulators ())
80+ DynamicPPL. maybe_invlink_before_eval!! (vi:: OnlyAccsVarInfo , :: Model ) = vi
8081DynamicPPL. getaccs (vi:: OnlyAccsVarInfo ) = vi. accs
8182DynamicPPL. setaccs!! (:: OnlyAccsVarInfo , accs:: AccumulatorTuple ) = OnlyAccsVarInfo (accs)
8283function DynamicPPL. get_param_eltype (:: OnlyAccsVarInfo , model:: Model )
8384 # Because the VarInfo has no parameters stored in it, we need to get the eltype from the
8485 # model's leaf context. This is only possible if said leaf context is indeed a FastEval
8586 # context.
86- leaf_ctx = DynamicPPL. leafcontext (model)
87+ leaf_ctx = DynamicPPL. leafcontext (model. context )
8788 if leaf_ctx isa FastEvalVectorContext
8889 return eltype (leaf_ctx. params)
8990 else
@@ -138,7 +139,8 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict.
138139It would be nice to unify the NamedTuple and Dict approach. See, e.g.
139140https://github.com/TuringLang/DynamicPPL.jl/issues/1116.
140141"""
141- struct FastEvalVectorContext{N<: NamedTuple ,T<: AbstractVector{<:Real} } <: AbstractContext
142+ struct FastEvalVectorContext{N<: NamedTuple ,T<: AbstractVector{<:Real} } < :
143+ AbstractFastEvalContext
142144 # This NamedTuple stores the ranges for identity VarNames
143145 iden_varname_ranges:: N
144146 # This Dict stores the ranges for all other VarNames
331333function (f:: FastLogDensityAt )(params:: AbstractVector{<:Real} )
332334 ctx = FastEvalVectorContext (f. _iden_varname_ranges, f. _varname_ranges, params)
333335 model = DynamicPPL. setleafcontext (f. _model, ctx)
334- _, vi = _evaluate!! (model, OnlyAccsVarInfo (fast_ldf_accs (f. _getlogdensity)))
336+ only_accs_vi = OnlyAccsVarInfo (fast_ldf_accs (f. _getlogdensity))
337+ # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
338+ # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
339+ # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
340+ # here.
341+ vi = if Threads. nthreads () > 1
342+ ThreadSafeVarInfo (only_accs_vi)
343+ else
344+ only_accs_vi
345+ end
346+ _, vi = _evaluate!! (model, vi)
335347 return f. _getlogdensity (vi)
336348end
337349
@@ -360,30 +372,75 @@ end
360372# Helper functions to extract ranges and link status #
361373# #####################################################
362374
363- # TODO : Fails for other VarInfo types.
375+ # TODO : Fails for SimpleVarInfo. Do I really care enough? Ehhh, honestly, debatable.
376+
377+ """
378+ get_ranges_and_linked(varinfo::VarInfo)
379+
380+ Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter
381+ representation, along with whether each variable is linked or unlinked.
382+
383+ This function should return a tuple containing:
384+
385+ - A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked`
386+ - A Dict mapping all other VarNames to their corresponding `RangeAndLinked`.
387+ """
364388function get_ranges_and_linked (varinfo:: VarInfo{<:NamedTuple{syms}} ) where {syms}
365389 all_iden_ranges = NamedTuple ()
366390 all_ranges = Dict {VarName,RangeAndLinked} ()
367391 offset = 1
368392 for sym in syms
369393 md = varinfo. metadata[sym]
370- # TODO : Fails for VarNamedVector.
371- for (vn, idx) in md. idcs
372- len = length (md. ranges[idx])
373- is_linked = md. is_transformed[idx]
374- range = offset: (offset + len - 1 )
375- if AbstractPPL. getoptic (vn) === identity
376- all_iden_ranges = merge (
377- all_iden_ranges,
378- NamedTuple ((
379- AbstractPPL. getsym (vn) => RangeAndLinked (range, is_linked),
380- )),
381- )
382- else
383- all_ranges[vn] = RangeAndLinked (range, is_linked)
384- end
385- offset += len
386- end
394+ this_md_iden, this_md_others, new_offset = get_ranges_and_linked_metadata (
395+ md, offset
396+ )
397+ all_iden_ranges = merge (all_iden_ranges, this_md_iden)
398+ all_ranges = merge (all_ranges, this_md_others)
399+ offset = new_offset
387400 end
388401 return all_iden_ranges, all_ranges
389402end
403+ function get_ranges_and_linked (varinfo:: VarInfo{<:Metadata} )
404+ all_iden, all_others, _ = get_ranges_and_linked_metadata (varinfo. metadata, 1 )
405+ return all_iden, all_others
406+ end
407+ function get_ranges_and_linked_metadata (md:: Metadata , start_offset:: Int )
408+ all_iden_ranges = NamedTuple ()
409+ all_ranges = Dict {VarName,RangeAndLinked} ()
410+ offset = start_offset
411+ for (vn, idx) in md. idcs
412+ len = length (md. ranges[idx])
413+ is_linked = md. is_transformed[idx]
414+ range = offset: (offset + len - 1 )
415+ if AbstractPPL. getoptic (vn) === identity
416+ all_iden_ranges = merge (
417+ all_iden_ranges,
418+ NamedTuple ((AbstractPPL. getsym (vn) => RangeAndLinked (range, is_linked),)),
419+ )
420+ else
421+ all_ranges[vn] = RangeAndLinked (range, is_linked)
422+ end
423+ offset += len
424+ end
425+ return all_iden_ranges, all_ranges, offset
426+ end
427+ function get_ranges_and_linked_metadata (vnv:: VarNamedVector , start_offset:: Int )
428+ all_iden_ranges = NamedTuple ()
429+ all_ranges = Dict {VarName,RangeAndLinked} ()
430+ offset = start_offset
431+ for (vn, idx) in vnv. varname_to_index
432+ len = length (vnv. ranges[idx])
433+ is_linked = vnv. is_unconstrained[idx]
434+ range = offset: (offset + len - 1 )
435+ if AbstractPPL. getoptic (vn) === identity
436+ all_iden_ranges = merge (
437+ all_iden_ranges,
438+ NamedTuple ((AbstractPPL. getsym (vn) => RangeAndLinked (range, is_linked),)),
439+ )
440+ else
441+ all_ranges[vn] = RangeAndLinked (range, is_linked)
442+ end
443+ offset += len
444+ end
445+ return all_iden_ranges, all_ranges, offset
446+ end
0 commit comments