Skip to content

Commit 7fc04d2

Browse files
committed
feat: port relativistic hmc from research repo
Signed-off-by: Kai Xu <xuk@ibm.com>
1 parent 2b3814c commit 7fc04d2

File tree

3 files changed

+25
-87
lines changed

3 files changed

+25
-87
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "0.6.1"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
7+
AdaptiveRejectionSampling = "c75e803d-635f-53bd-ab7d-544e482d8c75"
78
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
89
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
910
InplaceOps = "505f98c9-085e-5b2c-8e89-488be7bf1f34"
@@ -31,23 +32,24 @@ AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"
3132

3233
[compat]
3334
AbstractMCMC = "4.2, 5"
35+
AdaptiveRejectionSampling = "0.1.1"
3436
ArgCheck = "1, 2"
3537
CUDA = "3, 4, 5"
3638
DocStringExtensions = "0.8, 0.9"
3739
InplaceOps = "0.3"
40+
LinearAlgebra = "1.6"
3841
LogDensityProblems = "2"
3942
LogDensityProblemsAD = "1"
4043
MCMCChains = "5, 6"
4144
OrdinaryDiffEq = "6"
4245
ProgressMeter = "1"
46+
Random = "1.6"
4347
Requires = "0.5, 1"
4448
Setfield = "0.7, 0.8, 1"
4549
SimpleUnPack = "1.1"
4650
Statistics = "1.6"
4751
StatsBase = "0.31, 0.32, 0.33, 0.34"
4852
StatsFuns = "0.8, 0.9, 1"
49-
LinearAlgebra = "1.6"
50-
Random = "1.6"
5153
julia = "1.6"
5254

5355
[extras]

research/src/relativistic_hmc.jl

Lines changed: 0 additions & 83 deletions
This file was deleted.

src/AdvancedHMC.jl

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ export Hamiltonian
5252

5353
include("integrator.jl")
5454
export Leapfrog, JitteredLeapfrog, TemperedLeapfrog
55-
include("riemannian/integrator.jl")
56-
export GeneralizedLeapfrog
5755

5856
include("trajectory.jl")
5957
export Trajectory,
@@ -128,6 +126,27 @@ export sample
128126
include("constructors.jl")
129127
export HMCSampler, HMC, NUTS, HMCDA
130128

129+
module Experimental
130+
using Random, Statistics, LinearAlgebra
131+
using ..AdvancedHMC
132+
133+
import ..AdvancedHMC: ∂H∂r, neg_energy, AbstractKinetic
134+
import Random: AbstractRNG
135+
include("relativistic/hamiltonian.jl")
136+
export RelativisticKinetic, DimensionwiseRelativisticKinetic
137+
138+
using AdaptiveRejectionSampling: RejectionSampler, run_sampler!
139+
import ..AdvancedHMC: _rand
140+
include("relativistic/metric.jl")
141+
142+
using ..AdvancedHMC: @unpack, TYPEDEF, TYPEDFIELDS, AbstractScalarOrVec, AbstractLeapfrog, step, step_size
143+
import ..AdvancedHMC: ∂H∂θ, ∂H∂r, DualValue, PhasePoint, phasepoint, step
144+
include("riemannian/integrator.jl")
145+
include("riemannian/hamiltonian.jl")
146+
include("riemannian/metric.jl")
147+
export GeneralizedLeapfrog
148+
end
149+
131150
include("abstractmcmc.jl")
132151

133152
## Without explicit AD backend

0 commit comments

Comments
 (0)