Skip to content

Commit 8f43f2e

Browse files
committed
Support more VarInfos, make it thread-safe (?)
1 parent 4fe97cf commit 8f43f2e

File tree

1 file changed

+78
-21
lines changed

1 file changed

+78
-21
lines changed

src/fastldf.jl

Lines changed: 78 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,14 @@ struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo
7777
accs::Accs
7878
end
7979
OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators())
80+
DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi
8081
DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs
8182
DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs)
8283
function DynamicPPL.get_param_eltype(::OnlyAccsVarInfo, model::Model)
8384
# Because the VarInfo has no parameters stored in it, we need to get the eltype from the
8485
# model's leaf context. This is only possible if said leaf context is indeed a FastEval
8586
# context.
86-
leaf_ctx = DynamicPPL.leafcontext(model)
87+
leaf_ctx = DynamicPPL.leafcontext(model.context)
8788
if leaf_ctx isa FastEvalVectorContext
8889
return eltype(leaf_ctx.params)
8990
else
@@ -138,7 +139,8 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict.
138139
It would be nice to unify the NamedTuple and Dict approach. See, e.g.
139140
https://github.com/TuringLang/DynamicPPL.jl/issues/1116.
140141
"""
141-
struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext
142+
struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <:
143+
AbstractFastEvalContext
142144
# This NamedTuple stores the ranges for identity VarNames
143145
iden_varname_ranges::N
144146
# This Dict stores the ranges for all other VarNames
@@ -331,7 +333,17 @@ end
331333
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
332334
ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params)
333335
model = DynamicPPL.setleafcontext(f._model, ctx)
334-
_, vi = _evaluate!!(model, OnlyAccsVarInfo(fast_ldf_accs(f._getlogdensity)))
336+
only_accs_vi = OnlyAccsVarInfo(fast_ldf_accs(f._getlogdensity))
337+
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
338+
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
339+
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
340+
# here.
341+
vi = if Threads.nthreads() > 1
342+
ThreadSafeVarInfo(only_accs_vi)
343+
else
344+
only_accs_vi
345+
end
346+
_, vi = _evaluate!!(model, vi)
335347
return f._getlogdensity(vi)
336348
end
337349

@@ -360,30 +372,75 @@ end
360372
# Helper functions to extract ranges and link status #
361373
######################################################
362374

363-
# TODO: Fails for other VarInfo types.
375+
# TODO: Fails for SimpleVarInfo. Do I really care enough? Ehhh, honestly, debatable.
376+
377+
"""
378+
get_ranges_and_linked(varinfo::VarInfo)
379+
380+
Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter
381+
representation, along with whether each variable is linked or unlinked.
382+
383+
This function should return a tuple containing:
384+
385+
- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked`
386+
- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`.
387+
"""
364388
function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms}
365389
all_iden_ranges = NamedTuple()
366390
all_ranges = Dict{VarName,RangeAndLinked}()
367391
offset = 1
368392
for sym in syms
369393
md = varinfo.metadata[sym]
370-
# TODO: Fails for VarNamedVector.
371-
for (vn, idx) in md.idcs
372-
len = length(md.ranges[idx])
373-
is_linked = md.is_transformed[idx]
374-
range = offset:(offset + len - 1)
375-
if AbstractPPL.getoptic(vn) === identity
376-
all_iden_ranges = merge(
377-
all_iden_ranges,
378-
NamedTuple((
379-
AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),
380-
)),
381-
)
382-
else
383-
all_ranges[vn] = RangeAndLinked(range, is_linked)
384-
end
385-
offset += len
386-
end
394+
this_md_iden, this_md_others, new_offset = get_ranges_and_linked_metadata(
395+
md, offset
396+
)
397+
all_iden_ranges = merge(all_iden_ranges, this_md_iden)
398+
all_ranges = merge(all_ranges, this_md_others)
399+
offset = new_offset
387400
end
388401
return all_iden_ranges, all_ranges
389402
end
403+
function get_ranges_and_linked(varinfo::VarInfo{<:Metadata})
404+
all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1)
405+
return all_iden, all_others
406+
end
407+
function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int)
408+
all_iden_ranges = NamedTuple()
409+
all_ranges = Dict{VarName,RangeAndLinked}()
410+
offset = start_offset
411+
for (vn, idx) in md.idcs
412+
len = length(md.ranges[idx])
413+
is_linked = md.is_transformed[idx]
414+
range = offset:(offset + len - 1)
415+
if AbstractPPL.getoptic(vn) === identity
416+
all_iden_ranges = merge(
417+
all_iden_ranges,
418+
NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)),
419+
)
420+
else
421+
all_ranges[vn] = RangeAndLinked(range, is_linked)
422+
end
423+
offset += len
424+
end
425+
return all_iden_ranges, all_ranges, offset
426+
end
427+
function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int)
428+
all_iden_ranges = NamedTuple()
429+
all_ranges = Dict{VarName,RangeAndLinked}()
430+
offset = start_offset
431+
for (vn, idx) in vnv.varname_to_index
432+
len = length(vnv.ranges[idx])
433+
is_linked = vnv.is_unconstrained[idx]
434+
range = offset:(offset + len - 1)
435+
if AbstractPPL.getoptic(vn) === identity
436+
all_iden_ranges = merge(
437+
all_iden_ranges,
438+
NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)),
439+
)
440+
else
441+
all_ranges[vn] = RangeAndLinked(range, is_linked)
442+
end
443+
offset += len
444+
end
445+
return all_iden_ranges, all_ranges, offset
446+
end

0 commit comments

Comments
 (0)