@@ -36,17 +36,16 @@ function tilde(ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi)
3636end
3737
3838"""
39- tilde_assume(ctx, sampler, right, vn, inds, vi, logps )
39+ tilde_assume(ctx, sampler, right, vn, inds, vi)
4040
4141Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
42- accumulate the log probability in `logps` (separately for each thread), and return the
43- sampled value.
42+ accumulate the log probability, and return the sampled value.
4443
4544Falls back to `tilde(ctx, sampler, right, vn, inds, vi)`.
4645"""
47- function tilde_assume (ctx, sampler, right, vn, inds, vi, logps )
46+ function tilde_assume (ctx, sampler, right, vn, inds, vi)
4847 value, logp = tilde (ctx, sampler, right, vn, inds, vi)
49- logps[Threads . threadid ()] += logp
48+ acclogp! (vi, logp)
5049 return value
5150end
5251
7675 tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
7776
7877Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
79- accumulate the log probability in `logps` (separately for each thread), and return the
80- observed value.
78+ accumulate the log probability, and return the observed value.
8179
8280Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable name
8381and indices; if needed, these can be accessed through this function, though.
8482"""
85- function tilde_observe (ctx, sampler, right, left, vname, vinds, vi, logps )
83+ function tilde_observe (ctx, sampler, right, left, vname, vinds, vi)
8684 logp = tilde (ctx, sampler, right, left, vi)
87- logps[Threads . threadid ()] += logp
85+ acclogp! (vi, logp)
8886 return left
8987end
9088
9189"""
92- tilde_observe(ctx, sampler, right, left, vi, logps )
90+ tilde_observe(ctx, sampler, right, left, vi)
9391
94- Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability in `logps`
95- (separately for each thread), and return the observed value.
92+ Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and
93+ return the observed value.
9694
9795Falls back to `tilde(ctx, sampler, right, left, vi)`.
9896"""
99- function tilde_observe (ctx, sampler, right, left, vi, logps )
97+ function tilde_observe (ctx, sampler, right, left, vi)
10098 logp = tilde (ctx, sampler, right, left, vi)
101- logps[Threads . threadid ()] += logp
99+ acclogp! (vi, logp)
102100 return left
103101end
104102
@@ -117,7 +115,7 @@ function assume(
117115 spl:: Union{SampleFromPrior,SampleFromUniform} ,
118116 dist:: Distribution ,
119117 vn:: VarName ,
120- vi:: VarInfo ,
118+ vi,
121119)
122120 if haskey (vi, vn)
123121 # Always overwrite the parameters with new ones for `SampleFromUniform`.
@@ -142,7 +140,7 @@ function observe(
142140 spl:: Union{SampleFromPrior, SampleFromUniform} ,
143141 dist:: Distribution ,
144142 value,
145- vi:: VarInfo ,
143+ vi,
146144)
147145 increment_num_produce! (vi)
148146 return Distributions. logpdf (dist, value)
@@ -201,14 +199,13 @@ end
201199 dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
202200
203201Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the
204- model inputs), accumulate the log probability in `logps` (separately for each thread), and
205- return the sampled value.
202+ model inputs), accumulate the log probability, and return the sampled value.
206203
207204Falls back to `dot_tilde(ctx, sampler, right, left, vn, inds, vi)`.
208205"""
209- function dot_tilde_assume (ctx, sampler, right, left, vn, inds, vi, logps )
206+ function dot_tilde_assume (ctx, sampler, right, left, vn, inds, vi)
210207 value, logp = dot_tilde (ctx, sampler, right, left, vn, inds, vi)
211- logps[Threads . threadid ()] += logp
208+ acclogp! (vi, logp)
212209 return value
213210end
214211
@@ -240,7 +237,7 @@ function _dot_tilde(
240237 right:: Union{MultivariateDistribution, AbstractVector{<:MultivariateDistribution}} ,
241238 left:: AbstractMatrix{>:AbstractVector} ,
242239 vn:: AbstractVector{<:VarName} ,
243- vi:: VarInfo ,
240+ vi,
244241)
245242 throw (ambiguity_error_msg ())
246243end
@@ -250,7 +247,7 @@ function dot_assume(
250247 dist:: MultivariateDistribution ,
251248 vns:: AbstractVector{<:VarName} ,
252249 var:: AbstractMatrix ,
253- vi:: VarInfo ,
250+ vi,
254251)
255252 @assert length (dist) == size (var, 1 )
256253 r = get_and_set_val! (vi, vns, dist, spl)
@@ -263,7 +260,7 @@ function dot_assume(
263260 dists:: Union{Distribution, AbstractArray{<:Distribution}} ,
264261 vns:: AbstractArray{<:VarName} ,
265262 var:: AbstractArray ,
266- vi:: VarInfo ,
263+ vi,
267264)
268265 r = get_and_set_val! (vi, vns, dists, spl)
269266 # Make sure `r` is not a matrix for multivariate distributions
@@ -276,13 +273,13 @@ function dot_assume(
276273 :: Any ,
277274 :: AbstractArray{<:VarName} ,
278275 :: Any ,
279- :: VarInfo
276+ :: Any ,
280277)
281278 error (" [DynamicPPL] $(alg_str (spl)) doesn't support vectorizing assume statement" )
282279end
283280
284281function get_and_set_val! (
285- vi:: VarInfo ,
282+ vi,
286283 vns:: AbstractVector{<:VarName} ,
287284 dist:: MultivariateDistribution ,
288285 spl:: Union{SampleFromPrior,SampleFromUniform} ,
@@ -313,7 +310,7 @@ function get_and_set_val!(
313310end
314311
315312function get_and_set_val! (
316- vi:: VarInfo ,
313+ vi,
317314 vns:: AbstractArray{<:VarName} ,
318315 dists:: Union{Distribution, AbstractArray{<:Distribution}} ,
319316 spl:: Union{SampleFromPrior,SampleFromUniform} ,
@@ -344,7 +341,7 @@ function get_and_set_val!(
344341end
345342
346343function set_val! (
347- vi:: VarInfo ,
344+ vi,
348345 vns:: AbstractVector{<:VarName} ,
349346 dist:: MultivariateDistribution ,
350347 val:: AbstractMatrix ,
@@ -356,7 +353,7 @@ function set_val!(
356353 return val
357354end
358355function set_val! (
359- vi:: VarInfo ,
356+ vi,
360357 vns:: AbstractArray{<:VarName} ,
361358 dists:: Union{Distribution, AbstractArray{<:Distribution}} ,
362359 val:: AbstractArray ,
@@ -384,36 +381,34 @@ function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi)
384381end
385382
386383"""
387- dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi, logps )
384+ dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
388385
389386Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs),
390- accumulate the log probability in `logps` (separately for each thread), and return the
391- observed value.
387+ accumulate the log probability, and return the observed value.
392388
393389Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the information about variable
394390name and indices; if needed, these can be accessed through this function, though.
395391"""
396- function dot_tilde_observe (ctx, sampler, right, left, vn, inds, vi, logps )
392+ function dot_tilde_observe (ctx, sampler, right, left, vn, inds, vi)
397393 logp = dot_tilde (ctx, sampler, right, left, vi)
398- logps[Threads . threadid ()] += logp
394+ acclogp! (vi, logp)
399395 return left
400396end
401397
402398"""
403- dot_tilde_observe(ctx, sampler, right, left, vi, logps )
399+ dot_tilde_observe(ctx, sampler, right, left, vi)
404400
405401Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log
406- probability in `logps` (separately for each thread) , and return the observed value.
402+ probability, and return the observed value.
407403
408404Falls back to `dot_tilde(ctx, sampler, right, left, vi)`.
409405"""
410- function dot_tilde_observe (ctx, sampler, right, left, vi, logps )
406+ function dot_tilde_observe (ctx, sampler, right, left, vi)
411407 logp = dot_tilde (ctx, sampler, right, left, vi)
412- logps[Threads . threadid ()] += logp
408+ acclogp! (vi, logp)
413409 return left
414410end
415411
416-
417412function _dot_tilde (sampler, right, left:: AbstractArray , vi)
418413 return dot_observe (sampler, right, left, vi)
419414end
@@ -422,7 +417,7 @@ function _dot_tilde(
422417 sampler:: AbstractSampler ,
423418 right:: Union{MultivariateDistribution, AbstractVector{<:MultivariateDistribution}} ,
424419 left:: AbstractMatrix{>:AbstractVector} ,
425- vi:: VarInfo ,
420+ vi,
426421)
427422 throw (ambiguity_error_msg ())
428423end
@@ -431,7 +426,7 @@ function dot_observe(
431426 spl:: Union{SampleFromPrior, SampleFromUniform} ,
432427 dist:: MultivariateDistribution ,
433428 value:: AbstractMatrix ,
434- vi:: VarInfo ,
429+ vi,
435430)
436431 increment_num_produce! (vi)
437432 DynamicPPL. DEBUG && @debug " dist = $dist "
@@ -442,7 +437,7 @@ function dot_observe(
442437 spl:: Union{SampleFromPrior, SampleFromUniform} ,
443438 dists:: Union{Distribution, AbstractArray{<:Distribution}} ,
444439 value:: AbstractArray ,
445- vi:: VarInfo ,
440+ vi,
446441)
447442 increment_num_produce! (vi)
448443 DynamicPPL. DEBUG && @debug " dists = $dists "
@@ -453,7 +448,7 @@ function dot_observe(
453448 spl:: Sampler ,
454449 :: Any ,
455450 :: Any ,
456- :: VarInfo ,
451+ :: Any ,
457452)
458453 error (" [DynamicPPL] $(alg_str (spl)) doesn't support vectorizing observe statement" )
459454end
0 commit comments