@@ -205,6 +205,105 @@ function AbstractMCMC.step(
205205 return Transition (t. z, tstat), newstate
206206end
207207
208+ struct SGHMCState{T<: AbstractVector{<:Real} }
209+ " Index of current iteration."
210+ i
211+ " Current [`Transition`](@ref)."
212+ transition
213+ " Current [`AbstractMetric`](@ref), possibly adapted."
214+ metric
215+ " Current [`AbstractMCMCKernel`](@ref)."
216+ κ
217+ " Current [`AbstractAdaptor`](@ref)."
218+ adaptor
219+ velocity:: T
220+ end
221+ getadaptor (state:: SGHMCState ) = state. adaptor
222+ getmetric (state:: SGHMCState ) = state. metric
223+ getintegrator (state:: SGHMCState ) = state. κ. τ. integrator
224+
225+ function AbstractMCMC. step (
226+ rng:: Random.AbstractRNG ,
227+ model:: AbstractMCMC.LogDensityModel ,
228+ spl:: SGHMC ;
229+ initial_params= nothing ,
230+ kwargs... ,
231+ )
232+ # Unpack model
233+ logdensity = model. logdensity
234+
235+ # Define metric
236+ metric = make_metric (spl, logdensity)
237+
238+ # Construct the hamiltonian using the initial metric
239+ hamiltonian = Hamiltonian (metric, model)
240+
241+ # Compute initial sample and state.
242+ initial_params = make_initial_params (rng, spl, logdensity, initial_params)
243+ ϵ = make_step_size (rng, spl, hamiltonian, initial_params)
244+ integrator = make_integrator (spl, ϵ)
245+
246+ # Make kernel
247+ κ = make_kernel (spl, integrator)
248+
249+ # Make adaptor
250+ adaptor = make_adaptor (spl, metric, integrator)
251+
252+ # Get an initial sample.
253+ h, t = AdvancedHMC. sample_init (rng, hamiltonian, initial_params)
254+
255+ state = SGHMCState (0 , t, metric, κ, adaptor, initial_params, zero (initial_params))
256+
257+ return AbstractMCMC. step (rng, model, spl, state; kwargs... )
258+ end
259+
260+ function AbstractMCMC. step (
261+ rng:: AbstractRNG ,
262+ model:: AbstractMCMC.LogDensityModel ,
263+ spl:: SGHMC ,
264+ state:: SGHMCState ;
265+ n_adapts:: Int = 0 ,
266+ kwargs... ,
267+ )
268+ i = state. i + 1
269+ t_old = state. transition
270+ adaptor = state. adaptor
271+ κ = state. κ
272+ metric = state. metric
273+
274+ # Reconstruct hamiltonian.
275+ h = Hamiltonian (metric, model)
276+
277+ # Compute gradient of log density.
278+ logdensity_and_gradient = Base. Fix1 (
279+ LogDensityProblems. logdensity_and_gradient, model. logdensity
280+ )
281+ θ = t_old. z. θ
282+ grad = last (logdensity_and_gradient (θ))
283+
284+ # Update latent variables and velocity according to
285+ # equation (15) of Chen et al. (2014)
286+ v = state. velocity
287+ θ .+ = v
288+ η = spl. learning_rate
289+ α = spl. momentum_decay
290+ newv = (1 - α) .* v .+ η .* grad .+ sqrt (2 * η * α) .* randn (rng, eltype (v), length (v))
291+
292+ # Adapt h and spl.
293+ tstat = stat (t)
294+ h, κ, isadapted = adapt! (h, κ, adaptor, i, n_adapts, θ, tstat. acceptance_rate)
295+ tstat = merge (tstat, (is_adapt= isadapted,))
296+
297+ # Make new transition.
298+ t = transition (rng, h, κ, t_old. z)
299+
300+ # Compute next sample and state.
301+ sample = Transition (t. z, tstat)
302+ newstate = SGHMCState (i, t, h. metric, κ, adaptor, newv)
303+
304+ return sample, newstate
305+ end
306+
208307# ###############
209308# ## Callback ###
210309# ###############
@@ -392,6 +491,10 @@ function make_adaptor(spl::HMC, metric::AbstractMetric, integrator::AbstractInte
392491 return NoAdaptation ()
393492end
394493
494+ function make_adaptor (spl:: SGHMC , metric:: AbstractMetric , integrator:: AbstractIntegrator )
495+ return NoAdaptation ()
496+ end
497+
395498function make_adaptor (
396499 spl:: HMCSampler , metric:: AbstractMetric , integrator:: AbstractIntegrator
397500)
417520function make_kernel (spl:: HMCSampler , integrator:: AbstractIntegrator )
418521 return spl. κ
419522end
523+
524+ function make_kernel (spl:: SGHMC , integrator:: AbstractIntegrator )
525+ return HMCKernel (Trajectory {EndPointTS} (integrator, FixedNSteps (spl. n_leapfrog)))
526+ end
0 commit comments