diff --git a/docs/src/api.md b/docs/src/api.md index bbe39fb73..418d42756 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -110,6 +110,12 @@ Similarly, we can revert this with [`DynamicPPL.unfix`](@ref), i.e. return the v DynamicPPL.unfix ``` +## Controlling threadsafe evaluation + +```@docs +DynamicPPL.set_threadsafe_eval! +``` + ## Predicting DynamicPPL provides functionality for generating samples from the posterior predictive distribution through the `predict` function. This allows you to use posterior parameter samples to generate predictions for unobserved data points. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e66f3fe11..f7f3c8bda 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -92,6 +92,7 @@ export AbstractVarInfo, getargnames, extract_priors, values_as_in_model, + set_threadsafe_eval!, # LogDensityFunction LogDensityFunction, # Contexts @@ -204,8 +205,12 @@ include("test_utils.jl") include("experimental.jl") include("deprecated.jl") -if isdefined(Base.Experimental, :register_error_hint) - function __init__() +function __init__() + # This has to be in the `__init__()` function, if it's placed at the top level it + # always evaluates to false. + DynamicPPL.set_threadsafe_eval!(Threads.nthreads() > 1) + + if isdefined(Base.Experimental, :register_error_hint) # Better error message if users forget to load JET.jl Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ requires_jet = diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 44dbc5508..efc6f1087 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -102,7 +102,7 @@ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitS function InitFromParams( params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() ) - return InitFromParams(to_varname_dict(params), fallback) + return new{typeof(params),typeof(fallback)}(params, fallback) end end function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) diff --git a/src/fasteval.jl b/src/fasteval.jl index c91254d43..6b3e84a4b 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -60,6 +60,10 @@ using DynamicPPL: AbstractContext, AbstractVarInfo, AccumulatorTuple, + InitContext, + InitFromParams, + InitFromPrior, + InitFromUniform, LogJacobianAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, @@ -81,6 +85,7 @@ using DynamicPPL: getlogprior_internal, leafcontext using ADTypes: ADTypes +using BangBang: BangBang using Bijectors: with_logabsdet_jacobian using AbstractPPL: AbstractPPL, VarName using Distributions: Distribution @@ -108,6 +113,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) +@inline Base.haskey(::OnlyAccsVarInfo, ::VarName) = false +@inline DynamicPPL.is_transformed(::OnlyAccsVarInfo) = false +@inline BangBang.push!!(vi::OnlyAccsVarInfo, vn, y, dist) = vi function DynamicPPL.get_param_eltype( ::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, model::Model ) @@ -117,14 +125,12 @@ function DynamicPPL.get_param_eltype( leaf_ctx = DynamicPPL.leafcontext(model.context) if leaf_ctx isa FastEvalVectorContext return eltype(leaf_ctx.params) + elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams} + return DynamicPPL.infer_nested_eltype(leaf_ctx.strategy.params) + elseif leaf_ctx isa InitContext{<:Any,<:Union{InitFromPrior,InitFromUniform}} + # No need to enforce any particular eltype here, since new parameters are sampled + return Any else - # TODO(penelopeysm): In principle this can be done with InitContext{InitWithParams}. - # See also `src/simple_varinfo.jl` where `infer_nested_eltype` is used to try to - # figure out the parameter type from a NamedTuple or Dict. The benefit of - # implementing this for InitContext is that we could then use OnlyAccsVarInfo with - # it, which means fast evaluation with NamedTuple or Dict parameters! And I believe - # that Mooncake / Enzyme should be able to differentiate through that too and - # provide a NamedTuple of gradients (although I haven't tested this yet). error( "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", ) @@ -366,10 +372,7 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic # here. - # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what - # it _should_ do, but this is wrong regardless. - # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 - vi = if Threads.nthreads() > 1 + vi = if DynamicPPL.USE_THREADSAFE_EVAL[] accs = map( acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc), accs, diff --git a/src/model.jl b/src/model.jl index 6ca06aea6..3a29702ac 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,3 +1,25 @@ +# This is overridden in the `__init__()` function (src/DynamicPPL.jl) +USE_THREADSAFE_EVAL = Ref(true) + +""" + DynamicPPL.set_threadsafe_eval!(val::Bool) + +Enable or disable threadsafe model evaluation globally. By default, threadsafe evaluation is +used whenever Julia is run with multiple threads. + +However, this is not necessary for the vast majority of DynamicPPL models. **In particular, +use of threaded sampling with MCMCChains alone does NOT require threadsafe evaluation.** +Threadsafe evaluation is only needed when manipulating `VarInfo` objects in parallel, e.g. +when using `x ~ dist` statements inside `Threads.@threads` blocks. + +If you do not need threadsafe evaluation, disabling it can lead to significant performance +improvements. +""" +function set_threadsafe_eval!(val::Bool) + USE_THREADSAFE_EVAL[] = val + return nothing +end + """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} f::F @@ -863,16 +885,6 @@ function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInf return first(init!!(rng, model, varinfo)) end -""" - use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - -Return `true` if evaluation of a model using `context` and `varinfo` should -wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise. -""" -function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - return Threads.nthreads() > 1 -end - """ init!!( [rng::Random.AbstractRNG,] @@ -912,14 +924,14 @@ end Evaluate the `model` with the given `varinfo`. -If multiple threads are available, the varinfo provided will be wrapped in a +If threadsafe evaluation is enabled, the varinfo provided will be wrapped in a `ThreadSafeVarInfo` before evaluation. Returns a tuple of the model's return value, plus the updated `varinfo` (unwrapped if necessary). """ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) - return if use_threadsafe_eval(model.context, varinfo) + return if DynamicPPL.USE_THREADSAFE_EVAL[] evaluate_threadsafe!!(model, varinfo) else evaluate_threadunsafe!!(model, varinfo) diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 522730566..f35b51d4f 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -1,4 +1,26 @@ @testset "threadsafe.jl" begin + @testset "set threadsafe eval" begin + # A dummy model that lets us see what type of VarInfo is being used for evaluation. + @model function find_out_varinfo_type() + x ~ Normal() + return typeof(__varinfo__) + end + model = find_out_varinfo_type() + + # Check the default. + @test DynamicPPL.USE_THREADSAFE_EVAL[] == (Threads.nthreads() > 1) + # Disable it. + DynamicPPL.set_threadsafe_eval!(false) + @test DynamicPPL.USE_THREADSAFE_EVAL[] == false + @test !(model() <: DynamicPPL.ThreadSafeVarInfo) + # Enable it. + DynamicPPL.set_threadsafe_eval!(true) + @test DynamicPPL.USE_THREADSAFE_EVAL[] == true + @test model() <: DynamicPPL.ThreadSafeVarInfo + # Reset to default to avoid messing with other tests. + DynamicPPL.set_threadsafe_eval!(Threads.nthreads() > 1) + end + @testset "constructor" begin vi = VarInfo(gdemo_default) threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi)