@@ -111,56 +111,31 @@ function observe(spl::Sampler, weight)
111111 error (" DynamicPPL.observe: unmanaged inference algorithm: $(typeof (spl)) " )
112112end
113113
114- # If parameters exist, they are used and not overwritten.
115114function assume (
116- spl:: SampleFromPrior ,
115+ spl:: Union{ SampleFromPrior,SampleFromUniform} ,
117116 dist:: Distribution ,
118117 vn:: VarName ,
119118 vi:: VarInfo ,
120119)
121120 if haskey (vi, vn)
122- if is_flagged (vi, vn, " del" )
121+ # Always overwrite the parameters with new ones for `SampleFromUniform`.
122+ if spl isa SampleFromUniform || is_flagged (vi, vn, " del" )
123123 unset_flag! (vi, vn, " del" )
124- r = rand (dist)
124+ r = init (dist, spl )
125125 vi[vn] = vectorize (dist, r)
126126 settrans! (vi, false , vn)
127127 setorder! (vi, vn, get_num_produce (vi))
128128 else
129129 r = vi[vn]
130130 end
131131 else
132- r = rand (dist)
132+ r = init (dist, spl )
133133 push! (vi, vn, r, dist, spl)
134134 settrans! (vi, false , vn)
135135 end
136136 return r, Bijectors. logpdf_with_trans (dist, r, istrans (vi, vn))
137137end
138138
139- # Always overwrites the parameters with new ones.
140- function assume (
141- spl:: SampleFromUniform ,
142- dist:: Distribution ,
143- vn:: VarName ,
144- vi:: VarInfo ,
145- )
146- if haskey (vi, vn)
147- unset_flag! (vi, vn, " del" )
148- r = init (dist)
149- vi[vn] = vectorize (dist, r)
150- settrans! (vi, true , vn)
151- setorder! (vi, vn, get_num_produce (vi))
152- else
153- r = init (dist)
154- push! (vi, vn, r, dist, spl)
155- settrans! (vi, true , vn)
156- end
157- # NOTE: The importance weight is not correctly computed here because
158- # r is genereated from some uniform distribution which is different from the prior
159- # acclogp!(vi, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)))
160-
161- return r, Bijectors. logpdf_with_trans (dist, r, istrans (vi, vn))
162- end
163-
164139function observe (
165140 spl:: Union{SampleFromPrior, SampleFromUniform} ,
166141 dist:: Distribution ,
@@ -307,53 +282,60 @@ function get_and_set_val!(
307282 vi:: VarInfo ,
308283 vns:: AbstractVector{<:VarName} ,
309284 dist:: MultivariateDistribution ,
310- spl:: AbstractSampler ,
285+ spl:: Union{SampleFromPrior,SampleFromUniform} ,
311286)
312287 n = length (vns)
313288 if haskey (vi, vns[1 ])
314- if is_flagged (vi, vns[1 ], " del" )
289+ # Always overwrite the parameters with new ones for `SampleFromUniform`.
290+ if spl isa SampleFromUniform || is_flagged (vi, vns[1 ], " del" )
315291 unset_flag! (vi, vns[1 ], " del" )
316- r = spl isa SampleFromUniform ? init (dist, n) : rand (dist , n)
292+ r = init (dist, spl , n)
317293 for i in 1 : n
318294 vn = vns[i]
319295 vi[vn] = vectorize (dist, r[:, i])
296+ settrans! (vi, false , vn)
320297 setorder! (vi, vn, get_num_produce (vi))
321298 end
322299 else
323- r = vi[vns]
300+ r = vi[vns]
324301 end
325302 else
326- r = spl isa SampleFromUniform ? init (dist, n) : rand (dist , n)
303+ r = init (dist, spl , n)
327304 for i in 1 : n
328305 push! (vi, vns[i], r[:,i], dist, spl)
306+ settrans! (vi, false , vn)
329307 end
330308 end
331309 return r
332310end
311+
333312function get_and_set_val! (
334313 vi:: VarInfo ,
335314 vns:: AbstractArray{<:VarName} ,
336315 dists:: Union{Distribution, AbstractArray{<:Distribution}} ,
337- spl:: AbstractSampler ,
316+ spl:: Union{SampleFromPrior,SampleFromUniform} ,
338317)
339318 if haskey (vi, vns[1 ])
340- if is_flagged (vi, vns[1 ], " del" )
319+ # Always overwrite the parameters with new ones for `SampleFromUniform`.
320+ if spl isa SampleFromUniform || is_flagged (vi, vns[1 ], " del" )
341321 unset_flag! (vi, vns[1 ], " del" )
342- f = (vn, dist) -> spl isa SampleFromUniform ? init (dist) : rand (dist )
322+ f = (vn, dist) -> init (dist, spl )
343323 r = f .(vns, dists)
344324 for i in eachindex (vns)
345325 vn = vns[i]
346326 dist = dists isa AbstractArray ? dists[i] : dists
347327 vi[vn] = vectorize (dist, r[i])
328+ settrans! (vi, false , vn)
348329 setorder! (vi, vn, get_num_produce (vi))
349330 end
350331 else
351- r = reshape (vi[vec (vns)], size (vns))
332+ r = reshape (vi[vec (vns)], size (vns))
352333 end
353334 else
354- f = (vn, dist) -> spl isa SampleFromUniform ? init (dist) : rand (dist )
335+ f = (vn, dist) -> init (dist, spl )
355336 r = f .(vns, dists)
356337 push! .(Ref (vi), vns, r, dists, Ref (spl))
338+ settrans! .(Ref (vi), false , vns)
357339 end
358340 return r
359341end
0 commit comments