@@ -79,85 +79,18 @@ export InferenceAlgorithm,
7979 predict,
8080 externalsampler
8181
82- # ######################
83- # Sampler abstraction #
84- # ######################
85- abstract type AbstractAdapter end
86- abstract type InferenceAlgorithm end
87- abstract type ParticleInference <: InferenceAlgorithm end
88- abstract type Hamiltonian <: InferenceAlgorithm end
89- abstract type StaticHamiltonian <: Hamiltonian end
90- abstract type AdaptiveHamiltonian <: Hamiltonian end
82+ # ##############################################
83+ # Abstract inferface for inference algorithms #
84+ # ##############################################
9185
92- include (" repeat_sampler.jl" )
93-
94- """
95- ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained}
96-
97- Represents a sampler that is not an implementation of `InferenceAlgorithm`.
98-
99- The `Unconstrained` type-parameter is to indicate whether the sampler requires unconstrained space.
86+ include (" algorithm.jl" )
10087
101- # Fields
102- $(TYPEDFIELDS)
103- """
104- struct ExternalSampler{S<: AbstractSampler ,AD<: ADTypes.AbstractADType ,Unconstrained} < :
105- InferenceAlgorithm
106- " the sampler to wrap"
107- sampler:: S
108- " the automatic differentiation (AD) backend to use"
109- adtype:: AD
110-
111- """
112- ExternalSampler(sampler::AbstractSampler, adtype::ADTypes.AbstractADType, ::Val{unconstrained})
113-
114- Wrap a sampler so it can be used as an inference algorithm.
115-
116- # Arguments
117- - `sampler::AbstractSampler`: The sampler to wrap.
118- - `adtype::ADTypes.AbstractADType`: The automatic differentiation (AD) backend to use.
119- - `unconstrained::Val=Val{true}()`: Value type containing a boolean indicating whether the sampler requires unconstrained space.
120- """
121- function ExternalSampler (
122- sampler:: AbstractSampler ,
123- adtype:: ADTypes.AbstractADType ,
124- :: Val{unconstrained} = Val (true ),
125- ) where {unconstrained}
126- if ! (unconstrained isa Bool)
127- throw (
128- ArgumentError (" Expected Val{true} or Val{false}, got Val{$unconstrained }" )
129- )
130- end
131- return new {typeof(sampler),typeof(adtype),unconstrained} (sampler, adtype)
132- end
133- end
88+ # ###################
89+ # Sampler wrappers #
90+ # ###################
13491
135- """
136- requires_unconstrained_space(sampler::ExternalSampler)
137-
138- Return `true` if the sampler requires unconstrained space, and `false` otherwise.
139- """
140- requires_unconstrained_space (
141- :: ExternalSampler{<:Any,<:Any,Unconstrained}
142- ) where {Unconstrained} = Unconstrained
143-
144- """
145- externalsampler(sampler::AbstractSampler; adtype=AutoForwardDiff(), unconstrained=true)
146-
147- Wrap a sampler so it can be used as an inference algorithm.
148-
149- # Arguments
150- - `sampler::AbstractSampler`: The sampler to wrap.
151-
152- # Keyword Arguments
153- - `adtype::ADTypes.AbstractADType=ADTypes.AutoForwardDiff()`: The automatic differentiation (AD) backend to use.
154- - `unconstrained::Bool=true`: Whether the sampler requires unconstrained space.
155- """
156- function externalsampler (
157- sampler:: AbstractSampler ; adtype= Turing. DEFAULT_ADTYPE, unconstrained:: Bool = true
158- )
159- return ExternalSampler (sampler, adtype, Val (unconstrained))
160- end
92+ include (" repeat_sampler.jl" )
93+ include (" external_sampler.jl" )
16194
16295# TODO : make a nicer `set_namedtuple!` and move these functions to DynamicPPL.
16396function DynamicPPL. unflatten (vi:: DynamicPPL.NTVarInfo , θ:: NamedTuple )
@@ -168,30 +101,6 @@ function DynamicPPL.unflatten(vi::SimpleVarInfo, θ::NamedTuple)
168101 return SimpleVarInfo (θ, vi. logp, vi. transformation)
169102end
170103
171- """
172- Prior()
173-
174- Algorithm for sampling from the prior.
175- """
176- struct Prior <: InferenceAlgorithm end
177-
178- function AbstractMCMC. step (
179- rng:: Random.AbstractRNG ,
180- model:: DynamicPPL.Model ,
181- sampler:: DynamicPPL.Sampler{<:Prior} ,
182- state= nothing ;
183- kwargs... ,
184- )
185- vi = last (
186- DynamicPPL. evaluate!! (
187- model,
188- VarInfo (),
189- SamplingContext (rng, DynamicPPL. SampleFromPrior (), DynamicPPL. PriorContext ()),
190- ),
191- )
192- return vi, nothing
193- end
194-
195104"""
196105 mh_accept(logp_current::Real, logp_proposal::Real, log_proposal_ratio::Real)
197106
214123# Default Transition #
215124# #####################
216125# Default
217- # Extended in contrib/inference/abstractmcmc.jl
218126getstats (t) = nothing
219127
220128abstract type AbstractTransition end
@@ -246,95 +154,10 @@ DynamicPPL.getlogp(t::Transition) = t.lp
246154# Metadata of VarInfo object
247155metadata (vi:: AbstractVarInfo ) = (lp= getlogp (vi),)
248156
249- # TODO : Implement additional checks for certain samplers, e.g.
250- # HMC not supporting discrete parameters.
251- function _check_model (model:: DynamicPPL.Model )
252- return DynamicPPL. check_model (model; error_on_failure= true )
253- end
254- function _check_model (model:: DynamicPPL.Model , alg:: InferenceAlgorithm )
255- return _check_model (model)
256- end
257-
258- # ########################################
259- # Default definitions for the interface #
260- # ########################################
261-
262- function AbstractMCMC. sample (
263- model:: AbstractModel , alg:: InferenceAlgorithm , N:: Integer ; kwargs...
264- )
265- return AbstractMCMC. sample (Random. default_rng (), model, alg, N; kwargs... )
266- end
267-
268- function AbstractMCMC. sample (
269- rng:: AbstractRNG ,
270- model:: AbstractModel ,
271- alg:: InferenceAlgorithm ,
272- N:: Integer ;
273- check_model:: Bool = true ,
274- kwargs... ,
275- )
276- check_model && _check_model (model, alg)
277- return AbstractMCMC. sample (rng, model, Sampler (alg), N; kwargs... )
278- end
279-
280- function AbstractMCMC. sample (
281- model:: AbstractModel ,
282- alg:: InferenceAlgorithm ,
283- ensemble:: AbstractMCMC.AbstractMCMCEnsemble ,
284- N:: Integer ,
285- n_chains:: Integer ;
286- kwargs... ,
287- )
288- return AbstractMCMC. sample (
289- Random. default_rng (), model, alg, ensemble, N, n_chains; kwargs...
290- )
291- end
292-
293- function AbstractMCMC. sample (
294- rng:: AbstractRNG ,
295- model:: AbstractModel ,
296- alg:: InferenceAlgorithm ,
297- ensemble:: AbstractMCMC.AbstractMCMCEnsemble ,
298- N:: Integer ,
299- n_chains:: Integer ;
300- check_model:: Bool = true ,
301- kwargs... ,
302- )
303- check_model && _check_model (model, alg)
304- return AbstractMCMC. sample (rng, model, Sampler (alg), ensemble, N, n_chains; kwargs... )
305- end
306-
307- function AbstractMCMC. sample (
308- rng:: AbstractRNG ,
309- model:: AbstractModel ,
310- sampler:: Union{Sampler{<:InferenceAlgorithm},RepeatSampler} ,
311- ensemble:: AbstractMCMC.AbstractMCMCEnsemble ,
312- N:: Integer ,
313- n_chains:: Integer ;
314- chain_type= MCMCChains. Chains,
315- progress= PROGRESS[],
316- kwargs... ,
317- )
318- return AbstractMCMC. mcmcsample (
319- rng,
320- model,
321- sampler,
322- ensemble,
323- N,
324- n_chains;
325- chain_type= chain_type,
326- progress= progress,
327- kwargs... ,
328- )
329- end
330-
331157# #########################
332158# Chain making utilities #
333159# #########################
334160
335- DynamicPPL. default_chain_type (sampler:: Prior ) = MCMCChains. Chains
336- DynamicPPL. default_chain_type (sampler:: Sampler{<:InferenceAlgorithm} ) = MCMCChains. Chains
337-
338161"""
339162 getparams(model, t)
340163
535358# Concrete algorithm implementations. #
536359# ######################################
537360
538- include (" abstractmcmc.jl" )
539361include (" ess.jl" )
540362include (" hmc.jl" )
541363include (" mh.jl" )
@@ -544,6 +366,13 @@ include("particle_mcmc.jl")
544366include (" gibbs.jl" )
545367include (" sghmc.jl" )
546368include (" emcee.jl" )
369+ include (" prior.jl" )
370+
371+ # ###################################
372+ # Generic sample() method dispatch #
373+ # ###################################
374+
375+ include (" sample.jl" )
547376
548377# ###############
549378# Typing tools #
0 commit comments