@@ -39,7 +39,7 @@ julia> rng = StableRNG(42);
3939julia> # In the `NamedTuple` version we need to provide the place-holder values for
4040 # the variables which are using "containers", e.g. `Array`.
4141 # In this case, this means that we need to specify `x` but not `m`.
42- _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, SimpleVarInfo((x = ones(2), )));
42+ _, vi = DynamicPPL.init !!(rng, m, SimpleVarInfo((x = ones(2), )));
4343
4444julia> # (✓) Vroom, vroom! FAST!!!
4545 vi[@varname(x[1])]
@@ -57,12 +57,12 @@ julia> vi[@varname(x[1:2])]
5757 1.3736306979834252
5858
5959julia> # (×) If we don't provide the container...
60- _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, SimpleVarInfo()); vi
60+ _, vi = DynamicPPL.init !!(rng, m, SimpleVarInfo()); vi
6161ERROR: type NamedTuple has no field x
6262[...]
6363
6464julia> # If one does not know the varnames, we can use a `OrderedDict` instead.
65- _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()));
65+ _, vi = DynamicPPL.init !!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()));
6666
6767julia> # (✓) Sort of fast, but only possible at runtime.
6868 vi[@varname(x[1])]
@@ -91,28 +91,28 @@ demo_constrained (generic function with 2 methods)
9191
9292julia> m = demo_constrained();
9393
94- julia> _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, SimpleVarInfo());
94+ julia> _, vi = DynamicPPL.init !!(rng, m, SimpleVarInfo());
9595
9696julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞
97971.8632965762164932
9898
99- julia> _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true));
99+ julia> _, vi = DynamicPPL.init !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true));
100100
101101julia> vi[@varname(x)] # (✓) -∞ < x < ∞
102102-0.21080155351918753
103103
104- julia> xs = [last(DynamicPPL.evaluate_and_sample !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
104+ julia> xs = [last(DynamicPPL.init !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
105105
106106julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
107107true
108108
109109julia> # And with `OrderedDict` of course!
110- _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));
110+ _, vi = DynamicPPL.init !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));
111111
112112julia> vi[@varname(x)] # (✓) -∞ < x < ∞
1131130.6225185067787314
114114
115- julia> xs = [last(DynamicPPL.evaluate_and_sample !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
115+ julia> xs = [last(DynamicPPL.init !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
116116
117117julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
118118true
@@ -232,24 +232,25 @@ end
232232
233233# Constructor from `Model`.
234234function SimpleVarInfo {T} (
235- rng:: Random.AbstractRNG , model:: Model , sampler :: AbstractSampler = SampleFromPrior ()
235+ rng:: Random.AbstractRNG , model:: Model , init_strategy :: AbstractInitStrategy = PriorInit ()
236236) where {T<: Real }
237- new_model = contextualize (model, SamplingContext (rng, sampler, model. context))
237+ new_context = setleafcontext (model. context, InitContext (rng, init_strategy))
238+ new_model = contextualize (model, new_context)
238239 return last (evaluate!! (new_model, SimpleVarInfo {T} ()))
239240end
240241function SimpleVarInfo {T} (
241- model:: Model , sampler :: AbstractSampler = SampleFromPrior ()
242+ model:: Model , init_strategy :: AbstractInitStrategy = PriorInit ()
242243) where {T<: Real }
243- return SimpleVarInfo {T} (Random. default_rng (), model, sampler )
244+ return SimpleVarInfo {T} (Random. default_rng (), model, init_strategy )
244245end
245246# Constructors without type param
246247function SimpleVarInfo (
247- rng:: Random.AbstractRNG , model:: Model , sampler :: AbstractSampler = SampleFromPrior ()
248+ rng:: Random.AbstractRNG , model:: Model , init_strategy :: AbstractInitStrategy = PriorInit ()
248249)
249- return SimpleVarInfo {LogProbType} (rng, model, sampler )
250+ return SimpleVarInfo {LogProbType} (rng, model, init_strategy )
250251end
251- function SimpleVarInfo (model:: Model , sampler :: AbstractSampler = SampleFromPrior ())
252- return SimpleVarInfo {LogProbType} (Random. default_rng (), model, sampler )
252+ function SimpleVarInfo (model:: Model , init_strategy :: AbstractInitStrategy = PriorInit ())
253+ return SimpleVarInfo {LogProbType} (Random. default_rng (), model, init_strategy )
253254end
254255
255256# Constructor from `VarInfo`.
@@ -265,12 +266,12 @@ end
265266
266267function untyped_simple_varinfo (model:: Model )
267268 varinfo = SimpleVarInfo (OrderedDict {VarName,Any} ())
268- return last (evaluate_and_sample !! (model, varinfo))
269+ return last (init !! (model, varinfo))
269270end
270271
271272function typed_simple_varinfo (model:: Model )
272273 varinfo = SimpleVarInfo {Float64} ()
273- return last (evaluate_and_sample !! (model, varinfo))
274+ return last (init !! (model, varinfo))
274275end
275276
276277function unflatten (svi:: SimpleVarInfo , x:: AbstractVector )
@@ -480,7 +481,6 @@ function assume(
480481 return value, vi
481482end
482483
483- # NOTE: We don't implement `settrans!!(vi, trans, vn)`.
484484function settrans!! (vi:: SimpleVarInfo , trans)
485485 return settrans!! (vi, trans ? DynamicTransformation () : NoTransformation ())
486486end
490490function settrans!! (vi:: ThreadSafeVarInfo{<:SimpleVarInfo} , trans)
491491 return Accessors. @set vi. varinfo = settrans!! (vi. varinfo, trans)
492492end
493+ function settrans!! (vi:: SimpleOrThreadSafeSimple , trans:: Bool , :: VarName )
494+ # We keep this method around just to obey the AbstractVarInfo interface; however,
495+ # this is only a valid operation if it would be a no-op.
496+ if trans != istrans (vi)
497+ error (
498+ " Individual variables in SimpleVarInfo cannot have different `settrans` statuses." ,
499+ )
500+ end
501+ end
493502
494503istrans (vi:: SimpleVarInfo ) = ! (vi. transformation isa NoTransformation)
495504istrans (vi:: SimpleVarInfo , :: VarName ) = istrans (vi)
0 commit comments