|
1 | | -using Random |
2 | | - |
3 | | -### integrator.jl |
4 | | - |
5 | | -import AdvancedHMC: ∂H∂θ, ∂H∂r, DualValue, PhasePoint, phasepoint, step |
6 | | -using AdvancedHMC: TYPEDEF, TYPEDFIELDS, AbstractScalarOrVec, AbstractLeapfrog, step_size |
7 | | - |
8 | | -""" |
9 | | -$(TYPEDEF) |
10 | | -
|
11 | | -Generalized leapfrog integrator with fixed step size `ϵ`. |
12 | | -
|
13 | | -# Fields |
14 | | -
|
15 | | -$(TYPEDFIELDS) |
16 | | -""" |
17 | | -struct GeneralizedLeapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T} |
18 | | - "Step size." |
19 | | - ϵ::T |
20 | | - n::Int |
21 | | -end |
22 | | -function Base.show(io::IO, l::GeneralizedLeapfrog) |
23 | | - return print(io, "GeneralizedLeapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)), n=$(l.n))") |
24 | | -end |
25 | | - |
26 | | -# Fallback to ignore return_cache & cache kwargs for other ∂H∂θ |
27 | | -function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing) where {T} |
28 | | - dv = ∂H∂θ(h, θ, r) |
29 | | - return return_cache ? (dv, nothing) : dv |
30 | | -end |
31 | | - |
32 | | -# TODO Make sure vectorization works |
33 | | -# TODO Check if tempering is valid |
34 | | -function step( |
35 | | - lf::GeneralizedLeapfrog{T}, |
36 | | - h::Hamiltonian, |
37 | | - z::P, |
38 | | - n_steps::Int=1; |
39 | | - fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0 |
40 | | - full_trajectory::Val{FullTraj}=Val(false), |
41 | | -) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint,FullTraj} |
42 | | - n_steps = abs(n_steps) # to support `n_steps < 0` cases |
43 | | - |
44 | | - ϵ = fwd ? step_size(lf) : -step_size(lf) |
45 | | - ϵ = ϵ' |
46 | | - |
47 | | - res = if FullTraj |
48 | | - Vector{P}(undef, n_steps) |
49 | | - else |
50 | | - z |
51 | | - end |
52 | | - |
53 | | - for i in 1:n_steps |
54 | | - θ_init, r_init = z.θ, z.r |
55 | | - # Tempering |
56 | | - #r = temper(lf, r, (i=i, is_half=true), n_steps) |
57 | | - #! Eq (16) of Girolami & Calderhead (2011) |
58 | | - r_half = copy(r_init) |
59 | | - local cache |
60 | | - for j in 1:(lf.n) |
61 | | - # Reuse cache for the first iteration |
62 | | - if j == 1 |
63 | | - (; value, gradient) = z.ℓπ |
64 | | - elseif j == 2 # cache intermediate values that depends on θ only (which are unchanged) |
65 | | - retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache=true) |
66 | | - (; value, gradient) = retval |
67 | | - else # reuse cache |
68 | | - (; value, gradient) = ∂H∂θ_cache(h, θ_init, r_half; cache=cache) |
69 | | - end |
70 | | - r_half = r_init - ϵ / 2 * gradient |
71 | | - # println("r_half: ", r_half) |
72 | | - end |
73 | | - #! Eq (17) of Girolami & Calderhead (2011) |
74 | | - θ_full = copy(θ_init) |
75 | | - term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop |
76 | | - for j in 1:(lf.n) |
77 | | - θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half)) |
78 | | - # println("θ_full :", θ_full) |
79 | | - end |
80 | | - #! Eq (18) of Girolami & Calderhead (2011) |
81 | | - (; value, gradient) = ∂H∂θ(h, θ_full, r_half) |
82 | | - r_full = r_half - ϵ / 2 * gradient |
83 | | - # println("r_full: ", r_full) |
84 | | - # Tempering |
85 | | - #r = temper(lf, r, (i=i, is_half=false), n_steps) |
86 | | - # Create a new phase point by caching the logdensity and gradient |
87 | | - z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient)) |
88 | | - # Update result |
89 | | - if FullTraj |
90 | | - res[i] = z |
91 | | - else |
92 | | - res = z |
93 | | - end |
94 | | - if !isfinite(z) |
95 | | - # Remove undef |
96 | | - if FullTraj |
97 | | - res = res[isassigned.(Ref(res), 1:n_steps)] |
98 | | - end |
99 | | - break |
100 | | - end |
101 | | - # @assert false |
102 | | - end |
103 | | - return res |
104 | | -end |
105 | | - |
106 | | -# TODO Make the order of θ and r consistent with neg_energy |
107 | | -∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ) |
108 | | -∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r) |
109 | | - |
110 | | -### hamiltonian.jl |
111 | | - |
112 | | -import AdvancedHMC: refresh, phasepoint |
113 | | -using AdvancedHMC: FullMomentumRefreshment, PartialMomentumRefreshment, AbstractMetric |
114 | | - |
115 | | -# To change L180 of hamiltonian.jl |
116 | | -function phasepoint( |
117 | | - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, |
118 | | - θ::AbstractVecOrMat{T}, |
119 | | - h::Hamiltonian, |
120 | | -) where {T<:Real} |
121 | | - return phasepoint(h, θ, rand_momentum(rng, h.metric, h.kinetic, θ)) |
122 | | -end |
123 | | - |
124 | | -# To change L191 of hamiltonian.jl |
125 | | -function refresh( |
126 | | - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, |
127 | | - ::FullMomentumRefreshment, |
128 | | - h::Hamiltonian, |
129 | | - z::PhasePoint, |
130 | | -) |
131 | | - return phasepoint(h, z.θ, rand_momentum(rng, h.metric, h.kinetic, z.θ)) |
132 | | -end |
133 | | - |
134 | | -# To change L215 of hamiltonian.jl |
135 | | -function refresh( |
136 | | - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, |
137 | | - ref::PartialMomentumRefreshment, |
138 | | - h::Hamiltonian, |
139 | | - z::PhasePoint, |
140 | | -) |
141 | | - return phasepoint( |
142 | | - h, |
143 | | - z.θ, |
144 | | - ref.α * z.r + sqrt(1 - ref.α^2) * rand_momentum(rng, h.metric, h.kinetic, z.θ), |
145 | | - ) |
146 | | -end |
147 | | - |
148 | | -### metric.jl |
149 | | - |
150 | | -import AdvancedHMC: _rand |
151 | | -using AdvancedHMC: AbstractMetric |
152 | | -using LinearAlgebra: eigen, cholesky, Symmetric |
153 | | - |
154 | | -abstract type AbstractRiemannianMetric <: AbstractMetric end |
155 | | - |
156 | | -abstract type AbstractHessianMap end |
157 | | - |
158 | | -struct IdentityMap <: AbstractHessianMap end |
159 | | - |
160 | | -(::IdentityMap)(x) = x |
161 | | - |
162 | | -struct SoftAbsMap{T} <: AbstractHessianMap |
163 | | - α::T |
164 | | -end |
165 | | - |
166 | | -# TODO Register softabs with ReverseDiff |
167 | | -#! The definition of SoftAbs from Page 3 of Betancourt (2012) |
168 | | -function softabs(X, α=20.0) |
169 | | - F = eigen(X) # ReverseDiff cannot diff through `eigen` |
170 | | - Q = hcat(F.vectors) |
171 | | - λ = F.values |
172 | | - softabsλ = λ .* coth.(α * λ) |
173 | | - return Q * diagm(softabsλ) * Q', Q, λ, softabsλ |
174 | | -end |
175 | | - |
176 | | -(map::SoftAbsMap)(x) = softabs(x, map.α)[1] |
177 | | - |
178 | | -struct DenseRiemannianMetric{ |
179 | | - T, |
180 | | - TM<:AbstractHessianMap, |
181 | | - A<:Union{Tuple{Int},Tuple{Int,Int}}, |
182 | | - AV<:AbstractVecOrMat{T}, |
183 | | - TG, |
184 | | - T∂G∂θ, |
185 | | -} <: AbstractRiemannianMetric |
186 | | - size::A |
187 | | - G::TG # TODO store G⁻¹ here instead |
188 | | - ∂G∂θ::T∂G∂θ |
189 | | - map::TM |
190 | | - _temp::AV |
191 | | -end |
192 | | - |
193 | | -# TODO Make dense mass matrix support matrix-mode parallel |
194 | | -function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) where {T<:AbstractFloat} |
195 | | - _temp = Vector{Float64}(undef, size[1]) |
196 | | - return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp) |
197 | | -end |
198 | | -# DenseEuclideanMetric(::Type{T}, D::Int) where {T} = DenseEuclideanMetric(Matrix{T}(I, D, D)) |
199 | | -# DenseEuclideanMetric(D::Int) = DenseEuclideanMetric(Float64, D) |
200 | | -# DenseEuclideanMetric(::Type{T}, sz::Tuple{Int}) where {T} = DenseEuclideanMetric(Matrix{T}(I, first(sz), first(sz))) |
201 | | -# DenseEuclideanMetric(sz::Tuple{Int}) = DenseEuclideanMetric(Float64, sz) |
202 | | - |
203 | | -# renew(ue::DenseEuclideanMetric, M⁻¹) = DenseEuclideanMetric(M⁻¹) |
204 | | - |
205 | | -Base.size(e::DenseRiemannianMetric) = e.size |
206 | | -Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim] |
207 | | -Base.show(io::IO, dem::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric(...)") |
208 | | - |
209 | | -function rand_momentum( |
210 | | - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, |
211 | | - metric::DenseRiemannianMetric{T}, |
212 | | - kinetic, |
| 1 | +#! Eq (14) of Girolami & Calderhead (2011) |
| 2 | +function ∂H∂r( |
| 3 | + h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, |
213 | 4 | θ::AbstractVecOrMat, |
214 | | -) where {T} |
215 | | - r = _randn(rng, T, size(metric)...) |
216 | | - G⁻¹ = inv(metric.map(metric.G(θ))) |
217 | | - chol = cholesky(Symmetric(G⁻¹)) |
218 | | - ldiv!(chol.U, r) |
219 | | - return r |
220 | | -end |
221 | | - |
222 | | -### hamiltonian.jl |
223 | | - |
224 | | -import AdvancedHMC: phasepoint, neg_energy, ∂H∂θ, ∂H∂r |
225 | | -using LinearAlgebra: logabsdet, tr |
226 | | - |
227 | | -# QUES Do we want to change everything to position dependent by default? |
228 | | -# Add θ to ∂H∂r for DenseRiemannianMetric |
229 | | -function phasepoint( |
230 | | - h::Hamiltonian{<:DenseRiemannianMetric}, |
231 | | - θ::T, |
232 | | - r::T; |
233 | | - ℓπ=∂H∂θ(h, θ), |
234 | | - ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), |
235 | | -) where {T<:AbstractVecOrMat} |
236 | | - return PhasePoint(θ, r, ℓπ, ℓκ) |
237 | | -end |
238 | | - |
239 | | -# Negative kinetic energy |
240 | | -#! Eq (13) of Girolami & Calderhead (2011) |
241 | | -function neg_energy( |
242 | | - h::Hamiltonian{<:DenseRiemannianMetric}, r::T, θ::T |
243 | | -) where {T<:AbstractVecOrMat} |
244 | | - G = h.metric.map(h.metric.G(θ)) |
245 | | - D = size(G, 1) |
246 | | - # Need to consider the normalizing term as it is no longer same for different θs |
247 | | - logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined |
248 | | - mul!(h.metric._temp, inv(G), r) |
249 | | - return -logZ - dot(r, h.metric._temp) / 2 |
| 5 | + r::AbstractVecOrMat, |
| 6 | +) |
| 7 | + H = h.metric.G(θ) |
| 8 | + G = h.metric.map(H) |
| 9 | + return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't |
250 | 10 | end |
251 | 11 |
|
252 | | -# QUES L31 of hamiltonian.jl now reads a bit weird (semantically) |
253 | 12 | function ∂H∂θ( |
254 | | - h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap}}, |
| 13 | + h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap},<:GaussianKinetic}, |
255 | 14 | θ::AbstractVecOrMat{T}, |
256 | 15 | r::AbstractVecOrMat{T}, |
257 | 16 | ) where {T} |
@@ -293,14 +52,14 @@ function make_J(λ::AbstractVector{T}, α::T) where {T<:AbstractFloat} |
293 | 52 | end |
294 | 53 |
|
295 | 54 | function ∂H∂θ( |
296 | | - h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}}, |
| 55 | + h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, |
297 | 56 | θ::AbstractVecOrMat{T}, |
298 | 57 | r::AbstractVecOrMat{T}, |
299 | 58 | ) where {T} |
300 | 59 | return ∂H∂θ_cache(h, θ, r) |
301 | 60 | end |
302 | 61 | function ∂H∂θ_cache( |
303 | | - h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}}, |
| 62 | + h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, |
304 | 63 | θ::AbstractVecOrMat{T}, |
305 | 64 | r::AbstractVecOrMat{T}; |
306 | 65 | return_cache=false, |
@@ -342,17 +101,26 @@ function ∂H∂θ_cache( |
342 | 101 | return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv |
343 | 102 | end |
344 | 103 |
|
345 | | -#! Eq (14) of Girolami & Calderhead (2011) |
346 | | -function ∂H∂r( |
347 | | - h::Hamiltonian{<:DenseRiemannianMetric}, θ::AbstractVecOrMat, r::AbstractVecOrMat |
348 | | -) |
349 | | - H = h.metric.G(θ) |
350 | | - # if !all(isfinite, H) |
351 | | - # println("θ: ", θ) |
352 | | - # println("H: ", H) |
353 | | - # end |
354 | | - G = h.metric.map(H) |
355 | | - # return inv(G) * r |
356 | | - # println("G \ r: ", G \ r) |
357 | | - return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't |
| 104 | +# QUES Do we want to change everything to position dependent by default? |
| 105 | +# Add θ to ∂H∂r for DenseRiemannianMetric |
| 106 | +function phasepoint( |
| 107 | + h::Hamiltonian{<:DenseRiemannianMetric}, |
| 108 | + θ::T, |
| 109 | + r::T; |
| 110 | + ℓπ=∂H∂θ(h, θ), |
| 111 | + ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), |
| 112 | +) where {T<:AbstractVecOrMat} |
| 113 | + return PhasePoint(θ, r, ℓπ, ℓκ) |
| 114 | +end |
| 115 | + |
| 116 | +#! Eq (13) of Girolami & Calderhead (2011) |
| 117 | +function neg_energy( |
| 118 | + h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, r::T, θ::T |
| 119 | +) where {T<:AbstractVecOrMat} |
| 120 | + G = h.metric.map(h.metric.G(θ)) |
| 121 | + D = size(G, 1) |
| 122 | + # Need to consider the normalizing term as it is no longer same for different θs |
| 123 | + logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined |
| 124 | + mul!(h.metric._temp, inv(G), r) |
| 125 | + return -logZ - dot(r, h.metric._temp) / 2 |
358 | 126 | end |
0 commit comments