@@ -205,17 +205,23 @@ function AbstractMCMC.step(
205205 return Transition (t. z, tstat), newstate
206206end
207207
208- struct SGHMCState{T<: AbstractVector{<:Real} }
208+ struct SGHMCState{
209+ TTrans<: Transition ,
210+ TMetric<: AbstractMetric ,
211+ TKernel<: AbstractMCMCKernel ,
212+ TAdapt<: Adaptation.AbstractAdaptor ,
213+ T<: AbstractVector{<:Real} ,
214+ }
209215 " Index of current iteration."
210- i
216+ i:: Int
211217 " Current [`Transition`](@ref)."
212- transition
218+ transition:: TTrans
213219 " Current [`AbstractMetric`](@ref), possibly adapted."
214- metric
220+ metric:: TMetric
215221 " Current [`AbstractMCMCKernel`](@ref)."
216- κ
222+ κ:: TKernel
217223 " Current [`AbstractAdaptor`](@ref)."
218- adaptor
224+ adaptor:: TAdapt
219225 velocity:: T
220226end
221227getadaptor (state:: SGHMCState ) = state. adaptor
@@ -252,7 +258,7 @@ function AbstractMCMC.step(
252258 # Get an initial sample.
253259 h, t = AdvancedHMC. sample_init (rng, hamiltonian, initial_params)
254260
255- state = SGHMCState (0 , t, metric, κ, adaptor, initial_params, zero (initial_params) )
261+ state = SGHMCState (0 , t, metric, κ, adaptor, initial_params)
256262
257263 return AbstractMCMC. step (rng, model, spl, state; kwargs... )
258264end
@@ -265,6 +271,14 @@ function AbstractMCMC.step(
265271 n_adapts:: Int = 0 ,
266272 kwargs... ,
267273)
274+ if haskey (kwargs, :nadapts )
275+ throw (
276+ ArgumentError (
277+ " keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps." ,
278+ ),
279+ )
280+ end
281+
268282 i = state. i + 1
269283 t_old = state. transition
270284 adaptor = state. adaptor
@@ -289,14 +303,14 @@ function AbstractMCMC.step(
289303 α = spl. momentum_decay
290304 newv = (1 - α) .* v .+ η .* grad .+ sqrt (2 * η * α) .* randn (rng, eltype (v), length (v))
291305
306+ # Make new transition.
307+ t = transition (rng, h, κ, t_old. z)
308+
292309 # Adapt h and spl.
293310 tstat = stat (t)
294311 h, κ, isadapted = adapt! (h, κ, adaptor, i, n_adapts, θ, tstat. acceptance_rate)
295312 tstat = merge (tstat, (is_adapt= isadapted,))
296313
297- # Make new transition.
298- t = transition (rng, h, κ, t_old. z)
299-
300314 # Compute next sample and state.
301315 sample = Transition (t. z, tstat)
302316 newstate = SGHMCState (i, t, h. metric, κ, adaptor, newv)
0 commit comments