@@ -11,6 +11,12 @@ Set the "del" flag for all variables in the VarInfo `vi`, thus marking them for
1111resampling.
1212"""
1313function set_all_del! (vi:: AbstractVarInfo )
14+ # TODO (penelopeysm): Instead of being a 'del' flag on the VarInfo, we
15+ # could either:
16+ # - keep a boolean 'resample' flag on the trace, or
17+ # - modify the model context appropriately.
18+ # However, this refactoring will have to wait until InitContext is
19+ # merged into DPPL.
1420 for vn in keys (vi)
1521 DynamicPPL. set_flag! (vi, vn, " del" )
1622 end
5965function AdvancedPS. advance! (
6066 trace:: AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}} , isref:: Bool = false
6167)
62- # We want to increment num produce for the VarInfo stored in the trace. The trace is
63- # mutable, so we create a new model with the incremented VarInfo and set it in the trace
64- model = trace. model
65- model = Accessors. @set model. f. varinfo = DynamicPPL. increment_num_produce!! (
66- model. f. varinfo
67- )
68- trace. model = model
6968 # Make sure we load/reset the rng in the new replaying mechanism
7069 isref ? AdvancedPS. load_state! (trace. rng) : AdvancedPS. save_state! (trace. rng)
7170 score = consume (trace. model. ctask)
7271 return score
7372end
7473
7574function AdvancedPS. delete_retained! (trace:: TracedModel )
76- # TODO (DPPL0.37/penelopeysm): Explain this a bit better.
77- #
7875 # This method is called if, during a CSMC update, we perform a resampling
7976 # and choose the reference particle as the trajectory to carry on from.
8077 # In such a case, we need to ensure that when we continue sampling (i.e.
8178 # the next time we hit tilde_assume), we don't use the values in the
8279 # reference particle but rather sample new values.
83- # In this implementation, we indiscriminately set the 'del' flag for all
84- # variables in the VarInfo. This is slightly overkill: it is not necessary
85- # to set the 'del' flag for variables that were already sampled. However,
86- # it allows us to avoid using DynamicPPL.set_retained_vns_del!.
80+ #
81+ # Here, we indiscriminately set the 'del' flag for all variables in the
82+ # VarInfo. This is slightly overkill: it is not necessary to set the 'del'
83+ # flag for variables that were already sampled. However, it allows us to
84+ # avoid keeping track of which variables were sampled, which leads to many
85+ # simplifications in the VarInfo data structure.
8786 set_all_del! (trace. varinfo)
8887 return trace
8988end
9089
9190function AdvancedPS. reset_model (trace:: TracedModel )
92- return Accessors . @set trace. varinfo = DynamicPPL . reset_num_produce!! (trace . varinfo)
91+ return trace
9392end
9493
9594function Libtask. TapedTask (taped_globals, model:: TracedModel ; kwargs... )
@@ -213,7 +212,6 @@ function DynamicPPL.initialstep(
213212)
214213 # Reset the VarInfo.
215214 vi = DynamicPPL. setacc!! (vi, ProduceLogLikelihoodAccumulator ())
216- vi = DynamicPPL. reset_num_produce!! (vi)
217215 set_all_del! (vi)
218216 vi = DynamicPPL. resetlogp!! (vi)
219217 vi = DynamicPPL. empty!! (vi)
@@ -344,7 +342,6 @@ function DynamicPPL.initialstep(
344342)
345343 vi = DynamicPPL. setacc!! (vi, ProduceLogLikelihoodAccumulator ())
346344 # Reset the VarInfo before new sweep
347- vi = DynamicPPL. reset_num_produce!! (vi)
348345 set_all_del! (vi)
349346 vi = DynamicPPL. resetlogp!! (vi)
350347
@@ -366,11 +363,6 @@ function DynamicPPL.initialstep(
366363
367364 # Compute the first transition.
368365 _vi = reference. model. f. varinfo
369- # Unset any 'del' flags before we actually construct the transition.
370- # This is necessary because the model will be re-evaluated and we
371- # want to make sure we do use the values in the reference particle
372- # instead of resampling them.
373- unset_all_del! (_vi)
374366 transition = PGTransition (model, _vi, logevidence)
375367
376368 return transition, PGState (_vi, reference. rng)
@@ -382,10 +374,10 @@ function AbstractMCMC.step(
382374 # Reset the VarInfo before new sweep.
383375 vi = state. vi
384376 vi = DynamicPPL. setacc!! (vi, ProduceLogLikelihoodAccumulator ())
385- vi = DynamicPPL. reset_num_produce!! (vi)
386377 vi = DynamicPPL. resetlogp!! (vi)
387378
388379 # Create reference particle for which the samples will be retained.
380+ unset_all_del! (vi)
389381 reference = AdvancedPS. forkr (AdvancedPS. Trace (model, spl, vi, state. rng))
390382
391383 # For all other particles, do not retain the variables but resample them.
@@ -412,11 +404,6 @@ function AbstractMCMC.step(
412404
413405 # Compute the transition.
414406 _vi = newreference. model. f. varinfo
415- # Unset any 'del' flags before we actually construct the transition.
416- # This is necessary because the model will be re-evaluated and we
417- # want to make sure we do use the values in the reference particle
418- # instead of resampling them.
419- unset_all_del! (_vi)
420407 transition = PGTransition (model, _vi, logevidence)
421408
422409 return transition, PGState (_vi, newreference. rng)
@@ -499,12 +486,11 @@ function DynamicPPL.assume(
499486 vi = push!! (vi, vn, r, dist)
500487 elseif DynamicPPL. is_flagged (vi, vn, " del" )
501488 DynamicPPL. unset_flag! (vi, vn, " del" ) # Reference particle parent
502- r = rand (trng, dist)
503- vi[vn] = DynamicPPL. tovec (r)
504489 # TODO (mhauru):
505490 # The below is the only line that differs from assume called on SampleFromPrior.
506- # Could we just call assume on SampleFromPrior and then `setorder!!` after that?
507- vi = DynamicPPL. setorder!! (vi, vn, DynamicPPL. get_num_produce (vi))
491+ # Could we just call assume on SampleFromPrior with a specific rng?
492+ r = rand (trng, dist)
493+ vi[vn] = DynamicPPL. tovec (r)
508494 else
509495 r = vi[vn]
510496 end
@@ -546,8 +532,6 @@ function AdvancedPS.Trace(
546532 rng:: AdvancedPS.TracedRNG ,
547533)
548534 newvarinfo = deepcopy (varinfo)
549- newvarinfo = DynamicPPL. reset_num_produce!! (newvarinfo)
550-
551535 tmodel = TracedModel (model, sampler, newvarinfo, rng)
552536 newtrace = AdvancedPS. Trace (tmodel, rng)
553537 return newtrace
0 commit comments