Skip to content

Commit bac9477

Browse files
authored
Update to AbstractMCMC 2 and fully commit to its interface (#8)
1 parent a2ec96c commit bac9477

File tree

8 files changed

+32
-81
lines changed

8 files changed

+32
-81
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1212

1313
[compat]
14-
AbstractMCMC = "0.5, 1"
14+
AbstractMCMC = "2"
1515
ArrayInterface = "2"
1616
Distributions = "0.22, 0.23"
1717
julia = "1"

README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ priors with non-zero means and handle the change of variables internally.
2222

2323
## Usage
2424

25-
Probably most users would like to use the exported function
25+
Probably most users would like to generate a MC Markov chain of samples from
26+
the posterior distribution by calling
2627
```julia
27-
ESS_mcmc([rng, ]prior, loglikelihood, N[; kwargs...])
28+
sample([rng, ]ESSModel(prior, loglikelihood), ESS(), N[; kwargs...])
2829
```
2930
which returns a vector of `N` samples for approximating the posterior of
3031
a model with a Gaussian prior that allows sampling from the `prior` and
@@ -34,10 +35,10 @@ If you want to have more control about the sampling procedure (e.g., if you
3435
only want to save a subset of samples or want to use another stopping
3536
criterion), the function
3637
```julia
37-
AbstractMCMC.steps!(
38+
AbstractMCMC.steps(
3839
[rng,]
39-
EllipticalSliceSampling.Model(prior, loglikelihood),
40-
EllipticalSliceSampling.EllipticalSliceSampler();
40+
ESSModel(prior, loglikelihood),
41+
ESS();
4142
kwargs...
4243
)
4344
```

src/EllipticalSliceSampling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import Distributions
77
import Random
88
import Statistics
99

10-
export ESS_mcmc
10+
export sample, ESSModel, ESS
1111

1212
include("abstractmcmc.jl")
1313
include("model.jl")

src/abstractmcmc.jl

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
11
# elliptical slice sampler
2-
struct EllipticalSliceSampler <: AbstractMCMC.AbstractSampler end
2+
struct ESS <: AbstractMCMC.AbstractSampler end
33

44
# state of the elliptical slice sampler
5-
struct EllipticalSliceSamplerState{S,L}
5+
struct ESSState{S,L}
66
"Sample of the elliptical slice sampler."
77
sample::S
88
"Log-likelihood of the sample."
99
loglikelihood::L
1010
end
1111

1212
# first step of the elliptical slice sampler
13-
function AbstractMCMC.step!(
13+
function AbstractMCMC.step(
1414
rng::Random.AbstractRNG,
1515
model::AbstractMCMC.AbstractModel,
16-
::EllipticalSliceSampler,
17-
N::Integer,
18-
::Nothing;
16+
::ESS;
1917
kwargs...
2018
)
2119
# initial sample from the Gaussian prior
@@ -24,16 +22,15 @@ function AbstractMCMC.step!(
2422
# compute log-likelihood of the initial sample
2523
loglikelihood = Distributions.loglikelihood(model, f)
2624

27-
return EllipticalSliceSamplerState(f, loglikelihood)
25+
return f, ESSState(f, loglikelihood)
2826
end
2927

3028
# subsequent steps of the elliptical slice sampler
31-
function AbstractMCMC.step!(
29+
function AbstractMCMC.step(
3230
rng::Random.AbstractRNG,
3331
model::AbstractMCMC.AbstractModel,
34-
::EllipticalSliceSampler,
35-
N::Integer,
36-
state::EllipticalSliceSamplerState;
32+
::ESS,
33+
state::ESSState;
3734
kwargs...
3835
)
3936
# sample from Gaussian prior
@@ -78,29 +75,5 @@ function AbstractMCMC.step!(
7875
loglikelihood = Distributions.loglikelihood(model, fnext)
7976
end
8077

81-
return EllipticalSliceSamplerState(fnext, loglikelihood)
82-
end
83-
84-
# only save the samples by default
85-
function AbstractMCMC.transitions_init(
86-
state::EllipticalSliceSamplerState,
87-
model::AbstractMCMC.AbstractModel,
88-
::EllipticalSliceSampler,
89-
N::Integer;
90-
kwargs...
91-
)
92-
return Vector{typeof(state.sample)}(undef, N)
93-
end
94-
95-
function AbstractMCMC.transitions_save!(
96-
samples::AbstractVector{S},
97-
iteration::Integer,
98-
state::EllipticalSliceSamplerState{S},
99-
model::AbstractMCMC.AbstractModel,
100-
::EllipticalSliceSampler,
101-
N::Integer;
102-
kwargs...
103-
) where S
104-
samples[iteration] = state.sample
105-
return
78+
return fnext, ESSState(fnext, loglikelihood)
10679
end

src/interface.jl

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,3 @@
1-
# public interface
2-
3-
"""
4-
ESS_mcmc([rng, ]prior, loglikelihood, N; kwargs...)
5-
6-
Create a Markov chain of `N` samples for a model with given `prior` and `loglikelihood`
7-
functions using the elliptical slice sampling algorithm.
8-
"""
9-
function ESS_mcmc(
10-
rng::Random.AbstractRNG,
11-
prior,
12-
loglikelihood,
13-
N::Integer;
14-
kwargs...
15-
)
16-
model = Model(prior, loglikelihood)
17-
return AbstractMCMC.sample(rng, model, EllipticalSliceSampler(), N; kwargs...)
18-
end
19-
20-
function ESS_mcmc(prior, loglikelihood, N::Integer; kwargs...)
21-
return ESS_mcmc(Random.GLOBAL_RNG, prior, loglikelihood, N; kwargs...)
22-
end
23-
241
# private interface
252

263
"""

src/model.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# internal model structure consisting of prior, log-likelihood function, and a cache
22

3-
struct Model{P,L,C} <: AbstractMCMC.AbstractModel
3+
struct ESSModel{P,L,C} <: AbstractMCMC.AbstractModel
44
"Gaussian prior."
55
prior::P
66
"Log likelihood function."
77
loglikelihood::L
88
"Cache."
99
cache::C
1010

11-
function Model{P,L}(prior::P, loglikelihood::L) where {P,L}
11+
function ESSModel{P,L}(prior::P, loglikelihood::L) where {P,L}
1212
isgaussian(P) ||
1313
error("prior distribution has to be a Gaussian distribution")
1414

@@ -19,8 +19,8 @@ struct Model{P,L,C} <: AbstractMCMC.AbstractModel
1919
end
2020
end
2121

22-
Model(prior, loglikelihood) =
23-
Model{typeof(prior),typeof(loglikelihood)}(prior, loglikelihood)
22+
ESSModel(prior, loglikelihood) =
23+
ESSModel{typeof(prior),typeof(loglikelihood)}(prior, loglikelihood)
2424

2525
# cache for high-dimensional samplers
2626
function cache(dist)
@@ -39,11 +39,11 @@ isgaussian(dist) = false
3939
randtype(dist) = eltype(dist)
4040

4141
# evaluate the loglikelihood of a sample
42-
Distributions.loglikelihood(model::Model, f) = model.loglikelihood(f)
42+
Distributions.loglikelihood(model::ESSModel, f) = model.loglikelihood(f)
4343

4444
# sample from the prior
45-
initial_sample(rng::Random.AbstractRNG, model::Model) = rand(rng, model.prior)
46-
function sample_prior(rng::Random.AbstractRNG, model::Model)
45+
initial_sample(rng::Random.AbstractRNG, model::ESSModel) = rand(rng, model.prior)
46+
function sample_prior(rng::Random.AbstractRNG, model::ESSModel)
4747
cache = model.cache
4848

4949
if cache === nothing
@@ -55,8 +55,8 @@ function sample_prior(rng::Random.AbstractRNG, model::Model)
5555
end
5656

5757
# compute the proposal
58-
proposal(model::Model, f, ν, θ) = proposal(model.prior, f, ν, θ)
59-
proposal!(out, model::Model, f, ν, θ) = proposal!(out, model.prior, f, ν, θ)
58+
proposal(model::ESSModel, f, ν, θ) = proposal(model.prior, f, ν, θ)
59+
proposal!(out, model::ESSModel, f, ν, θ) = proposal!(out, model.prior, f, ν, θ)
6060

6161
# default out-of-place implementation
6262
function proposal(prior, f, ν, θ)

test/regression.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ const σ = 0.3
4141
end
4242

4343
# run elliptical slice sampler for 100 000 time steps
44-
samples = ESS_mcmc(prior, ℓ, 100_000; progress = false)
44+
samples = sample(ESSModel(prior, ℓ), ESS(), 100_000; progress = false)
4545

4646
# compute analytical posterior of GP
4747
posterior_Σ = prior_Σ * (I - (prior_Σ + σ^2 * I) \ prior_Σ)
@@ -66,7 +66,7 @@ end
6666
end
6767

6868
# run elliptical slice sampling for 100 000 time steps
69-
samples = ESS_mcmc(prior, ℓ, 100_000; progress = false)
69+
samples = sample(ESSModel(prior, ℓ), ESS(), 100_000; progress = false)
7070

7171
# compute analytical posterior
7272
posterior_μ = observations / (1 + σ^2)

test/simple.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ using Test
1818
μ = 0.8
1919
σ² = 0.2
2020

21-
samples = ESS_mcmc(prior, ℓ, 2_000; progress = false)
21+
samples = sample(ESSModel(prior, ℓ), ESS(), 2_000; progress = false)
2222

2323
@test mean(samples) μ atol=0.05
2424
@test var(samples) σ² atol=0.05
@@ -35,7 +35,7 @@ end
3535
μ = 0.9
3636
σ² = 0.2
3737

38-
samples = ESS_mcmc(prior, ℓ, 2_000; progress = false)
38+
samples = sample(ESSModel(prior, ℓ), ESS(), 2_000; progress = false)
3939

4040
@test mean(samples) μ atol=0.05
4141
@test var(samples) σ² atol=0.05
@@ -53,7 +53,7 @@ end
5353
μ = [0.8]
5454
σ² = [0.2]
5555

56-
samples = ESS_mcmc(prior, ℓ, 2_000; progress = false)
56+
samples = sample(ESSModel(prior, ℓ), ESS(), 2_000; progress = false)
5757

5858
@test mean(samples) μ atol=0.05
5959
@test var(samples) σ² atol=0.05
@@ -70,7 +70,7 @@ end
7070
μ = [0.9]
7171
σ² = [0.2]
7272

73-
samples = ESS_mcmc(prior, ℓ, 2_000; progress = false)
73+
samples = sample(ESSModel(prior, ℓ), ESS(), 2_000; progress = false)
7474

7575
@test mean(samples) μ atol=0.05
7676
@test var(samples) σ² atol=0.05

0 commit comments

Comments
 (0)