4545∂H∂r (h:: Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic} , r:: AbstractVecOrMat ) =
4646 h. metric. M⁻¹ * r
4747
48+ # TODO (kai) make the order of θ and r consistent with neg_energy
49+ # TODO (kai) add stricter types to block hamiltonian.jl#L37 from working on unknown metric/kinetic
50+ # The gradient of a position-dependent Hamiltonian system depends on both θ and r.
51+ ∂H∂θ (h:: Hamiltonian , θ:: AbstractVecOrMat , r:: AbstractVecOrMat ) = ∂H∂θ (h, θ)
52+ ∂H∂r (h:: Hamiltonian , θ:: AbstractVecOrMat , r:: AbstractVecOrMat ) = ∂H∂r (h, r)
53+
4854struct PhasePoint{T<: AbstractVecOrMat{<:AbstractFloat} ,V<: DualValue }
4955 θ:: T # Position variables / model parameters.
5056 r:: T # Momentum variables
@@ -156,7 +162,7 @@ phasepoint(
156162 rng:: Union{AbstractRNG,AbstractVector{<:AbstractRNG}} ,
157163 θ:: AbstractVecOrMat{T} ,
158164 h:: Hamiltonian ,
159- ) where {T<: Real } = phasepoint (h, θ, rand (rng, h. metric, h. kinetic))
165+ ) where {T<: Real } = phasepoint (h, θ, rand (rng, h. metric, h. kinetic, θ ))
160166
161167abstract type AbstractMomentumRefreshment end
162168
@@ -168,7 +174,7 @@ refresh(
168174 :: FullMomentumRefreshment ,
169175 h:: Hamiltonian ,
170176 z:: PhasePoint ,
171- ) = phasepoint (h, z. θ, rand (rng, h. metric, h. kinetic))
177+ ) = phasepoint (h, z. θ, rand (rng, h. metric, h. kinetic, z . θ ))
172178
173179"""
174180$(TYPEDEF)
@@ -196,4 +202,8 @@ refresh(
196202 ref:: PartialMomentumRefreshment ,
197203 h:: Hamiltonian ,
198204 z:: PhasePoint ,
199- ) = phasepoint (h, z. θ, ref. α * z. r + sqrt (1 - ref. α^ 2 ) * rand (rng, h. metric, h. kinetic))
205+ ) = phasepoint (
206+ h,
207+ z. θ,
208+ ref. α * z. r + sqrt (1 - ref. α^ 2 ) * rand (rng, h. metric, h. kinetic, z. θ),
209+ )
0 commit comments