@@ -30,6 +30,7 @@ import DynamicPPL: get_matching_type,
3030import EllipticalSliceSampling
3131import Random
3232import MCMCChains
33+ import StatsBase: predict
3334
3435export InferenceAlgorithm,
3536 Hamiltonian,
@@ -40,6 +41,7 @@ export InferenceAlgorithm,
4041 SampleFromPrior,
4142 MH,
4243 ESS,
44+ Emcee,
4345 Gibbs, # classic sampling
4446 HMC,
4547 SGLD,
@@ -56,7 +58,9 @@ export InferenceAlgorithm,
5658 dot_assume,
5759 observe,
5860 dot_observe,
59- resume
61+ resume,
62+ predict,
63+ isgibbscomponent
6064
6165# ######################
6266# Sampler abstraction #
@@ -135,7 +139,7 @@ const TURING_INTERNAL_VARS = (internals = [
135139 " step_size" ,
136140 " nom_step_size" ,
137141 " tree_depth" ,
138- " is_adapt" ,
142+ " is_adapt"
139143],)
140144
141145# ########################################
@@ -305,19 +309,21 @@ Return a named tuple of parameters.
305309getparams (t) = t. θ
306310getparams (t:: VarInfo ) = tonamedtuple (TypedVarInfo (t))
307311
308- function _params_to_array (ts)
309- names_set = Set {String } ()
312+ function _params_to_array (ts:: Vector )
313+ names = Vector {Symbol } ()
310314 # Extract the parameter names and values from each transition.
311315 dicts = map (ts) do t
312316 nms, vs = flatten_namedtuple (getparams (t))
313317 for nm in nms
314- push! (names_set, nm)
318+ if ! (nm in names)
319+ push! (names, nm)
320+ end
315321 end
316322 # Convert the names and values to a single dictionary.
317323 return Dict (nms[j] => vs[j] for j in 1 : length (vs))
318324 end
319- names = collect (names_set)
320- vals = [get (dicts[i], key, missing ) for i in eachindex (dicts),
325+ # names = collect(names_set)
326+ vals = [get (dicts[i], key, missing ) for i in eachindex (dicts),
321327 (j, key) in enumerate (names)]
322328
323329 return names, vals
@@ -327,7 +333,7 @@ function flatten_namedtuple(nt::NamedTuple)
327333 names_vals = mapreduce (vcat, keys (nt)) do k
328334 v = nt[k]
329335 if length (v) == 1
330- return [(string (k), v)]
336+ return [(Symbol (k), v)]
331337 else
332338 return mapreduce (vcat, zip (v[1 ], v[2 ])) do (vnval, vn)
333339 return collect (FlattenIterator (vn, vnval))
339345
340346function get_transition_extras (ts:: AbstractVector{<:VarInfo} )
341347 valmat = reshape ([getlogp (t) for t in ts], :, 1 )
342- return [" lp " ], valmat
348+ return [:lp ], valmat
343349end
344350
345351function get_transition_extras (ts:: AbstractVector )
@@ -353,7 +359,7 @@ function get_transition_extras(ts::AbstractVector)
353359
354360 # Iterate through each transition.
355361 for t in ts
356- extra_names = String []
362+ extra_names = Symbol []
357363 vals = []
358364
359365 # Iterate through each of the additional field names
@@ -365,11 +371,11 @@ function get_transition_extras(ts::AbstractVector)
365371 prop = getproperty (t, p)
366372 if prop isa NamedTuple
367373 for (k, v) in pairs (prop)
368- push! (extra_names, string (k))
374+ push! (extra_names, Symbol (k))
369375 push! (vals, v)
370376 end
371377 else
372- push! (extra_names, string (p))
378+ push! (extra_names, Symbol (p))
373379 push! (vals, prop)
374380 end
375381 end
@@ -432,12 +438,11 @@ function AbstractMCMC.bundle_samples(
432438 # Chain construction.
433439 return MCMCChains. Chains (
434440 parray,
435- string .( nms) ,
441+ nms,
436442 deepcopy (TURING_INTERNAL_VARS);
437443 evidence= le,
438444 info= info,
439- sorted= true
440- )
445+ ) |> sort
441446end
442447
443448# This is type piracy (for SampleFromPrior).
@@ -535,12 +540,13 @@ include("is.jl")
535540include (" AdvancedSMC.jl" )
536541include (" gibbs.jl" )
537542include (" ../contrib/inference/sghmc.jl" )
543+ include (" emcee.jl" )
538544
539545# ###############
540546# Typing tools #
541547# ###############
542548
543- for alg in (:SMC , :PG , :MH , :IS , :ESS , :Gibbs )
549+ for alg in (:SMC , :PG , :MH , :IS , :ESS , :Gibbs , :Emcee )
544550 @eval DynamicPPL. getspace (:: $alg{space} ) where {space} = space
545551end
546552for alg in (:HMC , :HMCDA , :NUTS , :SGLD , :SGHMC )
@@ -571,13 +577,12 @@ function get_matching_type(
571577)
572578 return floatof (eltype (vi, spl))
573579end
574- function get_matching_type (
575- spl:: AbstractSampler ,
576- vi,
577- :: Type{TV} ,
578- ) where {T, N, TV <: Array{T, N} }
580+ function get_matching_type (spl:: AbstractSampler , vi, :: Type{<:Array{T,N}} ) where {T,N}
579581 return Array{get_matching_type (spl, vi, T), N}
580582end
583+ function get_matching_type (spl:: AbstractSampler , vi, :: Type{<:Array{T}} ) where T
584+ return Array{get_matching_type (spl, vi, T)}
585+ end
581586function get_matching_type (
582587 spl:: Sampler{<:Union{PG, SMC}} ,
583588 vi,
@@ -593,4 +598,182 @@ end
593598DynamicPPL. getspace (spl:: Sampler ) = getspace (spl. alg)
594599DynamicPPL. inspace (vn:: VarName , spl:: Sampler ) = inspace (vn, getspace (spl. alg))
595600
601+ """
602+
603+ predict(model::Model, chain::MCMCChains.Chains; include_all=false)
604+
605+ Execute `model` conditioned on each sample in `chain`, and return the resulting `Chains`.
606+
607+ If `include_all` is `false`, the returned `Chains` will contain only those variables
608+ sampled/not present in `chain`.
609+
610+ # Details
611+ Internally calls `Turing.Inference.transitions_from_chain` to obtained the samples
612+ and then converts these into a `Chains` object using `AbstractMCMC.bundle_samples`.
613+
614+ # Example
615+ ```jldoctest
616+ julia> using Turing; Turing.turnprogress(false);
617+ [ Info: [Turing]: progress logging is disabled globally
618+
619+ julia> @model function linear_reg(x, y, σ = 0.1)
620+ β ~ Normal(0, 1)
621+
622+ for i ∈ eachindex(y)
623+ y[i] ~ Normal(β * x[i], σ)
624+ end
625+ end;
626+
627+ julia> σ = 0.1; f(x) = 2 * x + 0.1 * randn();
628+
629+ julia> Δ = 0.1; xs_train = 0:Δ:10; ys_train = f.(xs_train);
630+
631+ julia> xs_test = [10 + Δ, 10 + 2 * Δ]; ys_test = f.(xs_test);
632+
633+ julia> m_train = linear_reg(xs_train, ys_train, σ);
634+
635+ julia> chain_lin_reg = sample(m_train, NUTS(100, 0.65), 200);
636+ ┌ Info: Found initial step size
637+ └ ϵ = 0.003125
638+
639+ julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);
640+
641+ julia> predictions = Turing.Inference.predict(m_test, chain_lin_reg)
642+ Object of type Chains, with data of type 100×2×1 Array{Float64,3}
643+
644+ Iterations = 1:100
645+ Thinning interval = 1
646+ Chains = 1
647+ Samples per chain = 100
648+ parameters = y[1], y[2]
649+
650+ 2-element Array{ChainDataFrame,1}
651+
652+ Summary Statistics
653+ parameters mean std naive_se mcse ess r_hat
654+ ────────── ─────── ────── ──────── ─────── ──────── ──────
655+ y[1] 20.1974 0.1007 0.0101 missing 101.0711 0.9922
656+ y[2] 20.3867 0.1062 0.0106 missing 101.4889 0.9903
657+
658+ Quantiles
659+ parameters 2.5% 25.0% 50.0% 75.0% 97.5%
660+ ────────── ─────── ─────── ─────── ─────── ───────
661+ y[1] 20.0342 20.1188 20.2135 20.2588 20.4188
662+ y[2] 20.1870 20.3178 20.3839 20.4466 20.5895
663+
664+
665+ julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1));
666+
667+ julia> sum(abs2, ys_test - ys_pred) ≤ 0.1
668+ true
669+ ```
670+ """
671+ function predict (model:: Turing.Model , chain:: MCMCChains.Chains ; include_all = false )
672+ spl = DynamicPPL. SampleFromPrior ()
673+
674+ # Sample transitions using `spl` conditioned on values in `chain`
675+ transitions = transitions_from_chain (model, chain; sampler = spl)
676+
677+ # Let the Turing internals handle everything else for you
678+ chain_result = AbstractMCMC. bundle_samples (
679+ Distributions. GLOBAL_RNG,
680+ model,
681+ spl,
682+ length (chain),
683+ transitions,
684+ MCMCChains. Chains
685+ )
686+
687+ parameter_names = if include_all
688+ names (chain_result, :parameters )
689+ else
690+ filter (k -> ∉ (k, names (chain, :parameters )), names (chain_result, :parameters ))
691+ end
692+
693+ return chain_result[parameter_names]
694+ end
695+
696+ """
697+
698+ transitions_from_chain(
699+ model::Model,
700+ chain::MCMCChains.Chains;
701+ sampler = DynamicPPL.SampleFromPrior()
702+ )
703+
704+ Execute `model` conditioned on each sample in `chain`, and return resulting transitions.
705+
706+ The returned transitions are represented in a `Vector{<:Turing.Inference.Transition}`.
707+
708+ # Details
709+
710+ In a bit more detail, the process is as follows:
711+ 1. For every `sample` in `chain`
712+ 1. For every `variable` in `sample`
713+ 1. Set `variable` in `model` to its value in `sample`
714+ 2. Execute `model` with variables fixed as above, sampling variables NOT present
715+ in `chain` using `SampleFromPrior`
716+ 3. Return sampled variables and log-joint
717+
718+ # Example
719+ ```julia-repl
720+ julia> using Turing
721+
722+ julia> @model function demo()
723+ m ~ Normal(0, 1)
724+ x ~ Normal(m, 1)
725+ end;
726+
727+ julia> m = demo();
728+
729+ julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m`
730+
731+ julia> transitions = Turing.Inference.transitions_from_chain(m, chain);
732+
733+ julia> [Turing.Inference.getlogp(t) for t in transitions] # extract the logjoints
734+ 2-element Array{Float64,1}:
735+ -3.6294991938628374
736+ -2.5697948166987845
737+
738+ julia> [first(t.θ.x) for t in transitions] # extract samples for `x`
739+ 2-element Array{Array{Float64,1},1}:
740+ [-2.0844148956440796]
741+ [-1.704630494695469]
742+ ```
743+ """
744+ function transitions_from_chain (
745+ model:: Turing.Model ,
746+ chain:: MCMCChains.Chains ;
747+ sampler = DynamicPPL. SampleFromPrior ()
748+ )
749+ vi = Turing. VarInfo (model)
750+
751+ transitions = map (1 : length (chain)) do i
752+ c = chain[i]
753+ md = vi. metadata
754+ for v in keys (md)
755+ for vn in md[v]. vns
756+ vn_symbol = Symbol (vn)
757+ if vn_symbol ∈ c. name_map. parameters
758+ val = c[vn_symbol]
759+ DynamicPPL. setval! (vi, val, vn)
760+ DynamicPPL. settrans! (vi, false , vn)
761+ else
762+ # delete so we can sample from prior
763+ DynamicPPL. set_flag! (vi, vn, " del" )
764+ end
765+ end
766+ end
767+ # Execute `model` on the parameters set in `vi` and sample those with `"del"` flag using `sampler`
768+ model (vi, sampler)
769+
770+ # Convert `VarInfo` into `NamedTuple` and save
771+ theta = DynamicPPL. tonamedtuple (vi)
772+ lp = Turing. getlogp (vi)
773+ Transition (theta, lp)
774+ end
775+
776+ return transitions
777+ end
778+
596779end # module
0 commit comments