850850# ^ Weird Documenter.jl bug means that we have to write the two above separately
851851# as it can only detect the `function`-less syntax.
852852function (model:: Model )(rng:: Random.AbstractRNG , varinfo:: AbstractVarInfo = VarInfo ())
853- return first (evaluate_and_sample !! (rng, model, varinfo))
853+ return first (init !! (rng, model, varinfo))
854854end
855855
856856"""
@@ -864,29 +864,35 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
864864end
865865
866866"""
867- evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler])
868-
869- Evaluate the `model` with the given `varinfo`, but perform sampling during the
870- evaluation using the given `sampler` by wrapping the model's context in a
871- `SamplingContext`.
867+ init!!(
868+ [rng::Random.AbstractRNG, ]
869+ model::Model,
870+ varinfo::AbstractVarInfo,
871+ [init_strategy::AbstractInitStrategy=PriorInit()]
872+ )
872873
873- If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref).
874+ Evaluate the `model` and replace the values of the model's random variables
875+ in the given `varinfo` with new values, using a specified initialisation strategy.
876+ If the values in `varinfo` are not set, they will be added.
877+ using a specified initialisation strategy. If `init_strategy` is not provided,
878+ defaults to PriorInit().
874879
875880Returns a tuple of the model's return value, plus the updated `varinfo` object.
876881"""
877- function evaluate_and_sample !! (
882+ function init !! (
878883 rng:: Random.AbstractRNG ,
879884 model:: Model ,
880885 varinfo:: AbstractVarInfo ,
881- sampler :: AbstractSampler = SampleFromPrior (),
886+ init_strategy :: AbstractInitStrategy = PriorInit (),
882887)
883- sampling_model = contextualize (model, SamplingContext (rng, sampler, model. context))
884- return evaluate!! (sampling_model, varinfo)
888+ new_context = setleafcontext (model. context, InitContext (rng, init_strategy))
889+ new_model = contextualize (model, new_context)
890+ return evaluate!! (new_model, varinfo)
885891end
886- function evaluate_and_sample !! (
887- model:: Model , varinfo:: AbstractVarInfo , sampler :: AbstractSampler = SampleFromPrior ()
892+ function init !! (
893+ model:: Model , varinfo:: AbstractVarInfo , init_strategy :: AbstractInitStrategy = PriorInit ()
888894)
889- return evaluate_and_sample !! (Random. default_rng (), model, varinfo, sampler )
895+ return init !! (Random. default_rng (), model, varinfo, init_strategy )
890896end
891897
892898"""
@@ -1049,11 +1055,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f)
10491055Generate a sample of type `T` from the prior distribution of the `model`.
10501056"""
10511057function Base. rand (rng:: Random.AbstractRNG , :: Type{T} , model:: Model ) where {T}
1052- x = last (
1053- evaluate_and_sample!! (
1054- rng, model, SimpleVarInfo {Float64} (OrderedDict {VarName,Any} ())
1055- ),
1056- )
1058+ x = last (init!! (rng, model, SimpleVarInfo {Float64} (OrderedDict {VarName,Any} ())))
10571059 return values_as (x, T)
10581060end
10591061
@@ -1280,3 +1282,38 @@ end
12801282function returned (model:: Model , values, keys)
12811283 return returned (model, NamedTuple {keys} (values))
12821284end
1285+
1286+ """
1287+ prefix(model::Model, x::VarName)
1288+ prefix(model::Model, x::Val{sym})
1289+ prefix(model::Model, x::Any)
1290+
1291+ Return `model` but with all random variables prefixed by `x`, where `x` is either:
1292+ - a `VarName` (e.g. `@varname(a)`),
1293+ - a `Val{sym}` (e.g. `Val(:a)`), or
1294+ - for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that
1295+ this will introduce runtime overheads so is not recommended unless absolutely
1296+ necessary.
1297+
1298+ # Examples
1299+
1300+ ```jldoctest
1301+ julia> using DynamicPPL: prefix
1302+
1303+ julia> @model demo() = x ~ Dirac(1)
1304+ demo (generic function with 2 methods)
1305+
1306+ julia> rand(prefix(demo(), @varname(my_prefix)))
1307+ (var"my_prefix.x" = 1,)
1308+
1309+ julia> rand(prefix(demo(), Val(:my_prefix)))
1310+ (var"my_prefix.x" = 1,)
1311+ ```
1312+ """
1313+ prefix (model:: Model , x:: VarName ) = contextualize (model, PrefixContext (x, model. context))
1314+ function prefix (model:: Model , x:: Val{sym} ) where {sym}
1315+ return contextualize (model, PrefixContext (VarName {sym} (), model. context))
1316+ end
1317+ function prefix (model:: Model , x)
1318+ return contextualize (model, PrefixContext (VarName {Symbol(x)} (), model. context))
1319+ end
0 commit comments