@@ -39,7 +39,7 @@ julia> rng = StableRNG(42);
3939julia> # In the `NamedTuple` version we need to provide the place-holder values for
4040 # the variables which are using "containers", e.g. `Array`.
4141 # In this case, this means that we need to specify `x` but not `m`.
42- _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, SimpleVarInfo((x = ones(2), )));
42+ _, vi = DynamicPPL.init !!(rng, m, SimpleVarInfo((x = ones(2), )));
4343
4444julia> # (✓) Vroom, vroom! FAST!!!
4545 vi[@varname(x[1])]
@@ -57,12 +57,12 @@ julia> vi[@varname(x[1:2])]
5757 1.3736306979834252
5858
5959julia> # (×) If we don't provide the container...
60- _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, SimpleVarInfo()); vi
60+ _, vi = DynamicPPL.init !!(rng, m, SimpleVarInfo()); vi
6161ERROR: type NamedTuple has no field x
6262[...]
6363
6464julia> # If one does not know the varnames, we can use a `OrderedDict` instead.
65- _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()));
65+ _, vi = DynamicPPL.init !!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()));
6666
6767julia> # (✓) Sort of fast, but only possible at runtime.
6868 vi[@varname(x[1])]
@@ -91,28 +91,28 @@ demo_constrained (generic function with 2 methods)
9191
9292julia> m = demo_constrained();
9393
94- julia> _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, SimpleVarInfo());
94+ julia> _, vi = DynamicPPL.init !!(rng, m, SimpleVarInfo());
9595
9696julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞
97971.8632965762164932
9898
99- julia> _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true));
99+ julia> _, vi = DynamicPPL.init !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true));
100100
101101julia> vi[@varname(x)] # (✓) -∞ < x < ∞
102102-0.21080155351918753
103103
104- julia> xs = [last(DynamicPPL.evaluate_and_sample !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
104+ julia> xs = [last(DynamicPPL.init !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
105105
106106julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
107107true
108108
109109julia> # And with `OrderedDict` of course!
110- _, vi = DynamicPPL.evaluate_and_sample !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));
110+ _, vi = DynamicPPL.init !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));
111111
112112julia> vi[@varname(x)] # (✓) -∞ < x < ∞
1131130.6225185067787314
114114
115- julia> xs = [last(DynamicPPL.evaluate_and_sample !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
115+ julia> xs = [last(DynamicPPL.init !!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
116116
117117julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
118118true
@@ -226,24 +226,25 @@ end
226226
227227# Constructor from `Model`.
228228function SimpleVarInfo {T} (
229- rng:: Random.AbstractRNG , model:: Model , sampler :: AbstractSampler = SampleFromPrior ()
229+ rng:: Random.AbstractRNG , model:: Model , init_strategy :: AbstractInitStrategy = PriorInit ()
230230) where {T<: Real }
231- new_model = contextualize (model, SamplingContext (rng, sampler, model. context))
231+ new_context = setleafcontext (model. context, InitContext (rng, init_strategy))
232+ new_model = contextualize (model, new_context)
232233 return last (evaluate!! (new_model, SimpleVarInfo {T} ()))
233234end
234235function SimpleVarInfo {T} (
235- model:: Model , sampler :: AbstractSampler = SampleFromPrior ()
236+ model:: Model , init_strategy :: AbstractInitStrategy = PriorInit ()
236237) where {T<: Real }
237- return SimpleVarInfo {T} (Random. default_rng (), model, sampler )
238+ return SimpleVarInfo {T} (Random. default_rng (), model, init_strategy )
238239end
239240# Constructors without type param
240241function SimpleVarInfo (
241- rng:: Random.AbstractRNG , model:: Model , sampler :: AbstractSampler = SampleFromPrior ()
242+ rng:: Random.AbstractRNG , model:: Model , init_strategy :: AbstractInitStrategy = PriorInit ()
242243)
243- return SimpleVarInfo {LogProbType} (rng, model, sampler )
244+ return SimpleVarInfo {LogProbType} (rng, model, init_strategy )
244245end
245- function SimpleVarInfo (model:: Model , sampler :: AbstractSampler = SampleFromPrior ())
246- return SimpleVarInfo {LogProbType} (Random. default_rng (), model, sampler )
246+ function SimpleVarInfo (model:: Model , init_strategy :: AbstractInitStrategy = PriorInit ())
247+ return SimpleVarInfo {LogProbType} (Random. default_rng (), model, init_strategy )
247248end
248249
249250# Constructor from `VarInfo`.
@@ -259,12 +260,12 @@ end
259260
260261function untyped_simple_varinfo (model:: Model )
261262 varinfo = SimpleVarInfo (OrderedDict {VarName,Any} ())
262- return last (evaluate_and_sample !! (model, varinfo))
263+ return last (init !! (model, varinfo))
263264end
264265
265266function typed_simple_varinfo (model:: Model )
266267 varinfo = SimpleVarInfo {Float64} ()
267- return last (evaluate_and_sample !! (model, varinfo))
268+ return last (init !! (model, varinfo))
268269end
269270
270271function unflatten (svi:: SimpleVarInfo , x:: AbstractVector )
@@ -474,7 +475,6 @@ function assume(
474475 return value, vi
475476end
476477
477- # NOTE: We don't implement `settrans!!(vi, trans, vn)`.
478478function settrans!! (vi:: SimpleVarInfo , trans)
479479 return settrans!! (vi, trans ? DynamicTransformation () : NoTransformation ())
480480end
484484function settrans!! (vi:: ThreadSafeVarInfo{<:SimpleVarInfo} , trans)
485485 return Accessors. @set vi. varinfo = settrans!! (vi. varinfo, trans)
486486end
487+ function settrans!! (vi:: SimpleOrThreadSafeSimple , trans:: Bool , :: VarName )
488+ # We keep this method around just to obey the AbstractVarInfo interface; however,
489+ # this is only a valid operation if it would be a no-op.
490+ if trans != istrans (vi)
491+ error (
492+ " Individual variables in SimpleVarInfo cannot have different `settrans` statuses." ,
493+ )
494+ end
495+ end
487496
488497istrans (vi:: SimpleVarInfo ) = ! (vi. transformation isa NoTransformation)
489498istrans (vi:: SimpleVarInfo , :: VarName ) = istrans (vi)
0 commit comments