Skip to content

Commit 00886ec

Browse files
committed
Allow opting out of TSVI
1 parent 7deaaab commit 00886ec

File tree

4 files changed

+52
-12
lines changed

4 files changed

+52
-12
lines changed

docs/src/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,12 @@ Similarly, we can revert this with [`DynamicPPL.unfix`](@ref), i.e. return the v
110110
DynamicPPL.unfix
111111
```
112112

113+
## Controlling threadsafe evaluation
114+
115+
```@docs
116+
DynamicPPL.set_threadsafe_eval!
117+
```
118+
113119
## Predicting
114120

115121
DynamicPPL provides functionality for generating samples from the posterior predictive distribution through the `predict` function. This allows you to use posterior parameter samples to generate predictions for unobserved data points.

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ export AbstractVarInfo,
9292
getargnames,
9393
extract_priors,
9494
values_as_in_model,
95+
set_threadsafe_eval!,
9596
# LogDensityFunction
9697
LogDensityFunction,
9798
# Contexts

src/model.jl

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,24 @@
1+
const USE_THREADSAFE_EVAL = Ref(Threads.nthreads() > 1)
2+
3+
"""
4+
DynamicPPL.set_threadsafe_eval!(val::Bool)
5+
6+
Enable or disable threadsafe model evaluation globally. By default, threadsafe evaluation is
7+
used whenever Julia is run with multiple threads.
8+
9+
However, this is not necessary for the vast majority of DynamicPPL models. **In particular,
10+
use of threaded sampling with MCMCChains alone does NOT require threadsafe evaluation.**
11+
Threadsafe evaluation is only needed when manipulating `VarInfo` objects in parallel, e.g.
12+
when using `x ~ dist` statements inside `Threads.@threads` blocks.
13+
14+
If you do not need threadsafe evaluation, disabling it can lead to significant performance
15+
improvements.
16+
"""
17+
function set_threadsafe_eval!(val::Bool)
18+
USE_THREADSAFE_EVAL[] = val
19+
return nothing
20+
end
21+
122
"""
223
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext}
324
f::F
@@ -863,16 +884,6 @@ function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInf
863884
return first(init!!(rng, model, varinfo))
864885
end
865886

866-
"""
867-
use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
868-
869-
Return `true` if evaluation of a model using `context` and `varinfo` should
870-
wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise.
871-
"""
872-
function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
873-
return Threads.nthreads() > 1
874-
end
875-
876887
"""
877888
init!!(
878889
[rng::Random.AbstractRNG,]
@@ -912,14 +923,14 @@ end
912923
913924
Evaluate the `model` with the given `varinfo`.
914925
915-
If multiple threads are available, the varinfo provided will be wrapped in a
926+
If threadsafe evaluation is enabled, the varinfo provided will be wrapped in a
916927
`ThreadSafeVarInfo` before evaluation.
917928
918929
Returns a tuple of the model's return value, plus the updated `varinfo`
919930
(unwrapped if necessary).
920931
"""
921932
function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo)
922-
return if use_threadsafe_eval(model.context, varinfo)
933+
return if DynamicPPL.USE_THREADSAFE_EVAL[]
923934
evaluate_threadsafe!!(model, varinfo)
924935
else
925936
evaluate_threadunsafe!!(model, varinfo)

test/threadsafe.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,26 @@
11
@testset "threadsafe.jl" begin
2+
@testset "set threadsafe eval" begin
3+
# A dummy model that lets us see what type of VarInfo is being used for evaluation.
4+
@model function find_out_varinfo_type()
5+
x ~ Normal()
6+
return typeof(__varinfo__)
7+
end
8+
model = find_out_varinfo_type()
9+
10+
# Check the default.
11+
@test DynamicPPL.USE_THREADSAFE_EVAL[] == (Threads.nthreads() > 1)
12+
# Disable it.
13+
DynamicPPL.set_threadsafe_eval!(false)
14+
@test DynamicPPL.USE_THREADSAFE_EVAL[] == false
15+
@test !(model() <: DynamicPPL.ThreadSafeVarInfo)
16+
# Enable it.
17+
DynamicPPL.set_threadsafe_eval!(true)
18+
@test DynamicPPL.USE_THREADSAFE_EVAL[] == true
19+
@test model() <: DynamicPPL.ThreadSafeVarInfo
20+
# Reset to default to avoid messing with other tests.
21+
DynamicPPL.set_threadsafe_eval!(Threads.nthreads() > 1)
22+
end
23+
224
@testset "constructor" begin
325
vi = VarInfo(gdemo_default)
426
threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi)

0 commit comments

Comments
 (0)