|
38 | 38 | """ |
39 | 39 | tilde_assume(ctx, sampler, right, vn, inds, vi) |
40 | 40 |
|
41 | | -This method is applied in the generated code for assumed variables, e.g., `x ~ Normal()` where |
42 | | -`x` does not occur in the model inputs. |
| 41 | +Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), |
| 42 | +accumulate the log probability, and return the sampled value. |
43 | 43 |
|
44 | 44 | Falls back to `tilde(ctx, sampler, right, vn, inds, vi)`. |
45 | 45 | """ |
46 | 46 | function tilde_assume(ctx, sampler, right, vn, inds, vi) |
47 | | - return tilde(ctx, sampler, right, vn, inds, vi) |
| 47 | + value, logp = tilde(ctx, sampler, right, vn, inds, vi) |
| 48 | + acclogp!(vi, logp) |
| 49 | + return value |
48 | 50 | end |
49 | 51 |
|
50 | 52 |
|
|
72 | 74 | """ |
73 | 75 | tilde_observe(ctx, sampler, right, left, vname, vinds, vi) |
74 | 76 |
|
75 | | -This method is applied in the generated code for observed variables, e.g., `x ~ Normal()` where |
76 | | -`x` does occur in the model inputs. |
| 77 | +Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), |
| 78 | +accumulate the log probability, and return the observed value. |
77 | 79 |
|
78 | | -Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable |
79 | | -name and indices; if needed, these can be accessed through this function, though. |
| 80 | +Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable name |
| 81 | +and indices; if needed, these can be accessed through this function, though. |
80 | 82 | """ |
81 | 83 | function tilde_observe(ctx, sampler, right, left, vname, vinds, vi) |
82 | | - return tilde(ctx, sampler, right, left, vi) |
| 84 | + logp = tilde(ctx, sampler, right, left, vi) |
| 85 | + acclogp!(vi, logp) |
| 86 | + return left |
83 | 87 | end |
84 | 88 |
|
85 | 89 | """ |
86 | 90 | tilde_observe(ctx, sampler, right, left, vi) |
87 | 91 |
|
88 | | -This method is applied in the generated code for observed constants, e.g., `1.0 ~ Normal()`. |
| 92 | +Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the |
| 93 | +observed value. |
| 94 | +
|
89 | 95 | Falls back to `tilde(ctx, sampler, right, left, vi)`. |
90 | 96 | """ |
91 | 97 | function tilde_observe(ctx, sampler, right, left, vi) |
92 | | - return tilde(ctx, sampler, right, left, vi) |
| 98 | + logp = tilde(ctx, sampler, right, left, vi) |
| 99 | + acclogp!(vi, logp) |
| 100 | + return left |
93 | 101 | end |
94 | 102 |
|
95 | 103 |
|
@@ -191,13 +199,15 @@ end |
191 | 199 | """ |
192 | 200 | dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi) |
193 | 201 |
|
194 | | -This method is applied in the generated code for assumed vectorized variables, e.g., `x .~ |
195 | | -MvNormal()` where `x` does not occur in the model inputs. |
| 202 | +Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the |
| 203 | +model inputs), accumulate the log probability, and return the sampled value. |
196 | 204 |
|
197 | 205 | Falls back to `dot_tilde(ctx, sampler, right, left, vn, inds, vi)`. |
198 | 206 | """ |
199 | 207 | function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi) |
200 | | - return dot_tilde(ctx, sampler, right, left, vn, inds, vi) |
| 208 | + value, logp = dot_tilde(ctx, sampler, right, left, vn, inds, vi) |
| 209 | + acclogp!(vi, logp) |
| 210 | + return value |
201 | 211 | end |
202 | 212 |
|
203 | 213 |
|
@@ -367,24 +377,30 @@ end |
367 | 377 | """ |
368 | 378 | dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi) |
369 | 379 |
|
370 | | -This method is applied in the generated code for vectorized observed variables, e.g., `x .~ |
371 | | -MvNormal()` where `x` does occur the model inputs. |
| 380 | +Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), |
| 381 | +accumulate the log probability, and return the observed value. |
372 | 382 |
|
373 | 383 | Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the information about variable |
374 | 384 | name and indices; if needed, these can be accessed through this function, though. |
375 | 385 | """ |
376 | 386 | function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi) |
377 | | - return dot_tilde(ctx, sampler, right, left, vi) |
| 387 | + logp = dot_tilde(ctx, sampler, right, left, vi) |
| 388 | + acclogp!(vi, logp) |
| 389 | + return left |
378 | 390 | end |
379 | 391 |
|
380 | 392 | """ |
381 | 393 | dot_tilde_observe(ctx, sampler, right, left, vi) |
382 | 394 |
|
383 | | -This method is applied in the generated code for vectorized observed constants, e.g., `[1.0] .~ |
384 | | -MvNormal()`. Falls back to `dot_tilde(ctx, sampler, right, left, vi)`. |
| 395 | +Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log |
| 396 | +probability, and return the observed value. |
| 397 | +
|
| 398 | +Falls back to `dot_tilde(ctx, sampler, right, left, vi)`. |
385 | 399 | """ |
386 | 400 | function dot_tilde_observe(ctx, sampler, right, left, vi) |
387 | | - return dot_tilde(ctx, sampler, right, left, vi) |
| 401 | + logp = dot_tilde(ctx, sampler, right, left, vi) |
| 402 | + acclogp!(vi, logp) |
| 403 | + return left |
388 | 404 | end |
389 | 405 |
|
390 | 406 |
|
|
0 commit comments