@@ -36,16 +36,17 @@ function tilde(ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi)
3636end
3737
3838"""
39- tilde_assume(ctx, sampler, right, vn, inds, vi)
39+ tilde_assume(ctx, sampler, right, vn, inds, vi, logps )
4040
4141Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
42- accumulate the log probability, and return the sampled value.
42+ accumulate the log probability in `logps` (separately for each thread), and return the
43+ sampled value.
4344
4445Falls back to `tilde(ctx, sampler, right, vn, inds, vi)`.
4546"""
46- function tilde_assume (ctx, sampler, right, vn, inds, vi)
47+ function tilde_assume (ctx, sampler, right, vn, inds, vi, logps )
4748 value, logp = tilde (ctx, sampler, right, vn, inds, vi)
48- acclogp! (vi, logp)
49+ logps[Threads . threadid ()] += logp
4950 return value
5051end
5152
7576 tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
7677
7778Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
78- accumulate the log probability, and return the observed value.
79+ accumulate the log probability in `logps` (separately for each thread), and return the
80+ observed value.
7981
8082Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable name
8183and indices; if needed, these can be accessed through this function, though.
8284"""
83- function tilde_observe (ctx, sampler, right, left, vname, vinds, vi)
85+ function tilde_observe (ctx, sampler, right, left, vname, vinds, vi, logps )
8486 logp = tilde (ctx, sampler, right, left, vi)
85- acclogp! (vi, logp)
87+ logps[Threads . threadid ()] += logp
8688 return left
8789end
8890
8991"""
90- tilde_observe(ctx, sampler, right, left, vi)
92+ tilde_observe(ctx, sampler, right, left, vi, logps )
9193
92- Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the
93- observed value.
94+ Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability in `logps`
95+ (separately for each thread), and return the observed value.
9496
9597Falls back to `tilde(ctx, sampler, right, left, vi)`.
9698"""
97- function tilde_observe (ctx, sampler, right, left, vi)
99+ function tilde_observe (ctx, sampler, right, left, vi, logps )
98100 logp = tilde (ctx, sampler, right, left, vi)
99- acclogp! (vi, logp)
101+ logps[Threads . threadid ()] += logp
100102 return left
101103end
102104
@@ -199,13 +201,14 @@ end
199201 dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
200202
201203Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the
202- model inputs), accumulate the log probability, and return the sampled value.
204+ model inputs), accumulate the log probability in `logps` (separately for each thread), and
205+ return the sampled value.
203206
204207Falls back to `dot_tilde(ctx, sampler, right, left, vn, inds, vi)`.
205208"""
206- function dot_tilde_assume (ctx, sampler, right, left, vn, inds, vi)
209+ function dot_tilde_assume (ctx, sampler, right, left, vn, inds, vi, logps )
207210 value, logp = dot_tilde (ctx, sampler, right, left, vn, inds, vi)
208- acclogp! (vi, logp)
211+ logps[Threads . threadid ()] += logp
209212 return value
210213end
211214
@@ -381,31 +384,32 @@ function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi)
381384end
382385
383386"""
384- dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
387+ dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi, logps )
385388
386389Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs),
387- accumulate the log probability, and return the observed value.
390+ accumulate the log probability in `logps` (separately for each thread), and return the
391+ observed value.
388392
389393Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the information about variable
390394name and indices; if needed, these can be accessed through this function, though.
391395"""
392- function dot_tilde_observe (ctx, sampler, right, left, vn, inds, vi)
396+ function dot_tilde_observe (ctx, sampler, right, left, vn, inds, vi, logps )
393397 logp = dot_tilde (ctx, sampler, right, left, vi)
394- acclogp! (vi, logp)
398+ logps[Threads . threadid ()] += logp
395399 return left
396400end
397401
398402"""
399- dot_tilde_observe(ctx, sampler, right, left, vi)
403+ dot_tilde_observe(ctx, sampler, right, left, vi, logps )
400404
401405Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log
402- probability, and return the observed value.
406+ probability in `logps` (separately for each thread) , and return the observed value.
403407
404408Falls back to `dot_tilde(ctx, sampler, right, left, vi)`.
405409"""
406- function dot_tilde_observe (ctx, sampler, right, left, vi)
410+ function dot_tilde_observe (ctx, sampler, right, left, vi, logps )
407411 logp = dot_tilde (ctx, sampler, right, left, vi)
408- acclogp! (vi, logp)
412+ logps[Threads . threadid ()] += logp
409413 return left
410414end
411415
0 commit comments