|
| 1 | +const USE_THREADSAFE_EVAL = Ref(Threads.nthreads() > 1) |
| 2 | + |
| 3 | +""" |
| 4 | + DynamicPPL.set_threadsafe_eval!(val::Bool) |
| 5 | +
|
| 6 | +Enable or disable threadsafe model evaluation globally. By default, threadsafe evaluation is |
| 7 | +used whenever Julia is run with multiple threads. |
| 8 | +
|
| 9 | +However, this is not necessary for the vast majority of DynamicPPL models. **In particular, |
| 10 | +use of threaded sampling with MCMCChains alone does NOT require threadsafe evaluation.** |
| 11 | +Threadsafe evaluation is only needed when manipulating `VarInfo` objects in parallel, e.g. |
| 12 | +when using `x ~ dist` statements inside `Threads.@threads` blocks. |
| 13 | +
|
| 14 | +If you do not need threadsafe evaluation, disabling it can lead to significant performance |
| 15 | +improvements. |
| 16 | +""" |
| 17 | +function set_threadsafe_eval!(val::Bool) |
| 18 | + USE_THREADSAFE_EVAL[] = val |
| 19 | + return nothing |
| 20 | +end |
| 21 | + |
1 | 22 | """ |
2 | 23 | struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} |
3 | 24 | f::F |
@@ -863,16 +884,6 @@ function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInf |
863 | 884 | return first(init!!(rng, model, varinfo)) |
864 | 885 | end |
865 | 886 |
|
866 | | -""" |
867 | | - use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) |
868 | | -
|
869 | | -Return `true` if evaluation of a model using `context` and `varinfo` should |
870 | | -wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise. |
871 | | -""" |
872 | | -function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) |
873 | | - return Threads.nthreads() > 1 |
874 | | -end |
875 | | - |
876 | 887 | """ |
877 | 888 | init!!( |
878 | 889 | [rng::Random.AbstractRNG,] |
@@ -912,14 +923,14 @@ end |
912 | 923 |
|
913 | 924 | Evaluate the `model` with the given `varinfo`. |
914 | 925 |
|
915 | | -If multiple threads are available, the varinfo provided will be wrapped in a |
| 926 | +If threadsafe evaluation is enabled, the varinfo provided will be wrapped in a |
916 | 927 | `ThreadSafeVarInfo` before evaluation. |
917 | 928 |
|
918 | 929 | Returns a tuple of the model's return value, plus the updated `varinfo` |
919 | 930 | (unwrapped if necessary). |
920 | 931 | """ |
921 | 932 | function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) |
922 | | - return if use_threadsafe_eval(model.context, varinfo) |
| 933 | + return if DynamicPPL.USE_THREADSAFE_EVAL[] |
923 | 934 | evaluate_threadsafe!!(model, varinfo) |
924 | 935 | else |
925 | 936 | evaluate_threadunsafe!!(model, varinfo) |
|
0 commit comments