Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ Similarly, we can revert this with [`DynamicPPL.unfix`](@ref), i.e. return the v
DynamicPPL.unfix
```

## Controlling threadsafe evaluation

```@docs
DynamicPPL.set_threadsafe_eval!
```

## Predicting

DynamicPPL provides functionality for generating samples from the posterior predictive distribution through the `predict` function. This allows you to use posterior parameter samples to generate predictions for unobserved data points.
Expand Down
9 changes: 7 additions & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ export AbstractVarInfo,
getargnames,
extract_priors,
values_as_in_model,
set_threadsafe_eval!,
# LogDensityFunction
LogDensityFunction,
# Contexts
Expand Down Expand Up @@ -204,8 +205,12 @@ include("test_utils.jl")
include("experimental.jl")
include("deprecated.jl")

if isdefined(Base.Experimental, :register_error_hint)
function __init__()
function __init__()
# This has to be in the `__init__()` function, if it's placed at the top level it
# always evaluates to false.
DynamicPPL.set_threadsafe_eval!(Threads.nthreads() > 1)

if isdefined(Base.Experimental, :register_error_hint)
# Better error message if users forget to load JET.jl
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
requires_jet =
Expand Down
2 changes: 1 addition & 1 deletion src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitS
function InitFromParams(
params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior()
)
return InitFromParams(to_varname_dict(params), fallback)
return new{typeof(params),typeof(fallback)}(params, fallback)
end
end
function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams)
Expand Down
25 changes: 14 additions & 11 deletions src/fasteval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ using DynamicPPL:
AbstractContext,
AbstractVarInfo,
AccumulatorTuple,
InitContext,
InitFromParams,
InitFromPrior,
InitFromUniform,
LogJacobianAccumulator,
LogLikelihoodAccumulator,
LogPriorAccumulator,
Expand All @@ -81,6 +85,7 @@ using DynamicPPL:
getlogprior_internal,
leafcontext
using ADTypes: ADTypes
using BangBang: BangBang
using Bijectors: with_logabsdet_jacobian
using AbstractPPL: AbstractPPL, VarName
using Distributions: Distribution
Expand Down Expand Up @@ -108,6 +113,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators())
DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi
DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs
DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs)
@inline Base.haskey(::OnlyAccsVarInfo, ::VarName) = false
@inline DynamicPPL.is_transformed(::OnlyAccsVarInfo) = false
@inline BangBang.push!!(vi::OnlyAccsVarInfo, vn, y, dist) = vi
function DynamicPPL.get_param_eltype(
::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, model::Model
)
Expand All @@ -117,14 +125,12 @@ function DynamicPPL.get_param_eltype(
leaf_ctx = DynamicPPL.leafcontext(model.context)
if leaf_ctx isa FastEvalVectorContext
return eltype(leaf_ctx.params)
elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams}
return DynamicPPL.infer_nested_eltype(leaf_ctx.strategy.params)
elseif leaf_ctx isa InitContext{<:Any,<:Union{InitFromPrior,InitFromUniform}}
# No need to enforce any particular eltype here, since new parameters are sampled
return Any
else
# TODO(penelopeysm): In principle this can be done with InitContext{InitWithParams}.
# See also `src/simple_varinfo.jl` where `infer_nested_eltype` is used to try to
# figure out the parameter type from a NamedTuple or Dict. The benefit of
# implementing this for InitContext is that we could then use OnlyAccsVarInfo with
# it, which means fast evaluation with NamedTuple or Dict parameters! And I believe
# that Mooncake / Enzyme should be able to differentiate through that too and
# provide a NamedTuple of gradients (although I haven't tested this yet).
error(
"OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))",
)
Expand Down Expand Up @@ -366,10 +372,7 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
# here.
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
# it _should_ do, but this is wrong regardless.
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
vi = if Threads.nthreads() > 1
vi = if DynamicPPL.USE_THREADSAFE_EVAL[]
accs = map(
acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc),
accs,
Expand Down
36 changes: 24 additions & 12 deletions src/model.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
# This is overridden in the `__init__()` function (src/DynamicPPL.jl)
USE_THREADSAFE_EVAL = Ref(true)

"""
DynamicPPL.set_threadsafe_eval!(val::Bool)

Enable or disable threadsafe model evaluation globally. By default, threadsafe evaluation is
used whenever Julia is run with multiple threads.

However, this is not necessary for the vast majority of DynamicPPL models. **In particular,
use of threaded sampling with MCMCChains alone does NOT require threadsafe evaluation.**
Threadsafe evaluation is only needed when manipulating `VarInfo` objects in parallel, e.g.
when using `x ~ dist` statements inside `Threads.@threads` blocks.

If you do not need threadsafe evaluation, disabling it can lead to significant performance
improvements.
"""
function set_threadsafe_eval!(val::Bool)
USE_THREADSAFE_EVAL[] = val
return nothing
end

"""
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext}
f::F
Expand Down Expand Up @@ -863,16 +885,6 @@ function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInf
return first(init!!(rng, model, varinfo))
end

"""
use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)

Return `true` if evaluation of a model using `context` and `varinfo` should
wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise.
"""
function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
return Threads.nthreads() > 1
end
Comment on lines -866 to -874
Copy link
Member Author

@penelopeysm penelopeysm Nov 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, Turing overloads this function, so releasing this in a patch would cause upstream breakage. I don't feel bad about this though because this function is not exported.

On top of that, the method that Turing defines for this is bogus. It silently disables TSVI for PG/SMC, because Libtask can't handle it, and it generally doesn't work for SMC anyway because observations for all particles need to be in step. The problem is that for models that do need threadsafe eval, this will lead to silent wrong results (TuringLang/Turing.jl#2658).

The correct solution for this should be to emit an error saying:

PG/SMC can't run with models that require threadsafe eval. If you know your model doesn't need it, then you should call DynamicPPL.use_threadsafe_eval!(false).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that FastLDF is going into breaking instead, we should just rebase this on breaking.


"""
init!!(
[rng::Random.AbstractRNG,]
Expand Down Expand Up @@ -912,14 +924,14 @@ end

Evaluate the `model` with the given `varinfo`.

If multiple threads are available, the varinfo provided will be wrapped in a
If threadsafe evaluation is enabled, the varinfo provided will be wrapped in a
`ThreadSafeVarInfo` before evaluation.

Returns a tuple of the model's return value, plus the updated `varinfo`
(unwrapped if necessary).
"""
function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo)
return if use_threadsafe_eval(model.context, varinfo)
return if DynamicPPL.USE_THREADSAFE_EVAL[]
evaluate_threadsafe!!(model, varinfo)
else
evaluate_threadunsafe!!(model, varinfo)
Expand Down
22 changes: 22 additions & 0 deletions test/threadsafe.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,26 @@
@testset "threadsafe.jl" begin
@testset "set threadsafe eval" begin
# A dummy model that lets us see what type of VarInfo is being used for evaluation.
@model function find_out_varinfo_type()
x ~ Normal()
return typeof(__varinfo__)
end
model = find_out_varinfo_type()

# Check the default.
@test DynamicPPL.USE_THREADSAFE_EVAL[] == (Threads.nthreads() > 1)
# Disable it.
DynamicPPL.set_threadsafe_eval!(false)
@test DynamicPPL.USE_THREADSAFE_EVAL[] == false
@test !(model() <: DynamicPPL.ThreadSafeVarInfo)
# Enable it.
DynamicPPL.set_threadsafe_eval!(true)
@test DynamicPPL.USE_THREADSAFE_EVAL[] == true
@test model() <: DynamicPPL.ThreadSafeVarInfo
# Reset to default to avoid messing with other tests.
DynamicPPL.set_threadsafe_eval!(Threads.nthreads() > 1)
end

@testset "constructor" begin
vi = VarInfo(gdemo_default)
threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi)
Expand Down