@@ -153,6 +153,25 @@ function MH(model::Model; proposal_type=AMH.StaticProposal)
153153 return AMH. MetropolisHastings (priors)
154154end
155155
156+ """
157+ MHState(varinfo::AbstractVarInfo, logjoint_internal::Real)
158+
159+ State for Metropolis-Hastings sampling.
160+
161+ `varinfo` must have the correct parameters set inside it, but its other fields
162+ (e.g. accumulators, which track logp) can in general be missing or incorrect.
163+
164+ `logjoint_internal` is the log joint probability of the model, evaluated using
165+ the parameters and linking status of `varinfo`. It should be equal to
166+ `DynamicPPL.getlogjoint_internal(varinfo)`. This information is returned by the
167+ MH sampler so we store this here to avoid re-evaluating the model
168+ unnecessarily.
169+ """
170+ struct MHState{V<: AbstractVarInfo ,L<: Real }
171+ varinfo:: V
172+ logjoint_internal:: L
173+ end
174+
156175# ####################
157176# Utility functions #
158177# ####################
@@ -297,14 +316,15 @@ end
297316
298317# Make a proposal if we don't have a covariance proposal matrix (the default).
299318function propose!! (
300- rng:: AbstractRNG , vi :: AbstractVarInfo , model:: Model , spl:: Sampler{<:MH} , proposal
319+ rng:: AbstractRNG , prev_state :: MHState , model:: Model , spl:: Sampler{<:MH} , proposal
301320)
321+ vi = prev_state. varinfo
302322 # Retrieve distribution and value NamedTuples.
303323 dt, vt = dist_val_tuple (spl, vi)
304324
305325 # Create a sampler and the previous transition.
306326 mh_sampler = AMH. MetropolisHastings (dt)
307- prev_trans = AMH. Transition (vt, DynamicPPL . getlogjoint_internal (vi) , false )
327+ prev_trans = AMH. Transition (vt, prev_state . logjoint_internal , false )
308328
309329 # Make a new transition.
310330 spl_model = DynamicPPL. contextualize (
@@ -319,24 +339,29 @@ function propose!!(
319339 trans, _ = AbstractMCMC. step (rng, densitymodel, mh_sampler, prev_trans)
320340 # trans.params isa NamedTuple
321341 set_namedtuple! (vi, trans. params)
322- return vi
342+ # Here, `trans.lp` is equal to `getlogjoint_internal(vi)`. We don't know
343+ # how to set this back inside vi (without re-evaluating). However, the next
344+ # MH step will require this information to calculate the acceptance
345+ # probability, so we return it together with vi.
346+ return MHState (vi, trans. lp)
323347end
324348
325349# Make a proposal if we DO have a covariance proposal matrix.
326350function propose!! (
327351 rng:: AbstractRNG ,
328- vi :: AbstractVarInfo ,
352+ prev_state :: MHState ,
329353 model:: Model ,
330354 spl:: Sampler{<:MH} ,
331355 proposal:: AdvancedMH.RandomWalkProposal ,
332356)
357+ vi = prev_state. varinfo
333358 # If this is the case, we can just draw directly from the proposal
334359 # matrix.
335360 vals = vi[:]
336361
337362 # Create a sampler and the previous transition.
338363 mh_sampler = AMH. MetropolisHastings (spl. alg. proposals)
339- prev_trans = AMH. Transition (vals, DynamicPPL . getlogjoint_internal (vi) , false )
364+ prev_trans = AMH. Transition (vals, prev_state . logjoint_internal , false )
340365
341366 # Make a new transition.
342367 spl_model = DynamicPPL. contextualize (
@@ -350,7 +375,12 @@ function propose!!(
350375 )
351376 trans, _ = AbstractMCMC. step (rng, densitymodel, mh_sampler, prev_trans)
352377 # trans.params isa AbstractVector
353- return DynamicPPL. unflatten (vi, trans. params)
378+ vi = DynamicPPL. unflatten (vi, trans. params)
379+ # Here, `trans.lp` is equal to `getlogjoint_internal(vi)`. We don't know
380+ # how to set this back inside vi (without re-evaluating). However, the next
381+ # MH step will require this information to calculate the acceptance
382+ # probability, so we return it together with vi.
383+ return MHState (vi, trans. lp)
354384end
355385
356386function DynamicPPL. initialstep (
@@ -364,18 +394,18 @@ function DynamicPPL.initialstep(
364394 # just link everything before sampling.
365395 vi = maybe_link!! (vi, spl, spl. alg. proposals, model)
366396
367- return Transition (model, vi, nothing ), vi
397+ return Transition (model, vi, nothing ), MHState (vi, DynamicPPL . getlogjoint_internal (vi))
368398end
369399
370400function AbstractMCMC. step (
371- rng:: AbstractRNG , model:: Model , spl:: Sampler{<:MH} , vi :: AbstractVarInfo ; kwargs...
401+ rng:: AbstractRNG , model:: Model , spl:: Sampler{<:MH} , state :: MHState ; kwargs...
372402)
373403 # Cases:
374404 # 1. A covariance proposal matrix
375405 # 2. A bunch of NamedTuples that specify the proposal space
376- new_vi = propose!! (rng, vi , model, spl, spl. alg. proposals)
406+ new_state = propose!! (rng, state , model, spl, spl. alg. proposals)
377407
378- return Transition (model, new_vi , nothing ), new_vi
408+ return Transition (model, new_state . varinfo , nothing ), new_state
379409end
380410
381411# ###
0 commit comments