Skip to content

Commit d61f19a

Browse files
committed
Fix get_param_eltype for TSVI
1 parent ec45c33 commit d61f19a

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/fastldf.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators())
8080
DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi
8181
DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs
8282
DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs)
83-
function DynamicPPL.get_param_eltype(::OnlyAccsVarInfo, model::Model)
83+
function DynamicPPL.get_param_eltype(
84+
::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, model::Model
85+
)
8486
# Because the VarInfo has no parameters stored in it, we need to get the eltype from the
8587
# model's leaf context. This is only possible if said leaf context is indeed a FastEval
8688
# context.
@@ -333,15 +335,16 @@ end
333335
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
334336
ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params)
335337
model = DynamicPPL.setleafcontext(f._model, ctx)
336-
only_accs_vi = OnlyAccsVarInfo(fast_ldf_accs(f._getlogdensity))
338+
accs = fast_ldf_accs(f._getlogdensity)
337339
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
338340
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
339341
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
340342
# here.
341343
vi = if Threads.nthreads() > 1
342-
ThreadSafeVarInfo(only_accs_vi)
344+
accs = map(acc -> convert_eltype(float_type_with_fallback(eltype(params)), acc), accs)
345+
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
343346
else
344-
only_accs_vi
347+
OnlyAccsVarInfo(accs)
345348
end
346349
_, vi = _evaluate!!(model, vi)
347350
return f._getlogdensity(vi)

0 commit comments

Comments
 (0)