@@ -19,47 +19,47 @@ _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds))
1919_getindex (x, inds:: Tuple{} ) = x
2020
2121# assume
22- function tilde (ctx:: DefaultContext , sampler, right, vn:: VarName , _, vi)
23- return _tilde (sampler, right, vn, vi)
22+ function tilde (rng, ctx:: DefaultContext , sampler, right, vn:: VarName , _, vi)
23+ return _tilde (rng, sampler, right, vn, vi)
2424end
25- function tilde (ctx:: PriorContext , sampler, right, vn:: VarName , inds, vi)
25+ function tilde (rng, ctx:: PriorContext , sampler, right, vn:: VarName , inds, vi)
2626 if ctx. vars != = nothing
2727 vi[vn] = vectorize (right, _getindex (getfield (ctx. vars, getsym (vn)), inds))
2828 settrans! (vi, false , vn)
2929 end
30- return _tilde (sampler, right, vn, vi)
30+ return _tilde (rng, sampler, right, vn, vi)
3131end
32- function tilde (ctx:: LikelihoodContext , sampler, right, vn:: VarName , inds, vi)
32+ function tilde (rng, ctx:: LikelihoodContext , sampler, right, vn:: VarName , inds, vi)
3333 if ctx. vars != = nothing
3434 vi[vn] = vectorize (right, _getindex (getfield (ctx. vars, getsym (vn)), inds))
3535 settrans! (vi, false , vn)
3636 end
37- return _tilde (sampler, NoDist (right), vn, vi)
37+ return _tilde (rng, sampler, NoDist (right), vn, vi)
3838end
39- function tilde (ctx:: MiniBatchContext , sampler, right, left:: VarName , inds, vi)
40- return tilde (ctx. ctx, sampler, right, left, inds, vi)
39+ function tilde (rng, ctx:: MiniBatchContext , sampler, right, left:: VarName , inds, vi)
40+ return tilde (rng, ctx. ctx, sampler, right, left, inds, vi)
4141end
4242
4343"""
44- tilde_assume(ctx, sampler, right, vn, inds, vi)
44+ tilde_assume(rng, ctx, sampler, right, vn, inds, vi)
4545
4646Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
4747accumulate the log probability, and return the sampled value.
4848
49- Falls back to `tilde(ctx, sampler, right, vn, inds, vi)`.
49+ Falls back to `tilde(rng, ctx, sampler, right, vn, inds, vi)`.
5050"""
51- function tilde_assume (ctx, sampler, right, vn, inds, vi)
52- value, logp = tilde (ctx, sampler, right, vn, inds, vi)
51+ function tilde_assume (rng, ctx, sampler, right, vn, inds, vi)
52+ value, logp = tilde (rng, ctx, sampler, right, vn, inds, vi)
5353 acclogp! (vi, logp)
5454 return value
5555end
5656
5757
58- function _tilde (sampler, right, vn:: VarName , vi)
59- return assume (sampler, right, vn, vi)
58+ function _tilde (rng, sampler, right, vn:: VarName , vi)
59+ return assume (rng, sampler, right, vn, vi)
6060end
61- function _tilde (sampler, right:: NamedDist , vn:: VarName , vi)
62- return _tilde (sampler, right. dist, right. name, vi)
61+ function _tilde (rng, sampler, right:: NamedDist , vn:: VarName , vi)
62+ return _tilde (rng, sampler, right. dist, right. name, vi)
6363end
6464
6565# observe
108108
109109_tilde (sampler, right, left, vi) = observe (sampler, right, left, vi)
110110
111- function assume (spl:: Sampler , dist)
111+ function assume (rng, spl:: Sampler , dist)
112112 error (" DynamicPPL.assume: unmanaged inference algorithm: $(typeof (spl)) " )
113113end
114114
@@ -117,6 +117,7 @@ function observe(spl::Sampler, weight)
117117end
118118
119119function assume (
120+ rng,
120121 spl:: Union{SampleFromPrior,SampleFromUniform} ,
121122 dist:: Distribution ,
122123 vn:: VarName ,
@@ -126,15 +127,15 @@ function assume(
126127 # Always overwrite the parameters with new ones for `SampleFromUniform`.
127128 if spl isa SampleFromUniform || is_flagged (vi, vn, " del" )
128129 unset_flag! (vi, vn, " del" )
129- r = init (dist, spl)
130+ r = init (rng, dist, spl)
130131 vi[vn] = vectorize (dist, r)
131132 settrans! (vi, false , vn)
132133 setorder! (vi, vn, get_num_produce (vi))
133134 else
134135 r = vi[vn]
135136 end
136137 else
137- r = init (dist, spl)
138+ r = init (rng, dist, spl)
138139 push! (vi, vn, r, dist, spl)
139140 settrans! (vi, false , vn)
140141 end
@@ -154,11 +155,12 @@ end
154155# .~ functions
155156
156157# assume
157- function dot_tilde (ctx:: DefaultContext , sampler, right, left, vn:: VarName , _, vi)
158+ function dot_tilde (rng, ctx:: DefaultContext , sampler, right, left, vn:: VarName , _, vi)
158159 vns, dist = get_vns_and_dist (right, left, vn)
159- return _dot_tilde (sampler, dist, left, vns, vi)
160+ return _dot_tilde (rng, sampler, dist, left, vns, vi)
160161end
161162function dot_tilde (
163+ rng,
162164 ctx:: LikelihoodContext ,
163165 sampler,
164166 right,
@@ -175,12 +177,13 @@ function dot_tilde(
175177 else
176178 vns, dist = get_vns_and_dist (right, left, vn)
177179 end
178- return _dot_tilde (sampler, NoDist (dist), left, vns, vi)
180+ return _dot_tilde (rng, sampler, NoDist (dist), left, vns, vi)
179181end
180- function dot_tilde (ctx:: MiniBatchContext , sampler, right, left, vn:: VarName , inds, vi)
181- return dot_tilde (ctx. ctx, sampler, right, left, vn, inds, vi)
182+ function dot_tilde (rng, ctx:: MiniBatchContext , sampler, right, left, vn:: VarName , inds, vi)
183+ return dot_tilde (rng, ctx. ctx, sampler, right, left, vn, inds, vi)
182184end
183185function dot_tilde (
186+ rng,
184187 ctx:: PriorContext ,
185188 sampler,
186189 right,
@@ -197,19 +200,19 @@ function dot_tilde(
197200 else
198201 vns, dist = get_vns_and_dist (right, left, vn)
199202 end
200- return _dot_tilde (sampler, dist, left, vns, vi)
203+ return _dot_tilde (rng, sampler, dist, left, vns, vi)
201204end
202205
203206"""
204- dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
207+ dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)
205208
206209Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the
207210model inputs), accumulate the log probability, and return the sampled value.
208211
209- Falls back to `dot_tilde(ctx, sampler, right, left, vn, inds, vi)`.
212+ Falls back to `dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi)`.
210213"""
211- function dot_tilde_assume (ctx, sampler, right, left, vn, inds, vi)
212- value, logp = dot_tilde (ctx, sampler, right, left, vn, inds, vi)
214+ function dot_tilde_assume (rng, ctx, sampler, right, left, vn, inds, vi)
215+ value, logp = dot_tilde (rng, ctx, sampler, right, left, vn, inds, vi)
213216 acclogp! (vi, logp)
214217 return value
215218end
@@ -232,12 +235,13 @@ function get_vns_and_dist(
232235 return getvn .(CartesianIndices (var)), dist
233236end
234237
235- function _dot_tilde (sampler, right, left, vns:: AbstractArray{<:VarName} , vi)
236- return dot_assume (sampler, right, vns, left, vi)
238+ function _dot_tilde (rng, sampler, right, left, vns:: AbstractArray{<:VarName} , vi)
239+ return dot_assume (rng, sampler, right, vns, left, vi)
237240end
238241
239242# Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics
240243function _dot_tilde (
244+ rng,
241245 sampler:: AbstractSampler ,
242246 right:: Union{MultivariateDistribution, AbstractVector{<:MultivariateDistribution}} ,
243247 left:: AbstractMatrix{>:AbstractVector} ,
@@ -248,32 +252,35 @@ function _dot_tilde(
248252end
249253
250254function dot_assume (
255+ rng,
251256 spl:: Union{SampleFromPrior, SampleFromUniform} ,
252257 dist:: MultivariateDistribution ,
253258 vns:: AbstractVector{<:VarName} ,
254259 var:: AbstractMatrix ,
255260 vi,
256261)
257262 @assert length (dist) == size (var, 1 )
258- r = get_and_set_val! (vi, vns, dist, spl)
263+ r = get_and_set_val! (rng, vi, vns, dist, spl)
259264 lp = sum (Bijectors. logpdf_with_trans (dist, r, istrans (vi, vns[1 ])))
260265 var .= r
261266 return var, lp
262267end
263268function dot_assume (
269+ rng,
264270 spl:: Union{SampleFromPrior, SampleFromUniform} ,
265271 dists:: Union{Distribution, AbstractArray{<:Distribution}} ,
266272 vns:: AbstractArray{<:VarName} ,
267273 var:: AbstractArray ,
268274 vi,
269275)
270- r = get_and_set_val! (vi, vns, dists, spl)
276+ r = get_and_set_val! (rng, vi, vns, dists, spl)
271277 # Make sure `r` is not a matrix for multivariate distributions
272278 lp = sum (Bijectors. logpdf_with_trans .(dists, r, istrans (vi, vns[1 ])))
273279 var .= r
274280 return var, lp
275281end
276282function dot_assume (
283+ rng,
277284 spl:: Sampler ,
278285 :: Any ,
279286 :: AbstractArray{<:VarName} ,
@@ -284,6 +291,7 @@ function dot_assume(
284291end
285292
286293function get_and_set_val! (
294+ rng,
287295 vi,
288296 vns:: AbstractVector{<:VarName} ,
289297 dist:: MultivariateDistribution ,
@@ -294,7 +302,7 @@ function get_and_set_val!(
294302 # Always overwrite the parameters with new ones for `SampleFromUniform`.
295303 if spl isa SampleFromUniform || is_flagged (vi, vns[1 ], " del" )
296304 unset_flag! (vi, vns[1 ], " del" )
297- r = init (dist, spl, n)
305+ r = init (rng, dist, spl, n)
298306 for i in 1 : n
299307 vn = vns[i]
300308 vi[vn] = vectorize (dist, r[:, i])
@@ -305,7 +313,7 @@ function get_and_set_val!(
305313 r = vi[vns]
306314 end
307315 else
308- r = init (dist, spl, n)
316+ r = init (rng, dist, spl, n)
309317 for i in 1 : n
310318 vn = vns[i]
311319 push! (vi, vn, r[:,i], dist, spl)
@@ -316,6 +324,7 @@ function get_and_set_val!(
316324end
317325
318326function get_and_set_val! (
327+ rng,
319328 vi,
320329 vns:: AbstractArray{<:VarName} ,
321330 dists:: Union{Distribution, AbstractArray{<:Distribution}} ,
@@ -325,7 +334,7 @@ function get_and_set_val!(
325334 # Always overwrite the parameters with new ones for `SampleFromUniform`.
326335 if spl isa SampleFromUniform || is_flagged (vi, vns[1 ], " del" )
327336 unset_flag! (vi, vns[1 ], " del" )
328- f = (vn, dist) -> init (dist, spl)
337+ f = (vn, dist) -> init (rng, dist, spl)
329338 r = f .(vns, dists)
330339 for i in eachindex (vns)
331340 vn = vns[i]
@@ -338,7 +347,7 @@ function get_and_set_val!(
338347 r = reshape (vi[vec (vns)], size (vns))
339348 end
340349 else
341- f = (vn, dist) -> init (dist, spl)
350+ f = (vn, dist) -> init (rng, dist, spl)
342351 r = f .(vns, dists)
343352 push! .(Ref (vi), vns, r, dists, Ref (spl))
344353 settrans! .(Ref (vi), false , vns)
0 commit comments