@@ -83,14 +83,20 @@ function externalsampler(
8383 return ExternalSampler (sampler, adtype, Val (unconstrained))
8484end
8585
86- struct TuringState{S,M,V}
86+ # TODO (penelopeysm): Can't we clean this up somehow?
87+ struct TuringState{S,V1,M,V}
8788 state:: S
89+ # Note that this varinfo must have the correct parameters set; but logp
90+ # does not matter as it will be re-evaluated
91+ varinfo:: V1
8892 # Note that in general the VarInfo inside this LogDensityFunction will have
8993 # junk parameters and logp. It only exists to provide structure
9094 ldf:: DynamicPPL.LogDensityFunction{M,V}
9195end
9296
93- get_varinfo (state:: TuringState ) = state. ldf. varinfo
97+ # get_varinfo should return something from which the correct parameters can be
98+ # obtained, hence we use state.varinfo rather than state.ldf.varinfo
99+ get_varinfo (state:: TuringState ) = state. varinfo
94100get_varinfo (state:: AbstractVarInfo ) = state
95101
96102getparams (:: DynamicPPL.Model , transition:: AdvancedHMC.Transition ) = transition. z. θ
@@ -148,8 +154,10 @@ function AbstractMCMC.step(
148154 end
149155
150156 new_parameters = getparams (f. model, state_inner)
151- vi = DynamicPPL. unflatten (f. varinfo, new_parameters)
152- return (Transition (f. model, vi, transition_inner), TuringState (state_inner, f))
157+ new_vi = DynamicPPL. unflatten (f. varinfo, new_parameters)
158+ return (
159+ Transition (f. model, new_vi, transition_inner), TuringState (state_inner, new_vi, f)
160+ )
153161end
154162
155163function AbstractMCMC. step (
@@ -168,6 +176,8 @@ function AbstractMCMC.step(
168176 )
169177
170178 new_parameters = getparams (f. model, state_inner)
171- vi = DynamicPPL. unflatten (f. varinfo, new_parameters)
172- return (Transition (f. model, vi, transition_inner), TuringState (state_inner, f))
179+ new_vi = DynamicPPL. unflatten (f. varinfo, new_parameters)
180+ return (
181+ Transition (f. model, new_vi, transition_inner), TuringState (state_inner, new_vi, f)
182+ )
173183end
0 commit comments