@@ -80,7 +80,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators())
8080DynamicPPL. maybe_invlink_before_eval!! (vi:: OnlyAccsVarInfo , :: Model ) = vi
8181DynamicPPL. getaccs (vi:: OnlyAccsVarInfo ) = vi. accs
8282DynamicPPL. setaccs!! (:: OnlyAccsVarInfo , accs:: AccumulatorTuple ) = OnlyAccsVarInfo (accs)
83- function DynamicPPL. get_param_eltype (:: OnlyAccsVarInfo , model:: Model )
83+ function DynamicPPL. get_param_eltype (
84+ :: Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}} , model:: Model
85+ )
8486 # Because the VarInfo has no parameters stored in it, we need to get the eltype from the
8587 # model's leaf context. This is only possible if said leaf context is indeed a FastEval
8688 # context.
@@ -333,15 +335,16 @@ end
333335function (f:: FastLogDensityAt )(params:: AbstractVector{<:Real} )
334336 ctx = FastEvalVectorContext (f. _iden_varname_ranges, f. _varname_ranges, params)
335337 model = DynamicPPL. setleafcontext (f. _model, ctx)
336- only_accs_vi = OnlyAccsVarInfo ( fast_ldf_accs (f. _getlogdensity) )
338+ accs = fast_ldf_accs (f. _getlogdensity)
337339 # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
338340 # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
339341 # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
340342 # here.
341343 vi = if Threads. nthreads () > 1
342- ThreadSafeVarInfo (only_accs_vi)
344+ accs = map (acc -> convert_eltype (float_type_with_fallback (eltype (params)), acc), accs)
345+ ThreadSafeVarInfo (OnlyAccsVarInfo (accs))
343346 else
344- only_accs_vi
347+ OnlyAccsVarInfo (accs)
345348 end
346349 _, vi = _evaluate!! (model, vi)
347350 return f. _getlogdensity (vi)
0 commit comments