From e5a038e805759b20a6e5c2ca501d1c545f3f5669 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 8 Nov 2025 17:38:42 +0000 Subject: [PATCH 1/2] Allow opting out of TSVI --- docs/src/api.md | 6 ++++++ src/DynamicPPL.jl | 9 +++++++-- src/fasteval.jl | 5 +---- src/model.jl | 36 ++++++++++++++++++++++++------------ test/threadsafe.jl | 22 ++++++++++++++++++++++ 5 files changed, 60 insertions(+), 18 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index e81f18dc7..d7238e8b8 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -110,6 +110,12 @@ Similarly, we can revert this with [`DynamicPPL.unfix`](@ref), i.e. return the v DynamicPPL.unfix ``` +## Controlling threadsafe evaluation + +```@docs +DynamicPPL.set_threadsafe_eval! +``` + ## Predicting DynamicPPL provides functionality for generating samples from the posterior predictive distribution through the `predict` function. This allows you to use posterior parameter samples to generate predictions for unobserved data points. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e9b902363..94a5288b0 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -92,6 +92,7 @@ export AbstractVarInfo, getargnames, extract_priors, values_as_in_model, + set_threadsafe_eval!, # LogDensityFunction LogDensityFunction, # Leaf contexts @@ -212,8 +213,12 @@ include("test_utils.jl") include("experimental.jl") include("deprecated.jl") -if isdefined(Base.Experimental, :register_error_hint) - function __init__() +function __init__() + # This has to be in the `__init__()` function, if it's placed at the top level it + # always evaluates to false. + DynamicPPL.set_threadsafe_eval!(Threads.nthreads() > 1) + + if isdefined(Base.Experimental, :register_error_hint) # Better error message if users forget to load JET.jl Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ requires_jet = diff --git a/src/fasteval.jl b/src/fasteval.jl index 5b9b767df..c976801e2 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -219,10 +219,7 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic # here. - # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what - # it _should_ do, but this is wrong regardless. - # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 - vi = if Threads.nthreads() > 1 + vi = if DynamicPPL.USE_THREADSAFE_EVAL[] accs = map( acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc), accs, diff --git a/src/model.jl b/src/model.jl index 2bcfe8f98..ce225c670 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,3 +1,25 @@ +# This is overridden in the `__init__()` function (src/DynamicPPL.jl) +USE_THREADSAFE_EVAL = Ref(true) + +""" + DynamicPPL.set_threadsafe_eval!(val::Bool) + +Enable or disable threadsafe model evaluation globally. By default, threadsafe evaluation is +used whenever Julia is run with multiple threads. + +However, this is not necessary for the vast majority of DynamicPPL models. **In particular, +use of threaded sampling with MCMCChains alone does NOT require threadsafe evaluation.** +Threadsafe evaluation is only needed when manipulating `VarInfo` objects in parallel, e.g. +when using `x ~ dist` statements inside `Threads.@threads` blocks. + +If you do not need threadsafe evaluation, disabling it can lead to significant performance +improvements. +""" +function set_threadsafe_eval!(val::Bool) + USE_THREADSAFE_EVAL[] = val + return nothing +end + """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} f::F @@ -863,16 +885,6 @@ function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInf return first(init!!(rng, model, varinfo)) end -""" - use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - -Return `true` if evaluation of a model using `context` and `varinfo` should -wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise. -""" -function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - return Threads.nthreads() > 1 -end - """ init!!( [rng::Random.AbstractRNG,] @@ -912,14 +924,14 @@ end Evaluate the `model` with the given `varinfo`. -If multiple threads are available, the varinfo provided will be wrapped in a +If threadsafe evaluation is enabled, the varinfo provided will be wrapped in a `ThreadSafeVarInfo` before evaluation. Returns a tuple of the model's return value, plus the updated `varinfo` (unwrapped if necessary). """ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) - return if use_threadsafe_eval(model.context, varinfo) + return if DynamicPPL.USE_THREADSAFE_EVAL[] evaluate_threadsafe!!(model, varinfo) else evaluate_threadunsafe!!(model, varinfo) diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 522730566..f35b51d4f 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -1,4 +1,26 @@ @testset "threadsafe.jl" begin + @testset "set threadsafe eval" begin + # A dummy model that lets us see what type of VarInfo is being used for evaluation. + @model function find_out_varinfo_type() + x ~ Normal() + return typeof(__varinfo__) + end + model = find_out_varinfo_type() + + # Check the default. + @test DynamicPPL.USE_THREADSAFE_EVAL[] == (Threads.nthreads() > 1) + # Disable it. + DynamicPPL.set_threadsafe_eval!(false) + @test DynamicPPL.USE_THREADSAFE_EVAL[] == false + @test !(model() <: DynamicPPL.ThreadSafeVarInfo) + # Enable it. + DynamicPPL.set_threadsafe_eval!(true) + @test DynamicPPL.USE_THREADSAFE_EVAL[] == true + @test model() <: DynamicPPL.ThreadSafeVarInfo + # Reset to default to avoid messing with other tests. + DynamicPPL.set_threadsafe_eval!(Threads.nthreads() > 1) + end + @testset "constructor" begin vi = VarInfo(gdemo_default) threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi) From 6a0ecfa8dea051870ee8fc7146c90ef4af077713 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 12 Nov 2025 12:41:53 +0000 Subject: [PATCH 2/2] Replace use of `nthreads()` in tests --- test/compiler.jl | 2 +- test/fasteval.jl | 2 +- test/threadsafe.jl | 13 +++++++------ 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index b1309254e..afe419b54 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -606,7 +606,7 @@ module Issue537 end @model demo() = return __varinfo__ retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() - if Threads.nthreads() > 1 + if DynamicPPL.USE_THREADSAFE_EVAL[] @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} @test retval.varinfo == svi else diff --git a/test/fasteval.jl b/test/fasteval.jl index f1c535643..0871e2dc0 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -96,7 +96,7 @@ end end @testset "FastLDF: performance" begin - if Threads.nthreads() == 1 + if !(DynamicPPL.USE_THREADSAFE_EVAL[]) # Evaluating these three models should not lead to any allocations (but only when # not using TSVI). @model function f() diff --git a/test/threadsafe.jl b/test/threadsafe.jl index f35b51d4f..d54a5afa9 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -64,6 +64,7 @@ @testset "model" begin println("Peforming threading tests with $(Threads.nthreads()) threads") + @show DynamicPPL.USE_THREADSAFE_EVAL[] x = rand(10_000) @@ -79,10 +80,10 @@ vi = VarInfo() model(vi) lp_w_threads = getlogjoint(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else + if DynamicPPL.USE_THREADSAFE_EVAL[] @test vi_ isa DynamicPPL.ThreadSafeVarInfo + else + @test vi_ isa VarInfo end println("With `@threads`:") @@ -112,10 +113,10 @@ vi = VarInfo() model(vi) lp_wo_threads = getlogjoint(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else + if DynamicPPL.USE_THREADSAFE_EVAL[] @test vi_ isa DynamicPPL.ThreadSafeVarInfo + else + @test vi_ isa VarInfo end println("Without `@threads`:")