3838"""
3939 tilde_assume(ctx, sampler, right, vn, inds, vi)
4040
41- This method is applied in the generated code for assumed variables, e.g., `x ~ Normal()` where
42- `x` does not occur in the model inputs .
41+ Handle assumed variables, e.g., `x ~ Normal()` ( where `x` does occur in the model inputs),
42+ accumulate the log probability, and return the sampled value .
4343
4444Falls back to `tilde(ctx, sampler, right, vn, inds, vi)`.
4545"""
4646function tilde_assume (ctx, sampler, right, vn, inds, vi)
47- return tilde (ctx, sampler, right, vn, inds, vi)
47+ value, logp = tilde (ctx, sampler, right, vn, inds, vi)
48+ acclogp! (vi, logp)
49+ return value
4850end
4951
5052
7274"""
7375 tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
7476
75- This method is applied in the generated code for observed variables, e.g., `x ~ Normal()` where
76- `x` does occur in the model inputs .
77+ Handle observed variables, e.g., `x ~ Normal()` ( where `x` does occur in the model inputs),
78+ accumulate the log probability, and return the observed value .
7779
78- Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable
79- name and indices; if needed, these can be accessed through this function, though.
80+ Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable name
81+ and indices; if needed, these can be accessed through this function, though.
8082"""
8183function tilde_observe (ctx, sampler, right, left, vname, vinds, vi)
82- return tilde (ctx, sampler, right, left, vi)
84+ logp = tilde (ctx, sampler, right, left, vi)
85+ acclogp! (vi, logp)
86+ return left
8387end
8488
8589"""
8690 tilde_observe(ctx, sampler, right, left, vi)
8791
88- This method is applied in the generated code for observed constants, e.g., `1.0 ~ Normal()`.
92+ Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the
93+ observed value.
94+
8995Falls back to `tilde(ctx, sampler, right, left, vi)`.
9096"""
9197function tilde_observe (ctx, sampler, right, left, vi)
92- return tilde (ctx, sampler, right, left, vi)
98+ logp = tilde (ctx, sampler, right, left, vi)
99+ acclogp! (vi, logp)
100+ return left
93101end
94102
95103
@@ -103,24 +111,48 @@ function observe(spl::Sampler, weight)
103111 error (" DynamicPPL.observe: unmanaged inference algorithm: $(typeof (spl)) " )
104112end
105113
114+ # If parameters exist, they are used and not overwritten.
106115function assume (
107- spl:: Union{ SampleFromPrior, SampleFromUniform} ,
116+ spl:: SampleFromPrior ,
108117 dist:: Distribution ,
109118 vn:: VarName ,
110119 vi:: VarInfo ,
111120)
112121 if haskey (vi, vn)
113122 if is_flagged (vi, vn, " del" )
114123 unset_flag! (vi, vn, " del" )
115- r = spl isa SampleFromUniform ? init (dist) : rand (dist)
124+ r = rand (dist)
116125 vi[vn] = vectorize (dist, r)
126+ settrans! (vi, false , vn)
117127 setorder! (vi, vn, get_num_produce (vi))
118128 else
119- r = vi[vn]
129+ r = vi[vn]
120130 end
121131 else
122- r = isa (spl, SampleFromUniform) ? init (dist) : rand (dist)
132+ r = rand (dist)
123133 push! (vi, vn, r, dist, spl)
134+ settrans! (vi, false , vn)
135+ end
136+ return r, Bijectors. logpdf_with_trans (dist, r, istrans (vi, vn))
137+ end
138+
139+ # Always overwrites the parameters with new ones.
140+ function assume (
141+ spl:: SampleFromUniform ,
142+ dist:: Distribution ,
143+ vn:: VarName ,
144+ vi:: VarInfo ,
145+ )
146+ if haskey (vi, vn)
147+ unset_flag! (vi, vn, " del" )
148+ r = init (dist)
149+ vi[vn] = vectorize (dist, r)
150+ settrans! (vi, true , vn)
151+ setorder! (vi, vn, get_num_produce (vi))
152+ else
153+ r = init (dist)
154+ push! (vi, vn, r, dist, spl)
155+ settrans! (vi, true , vn)
124156 end
125157 # NOTE: The importance weight is not correctly computed here because
126158 # r is genereated from some uniform distribution which is different from the prior
@@ -191,13 +223,15 @@ end
191223"""
192224 dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
193225
194- This method is applied in the generated code for assumed vectorized variables, e.g., `x .~
195- MvNormal()` where `x` does not occur in the model inputs .
226+ Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the
227+ model inputs), accumulate the log probability, and return the sampled value .
196228
197229Falls back to `dot_tilde(ctx, sampler, right, left, vn, inds, vi)`.
198230"""
199231function dot_tilde_assume (ctx, sampler, right, left, vn, inds, vi)
200- return dot_tilde (ctx, sampler, right, left, vn, inds, vi)
232+ value, logp = dot_tilde (ctx, sampler, right, left, vn, inds, vi)
233+ acclogp! (vi, logp)
234+ return value
201235end
202236
203237
@@ -367,24 +401,30 @@ end
367401"""
368402 dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
369403
370- This method is applied in the generated code for vectorized observed variables , e.g., `x .~
371- MvNormal()` where `x` does occur the model inputs .
404+ Handle broadcasted observed values , e.g., `x .~ MvNormal()` (where `x` does occur the model inputs),
405+ accumulate the log probability, and return the observed value .
372406
373407Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the information about variable
374408name and indices; if needed, these can be accessed through this function, though.
375409"""
376410function dot_tilde_observe (ctx, sampler, right, left, vn, inds, vi)
377- return dot_tilde (ctx, sampler, right, left, vi)
411+ logp = dot_tilde (ctx, sampler, right, left, vi)
412+ acclogp! (vi, logp)
413+ return left
378414end
379415
380416"""
381417 dot_tilde_observe(ctx, sampler, right, left, vi)
382418
383- This method is applied in the generated code for vectorized observed constants, e.g., `[1.0] .~
384- MvNormal()`. Falls back to `dot_tilde(ctx, sampler, right, left, vi)`.
419+ Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log
420+ probability, and return the observed value.
421+
422+ Falls back to `dot_tilde(ctx, sampler, right, left, vi)`.
385423"""
386424function dot_tilde_observe (ctx, sampler, right, left, vi)
387- return dot_tilde (ctx, sampler, right, left, vi)
425+ logp = dot_tilde (ctx, sampler, right, left, vi)
426+ acclogp! (vi, logp)
427+ return left
388428end
389429
390430
0 commit comments