@@ -409,52 +409,74 @@ function DynamicPPL.initialstep(
409409 rng:: Random.AbstractRNG ,
410410 model:: DynamicPPL.Model ,
411411 spl:: DynamicPPL.Sampler{<:Gibbs} ,
412- vi_base :: DynamicPPL.AbstractVarInfo ;
412+ vi :: DynamicPPL.AbstractVarInfo ;
413413 initial_params= nothing ,
414414 kwargs... ,
415415)
416416 alg = spl. alg
417417 varnames = alg. varnames
418418 samplers = alg. samplers
419419
420- # Run the model once to get the varnames present + initial values to condition on.
421- vi = DynamicPPL . VarInfo ( rng, model)
422- if initial_params != = nothing
423- vi = DynamicPPL . unflatten (vi, initial_params )
424- end
420+ vi, states = gibbs_initialstep_recursive (
421+ rng, model, varnames, samplers, vi; initial_params = initial_params, kwargs ...
422+ )
423+ return Transition (model, vi), GibbsState (vi, states )
424+ end
425425
426- # Initialise each component sampler in turn, collect all their states.
427- states = []
428- for (varnames_local, sampler_local) in zip (varnames, samplers)
429- # Get the initial values for this component sampler.
430- initial_params_local = if initial_params === nothing
431- nothing
432- else
433- DynamicPPL. subset (vi, varnames_local)[:]
434- end
426+ """
427+ Take the first step of MCMC for the first component sampler, and call the same function
428+ recursively on the remaining samplers, until no samplers remain. Return the global VarInfo
429+ and a tuple of initial states for all component samplers.
430+ """
431+ function gibbs_initialstep_recursive (
432+ rng, model, varname_vecs, samplers, vi, states= (); initial_params= nothing , kwargs...
433+ )
434+ # End recursion
435+ if isempty (varname_vecs) && isempty (samplers)
436+ return vi, states
437+ end
435438
436- # Construct the conditioned model.
437- model_local, context_local = make_conditional (model, varnames_local, vi)
439+ varnames, varname_vecs_tail ... = varname_vecs
440+ sampler, samplers_tail ... = samplers
438441
439- # Take initial step.
440- _, new_state_local = AbstractMCMC. step (
441- rng,
442- model_local,
443- sampler_local;
444- # FIXME : This will cause issues if the sampler expects initial params in unconstrained space.
445- # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
446- initial_params= initial_params_local,
447- kwargs... ,
448- )
449- new_vi_local = varinfo (new_state_local)
450- # Merge in any new variables that were introduced during the step, but that
451- # were not in the domain of the current sampler.
452- vi = merge (vi, get_global_varinfo (context_local))
453- # Merge the new values for all the variables sampled by the current sampler.
454- vi = merge (vi, new_vi_local)
455- push! (states, new_state_local)
442+ # Get the initial values for this component sampler.
443+ initial_params_local = if initial_params === nothing
444+ nothing
445+ else
446+ DynamicPPL. subset (vi, varnames)[:]
456447 end
457- return Transition (model, vi), GibbsState (vi, states)
448+
449+ # Construct the conditioned model.
450+ conditioned_model, context = make_conditional (model, varnames, vi)
451+
452+ # Take initial step with the current sampler.
453+ _, new_state = AbstractMCMC. step (
454+ rng,
455+ conditioned_model,
456+ sampler;
457+ # FIXME : This will cause issues if the sampler expects initial params in unconstrained space.
458+ # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
459+ initial_params= initial_params_local,
460+ kwargs... ,
461+ )
462+ new_vi_local = varinfo (new_state)
463+ # Merge in any new variables that were introduced during the step, but that
464+ # were not in the domain of the current sampler.
465+ vi = merge (vi, get_global_varinfo (context))
466+ # Merge the new values for all the variables sampled by the current sampler.
467+ vi = merge (vi, new_vi_local)
468+
469+ states = (states... , new_state)
470+ return gibbs_initialstep_recursive (
471+ rng,
472+ model,
473+ varname_vecs_tail,
474+ samplers_tail,
475+ vi,
476+ states;
477+ initial_params= initial_params,
478+ kwargs... ,
479+ )
458480end
459481
460482function AbstractMCMC. step (
@@ -471,17 +493,7 @@ function AbstractMCMC.step(
471493 states = state. states
472494 @assert length (samplers) == length (state. states)
473495
474- # TODO : move this into a recursive function so we can unroll when reasonable?
475- for index in 1 : length (samplers)
476- # Take the inner step.
477- sampler_local = samplers[index]
478- state_local = states[index]
479- varnames_local = varnames[index]
480- vi, new_state_local = gibbs_step_inner (
481- rng, model, varnames_local, sampler_local, state_local, vi; kwargs...
482- )
483- states = Accessors. setindex (states, new_state_local, index)
484- end
496+ vi, states = gibbs_step_recursive (rng, model, varnames, samplers, states, vi; kwargs... )
485497 return Transition (model, vi), GibbsState (vi, states)
486498end
487499
@@ -605,19 +617,33 @@ function match_linking!!(varinfo_local, prev_state_local, model)
605617 return varinfo_local
606618end
607619
608- function gibbs_step_inner (
620+ """
621+ Run a Gibbs step for the first varname/sampler/state tuple, and recursively call the same
622+ function on the tail, until there are no more samplers left.
623+ """
624+ function gibbs_step_recursive (
609625 rng:: Random.AbstractRNG ,
610626 model:: DynamicPPL.Model ,
611- varnames_local,
612- sampler_local,
613- state_local,
614- global_vi;
627+ varname_vecs,
628+ samplers,
629+ states,
630+ global_vi,
631+ new_states= ();
615632 kwargs... ,
616633)
634+ # End recursion.
635+ if isempty (varname_vecs) && isempty (samplers) && isempty (states)
636+ return global_vi, new_states
637+ end
638+
639+ varnames, varname_vecs_tail... = varname_vecs
640+ sampler, samplers_tail... = samplers
641+ state, states_tail... = states
642+
617643 # Construct the conditional model and the varinfo that this sampler should use.
618- model_local, context_local = make_conditional (model, varnames_local , global_vi)
619- varinfo_local = subset (global_vi, varnames_local )
620- varinfo_local = match_linking!! (varinfo_local, state_local , model)
644+ conditioned_model, context = make_conditional (model, varnames , global_vi)
645+ vi = subset (global_vi, varnames )
646+ vi = match_linking!! (vi, state , model)
621647
622648 # TODO (mhauru) The below may be overkill. If the varnames for this sampler are not
623649 # sampled by other samplers, we don't need to `setparams`, but could rather simply
@@ -628,18 +654,25 @@ function gibbs_step_inner(
628654 # going to be a significant expense anyway.
629655 # Set the state of the current sampler, accounting for any changes made by other
630656 # samplers.
631- state_local = setparams_varinfo!! (
632- model_local, sampler_local, state_local, varinfo_local
633- )
657+ state = setparams_varinfo!! (conditioned_model, sampler, state, vi)
634658
635659 # Take a step with the local sampler.
636- new_state_local = last (
637- AbstractMCMC. step (rng, model_local, sampler_local, state_local; kwargs... )
638- )
660+ new_state = last (AbstractMCMC. step (rng, conditioned_model, sampler, state; kwargs... ))
639661
640- new_vi_local = varinfo (new_state_local )
662+ new_vi_local = varinfo (new_state )
641663 # Merge the latest values for all the variables in the current sampler.
642- new_global_vi = merge (get_global_varinfo (context_local ), new_vi_local)
664+ new_global_vi = merge (get_global_varinfo (context ), new_vi_local)
643665 new_global_vi = setlogp!! (new_global_vi, getlogp (new_vi_local))
644- return new_global_vi, new_state_local
666+
667+ new_states = (new_states... , new_state)
668+ return gibbs_step_recursive (
669+ rng,
670+ model,
671+ varname_vecs_tail,
672+ samplers_tail,
673+ states_tail,
674+ new_global_vi,
675+ new_states;
676+ kwargs... ,
677+ )
645678end
0 commit comments