@@ -103,24 +103,48 @@ function observe(spl::Sampler, weight)
103103 error (" DynamicPPL.observe: unmanaged inference algorithm: $(typeof (spl)) " )
104104end
105105
106+ # If parameters exist, they are used and not overwritten.
106107function assume (
107- spl:: Union{ SampleFromPrior, SampleFromUniform} ,
108+ spl:: SampleFromPrior ,
108109 dist:: Distribution ,
109110 vn:: VarName ,
110111 vi:: VarInfo ,
111112)
112113 if haskey (vi, vn)
113114 if is_flagged (vi, vn, " del" )
114115 unset_flag! (vi, vn, " del" )
115- r = spl isa SampleFromUniform ? init (dist) : rand (dist)
116+ r = rand (dist)
116117 vi[vn] = vectorize (dist, r)
118+ settrans! (vi, false , vn)
117119 setorder! (vi, vn, get_num_produce (vi))
118120 else
119- r = vi[vn]
121+ r = vi[vn]
120122 end
121123 else
122- r = isa (spl, SampleFromUniform) ? init (dist) : rand (dist)
124+ r = rand (dist)
125+ push! (vi, vn, r, dist, spl)
126+ settrans! (vi, false , vn)
127+ end
128+ return r, Bijectors. logpdf_with_trans (dist, r, istrans (vi, vn))
129+ end
130+
131+ # Always overwrites the parameters with new ones.
132+ function assume (
133+ spl:: SampleFromUniform ,
134+ dist:: Distribution ,
135+ vn:: VarName ,
136+ vi:: VarInfo ,
137+ )
138+ if haskey (vi, vn)
139+ unset_flag! (vi, vn, " del" )
140+ r = init (dist)
141+ vi[vn] = vectorize (dist, r)
142+ settrans! (vi, true , vn)
143+ setorder! (vi, vn, get_num_produce (vi))
144+ else
145+ r = init (dist)
123146 push! (vi, vn, r, dist, spl)
147+ settrans! (vi, true , vn)
124148 end
125149 # NOTE: The importance weight is not correctly computed here because
126150 # r is genereated from some uniform distribution which is different from the prior
0 commit comments