Skip to content

Commit 3bbb17f

Browse files
committed
feat(wip): port reemannian hmc from research repo
Signed-off-by: Kai Xu <xuk@ibm.com>
1 parent 6b023f7 commit 3bbb17f

File tree

4 files changed

+523
-0
lines changed

4 files changed

+523
-0
lines changed

src/relativistic/hamiltonian.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
abstract type AbstractRelativisticKinetic{T} <: AbstractKinetic end
2+
3+
struct RelativisticKinetic{T} <: AbstractRelativisticKinetic{T}
4+
"Mass"
5+
m::T
6+
"Speed of light"
7+
c::T
8+
end
9+
10+
relativistic_mass(kinetic::RelativisticKinetic, r, r′ = r) =
11+
kinetic.m * sqrt(dot(r, r′) / (kinetic.m ^ 2 * kinetic.c ^ 2) + 1)
12+
relativistic_energy(kinetic::RelativisticKinetic, r, r′ = r) = sum(
13+
kinetic.c ^ 2 * relativistic_mass(kinetic, r, r′)
14+
)
15+
16+
struct DimensionwiseRelativisticKinetic{T} <: AbstractRelativisticKinetic{T}
17+
"Mass"
18+
m::T
19+
"Speed of light"
20+
c::T
21+
end
22+
23+
relativistic_mass(kinetic::DimensionwiseRelativisticKinetic, r, r′ = r) =
24+
kinetic.m .* sqrt.(r .* r′ ./ (kinetic.m .^ 2 .* kinetic.c .^ 2) .+ 1)
25+
relativistic_energy(kinetic::DimensionwiseRelativisticKinetic, r, r′ = r) = sum(
26+
kinetic.c .^ 2 .* relativistic_mass(kinetic, r, r′)
27+
)
28+
29+
function ∂H∂r(
30+
h::Hamiltonian{<:UnitEuclideanMetric,<:AbstractRelativisticKinetic},
31+
r::AbstractVecOrMat,
32+
)
33+
mass = relativistic_mass(h.kinetic, r)
34+
return r ./ mass
35+
end
36+
function ∂H∂r(
37+
h::Hamiltonian{<:DiagEuclideanMetric,<:AbstractRelativisticKinetic},
38+
r::AbstractVecOrMat,
39+
)
40+
r = h.metric.sqrtM⁻¹ .* r
41+
mass = relativistic_mass(h.kinetic, r)
42+
red_term = r ./ mass # red part of (15)
43+
return h.metric.sqrtM⁻¹ .* red_term # (15)
44+
end
45+
function ∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric, <:AbstractRelativisticKinetic}, r::AbstractVecOrMat)
46+
r = h.metric.cholM⁻¹ * r
47+
mass = relativistic_mass(h.kinetic, r)
48+
red_term = r ./ mass
49+
return h.metric.cholM⁻¹' * red_term
50+
end
51+
52+
function neg_energy(
53+
h::Hamiltonian{<:UnitEuclideanMetric,<:AbstractRelativisticKinetic},
54+
r::T,
55+
θ::T,
56+
) where {T<:AbstractVector}
57+
return -relativistic_energy(h.kinetic, r)
58+
end
59+
function neg_energy(
60+
h::Hamiltonian{<:DiagEuclideanMetric,<:AbstractRelativisticKinetic},
61+
r::T,
62+
θ::T,
63+
) where {T<:AbstractVector}
64+
r = h.metric.sqrtM⁻¹ .* r
65+
return -relativistic_energy(h.kinetic, r)
66+
end
67+
function neg_energy(
68+
h::Hamiltonian{<:DenseEuclideanMetric,<:AbstractRelativisticKinetic},
69+
r::T,
70+
θ::T
71+
) where {T<:AbstractVector}
72+
r = h.metric.cholM⁻¹ * r
73+
return -relativistic_energy(h.kinetic, r)
74+
end

src/relativistic/metric.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
function rand_angles(rng::AbstractRNG, dim)
2+
rand(rng, dim - 1) .* vcat(fill(π, dim - 2), 2 * π)
3+
end
4+
5+
"Special case of `polar2spherical` with dimension equal 2"
6+
polar2cartesian(θ, d) = d * [cos(θ), sin(θ)]
7+
8+
# ref: https://en.wikipedia.org/wiki/N-sphere#Spherical_coordinates
9+
function polar2spherical(θs, d)
10+
cos_lst, sin_lst = cos.(θs), sin.(θs)
11+
suffixed_cos_lst = vcat(cos_lst, 1) # [cos(θ[1]), cos(θ[2]), ..., cos(θ[d-1]), 1]
12+
prefixed_cumprod_sin_lst = vcat(1, cumprod(sin_lst)) # [1, sin(θ[1]), sin(θ[1]) * sin(θ[2]), ..., sin(θ[1]) * ... * sin(θ[d-1])]
13+
return d * prefixed_cumprod_sin_lst .* suffixed_cos_lst
14+
end
15+
16+
momentum_mode(m, c) = sqrt((1 / c^2 + sqrt(1 / c^2 + 4 * m^2)) / 2) # mode of the momentum distribution
17+
18+
function _rand(
19+
rng::AbstractRNG,
20+
metric::UnitEuclideanMetric{T},
21+
kinetic::RelativisticKinetic{T},
22+
) where {T}
23+
densityfunc = x -> exp(-relativistic_energy(kinetic, [x])) * x
24+
mm = momentum_mode(kinetic.m, kinetic.c)
25+
sampler = RejectionSampler(densityfunc, (0.0, Inf), (mm / 2, mm * 2); max_segments = 5)
26+
sz = size(metric)
27+
θs = rand_angles(rng, prod(sz))
28+
d = only(run_sampler!(rng, sampler, 1))
29+
r = polar2spherical(θs, d * rand(rng, [-1, +1])) # TODO Double check if +/- is needed
30+
r = reshape(r, sz)
31+
return r
32+
end
33+
34+
# TODO Support AbstractVector{<:AbstractRNG}
35+
# FIXME Unit-test this using slice sampler or HMC sampler
36+
function _rand(
37+
rng::AbstractRNG,
38+
metric::UnitEuclideanMetric{T},
39+
kinetic::DimensionwiseRelativisticKinetic{T},
40+
) where {T}
41+
h_temp = Hamiltonian(metric, kinetic, identity, identity)
42+
densityfunc = x -> exp(neg_energy(h_temp, [x], [x]))
43+
sampler = RejectionSampler(densityfunc, (-Inf, Inf); max_segments = 5)
44+
sz = size(metric)
45+
r = run_sampler!(rng, sampler, prod(sz))
46+
r = reshape(r, sz)
47+
return r
48+
end
49+
50+
# TODO Support AbstractVector{<:AbstractRNG}
51+
function _rand(
52+
rng::AbstractRNG,
53+
metric::DiagEuclideanMetric{T},
54+
kinetic::AbstractRelativisticKinetic{T},
55+
) where {T}
56+
r = _rand(rng, UnitEuclideanMetric(size(metric)), kinetic)
57+
# p' = A p where A = sqrtM
58+
r ./= metric.sqrtM⁻¹
59+
return r
60+
end
61+
# TODO Support AbstractVector{<:AbstractRNG}
62+
function _rand(
63+
rng::AbstractRNG,
64+
metric::DenseEuclideanMetric{T},
65+
kinetic::AbstractRelativisticKinetic{T},
66+
) where {T}
67+
r = _rand(rng, UnitEuclideanMetric(size(metric)), kinetic)
68+
# p' = A p where A = cholM
69+
ldiv!(metric.cholM⁻¹, r)
70+
return r
71+
end

0 commit comments

Comments
 (0)