11# assume
2- """
3- tilde_assume(context::SamplingContext, right, vn, vi)
4-
5- Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
6- accumulate the log probability, and return the sampled value with a context associated
7- with a sampler.
8-
9- Falls back to
10- ```julia
11- tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
12- ```
13- """
14- function tilde_assume (context:: SamplingContext , right, vn, vi)
15- return tilde_assume (context. rng, context. context, context. sampler, right, vn, vi)
16- end
17-
182function tilde_assume (context:: AbstractContext , args... )
193 return tilde_assume (childcontext (context), args... )
204end
215function tilde_assume (:: DefaultContext , right, vn, vi)
22- return assume (right, vn, vi)
23- end
24-
25- function tilde_assume (rng:: Random.AbstractRNG , context:: AbstractContext , args... )
26- return tilde_assume (rng, childcontext (context), args... )
27- end
28- function tilde_assume (rng:: Random.AbstractRNG , :: DefaultContext , sampler, right, vn, vi)
29- return assume (rng, sampler, right, vn, vi)
30- end
31- function tilde_assume (:: DefaultContext , sampler, right, vn, vi)
32- # same as above but no rng
33- return assume (Random. default_rng (), sampler, right, vn, vi)
6+ y = getindex_internal (vi, vn)
7+ f = from_maybe_linked_internal_transform (vi, vn, right)
8+ x, inv_logjac = with_logabsdet_jacobian (f, y)
9+ vi = accumulate_assume!! (vi, x, - inv_logjac, vn, right)
10+ return x, vi
3411end
35-
3612function tilde_assume (context:: PrefixContext , right, vn, vi)
3713 # Note that we can't use something like this here:
3814 # new_vn = prefix(context, vn)
@@ -46,12 +22,6 @@ function tilde_assume(context::PrefixContext, right, vn, vi)
4622 new_vn, new_context = prefix_and_strip_contexts (context, vn)
4723 return tilde_assume (new_context, right, new_vn, vi)
4824end
49- function tilde_assume (
50- rng:: Random.AbstractRNG , context:: PrefixContext , sampler, right, vn, vi
51- )
52- new_vn, new_context = prefix_and_strip_contexts (context, vn)
53- return tilde_assume (rng, new_context, sampler, right, new_vn, vi)
54- end
5525
5626"""
5727 tilde_assume!!(context, right, vn, vi)
@@ -71,17 +41,6 @@ function tilde_assume!!(context, right, vn, vi)
7141end
7242
7343# observe
74- """
75- tilde_observe!!(context::SamplingContext, right, left, vi)
76-
77- Handle observed constants with a `context` associated with a sampler.
78-
79- Falls back to `tilde_observe!!(context.context, right, left, vi)`.
80- """
81- function tilde_observe!! (context:: SamplingContext , right, left, vn, vi)
82- return tilde_observe!! (context. context, right, left, vn, vi)
83- end
84-
8544function tilde_observe!! (context:: AbstractContext , right, left, vn, vi)
8645 return tilde_observe!! (childcontext (context), right, left, vn, vi)
8746end
@@ -114,58 +73,3 @@ function tilde_observe!!(::DefaultContext, right, left, vn, vi)
11473 vi = accumulate_observe!! (vi, right, left, vn)
11574 return left, vi
11675end
117-
118- function assume (:: Random.AbstractRNG , spl:: Sampler , dist)
119- return error (" DynamicPPL.assume: unmanaged inference algorithm: $(typeof (spl)) " )
120- end
121-
122- # fallback without sampler
123- function assume (dist:: Distribution , vn:: VarName , vi)
124- y = getindex_internal (vi, vn)
125- f = from_maybe_linked_internal_transform (vi, vn, dist)
126- x, inv_logjac = with_logabsdet_jacobian (f, y)
127- vi = accumulate_assume!! (vi, x, - inv_logjac, vn, dist)
128- return x, vi
129- end
130-
131- # TODO : Remove this thing.
132- # SampleFromPrior and SampleFromUniform
133- function assume (
134- rng:: Random.AbstractRNG ,
135- sampler:: Union{SampleFromPrior,SampleFromUniform} ,
136- dist:: Distribution ,
137- vn:: VarName ,
138- vi:: VarInfoOrThreadSafeVarInfo ,
139- )
140- if haskey (vi, vn)
141- # Always overwrite the parameters with new ones for `SampleFromUniform`.
142- if sampler isa SampleFromUniform || is_flagged (vi, vn, " del" )
143- # TODO (mhauru) Is it important to unset the flag here? The `true` allows us
144- # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
145- # if that's okay.
146- unset_flag! (vi, vn, " del" , true )
147- r = init (rng, dist, sampler)
148- f = to_maybe_linked_internal_transform (vi, vn, dist)
149- # TODO (mhauru) This should probably be call a function called setindex_internal!
150- vi = BangBang. setindex!! (vi, f (r), vn)
151- else
152- # Otherwise we just extract it.
153- r = vi[vn, dist]
154- end
155- else
156- r = init (rng, dist, sampler)
157- if istrans (vi)
158- f = to_linked_internal_transform (vi, vn, dist)
159- vi = push!! (vi, vn, f (r), dist)
160- # By default `push!!` sets the transformed flag to `false`.
161- vi = settrans!! (vi, true , vn)
162- else
163- vi = push!! (vi, vn, r, dist)
164- end
165- end
166-
167- # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
168- logjac = logabsdetjac (istrans (vi, vn) ? link_transform (dist) : identity, r)
169- vi = accumulate_assume!! (vi, r, logjac, vn, dist)
170- return r, vi
171- end
0 commit comments