@@ -50,7 +50,6 @@ import LogDensityProblems
5050import LogDensityProblemsAD
5151import Random
5252import MCMCChains
53- import StatsBase: predict
5453
5554export InferenceAlgorithm,
5655 Hamiltonian,
@@ -78,7 +77,6 @@ export InferenceAlgorithm,
7877 dot_assume,
7978 observe,
8079 dot_observe,
81- predict,
8280 externalsampler
8381
8482# ######################
@@ -396,7 +394,7 @@ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
396394 # this means that the code below will work both of linked and invlinked `vi`.
397395 # Ref: https://github.com/TuringLang/Turing.jl/issues/2195
398396 # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
399- vals = DynamicPPL. values_as_in_model (model, deepcopy (vi))
397+ vals = DynamicPPL. values_as_in_model (model, true , deepcopy (vi))
400398
401399 # Obtain an iterator over the flattened parameter names and values.
402400 iters = map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals))
@@ -612,112 +610,6 @@ end
612610DynamicPPL. getspace (spl:: Sampler ) = getspace (spl. alg)
613611DynamicPPL. inspace (vn:: VarName , spl:: Sampler ) = inspace (vn, getspace (spl. alg))
614612
615- """
616-
617- predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
618-
619- Execute `model` conditioned on each sample in `chain`, and return the resulting `Chains`.
620-
621- If `include_all` is `false`, the returned `Chains` will contain only those variables
622- sampled/not present in `chain`.
623-
624- # Details
625- Internally calls `Turing.Inference.transitions_from_chain` to obtained the samples
626- and then converts these into a `Chains` object using `AbstractMCMC.bundle_samples`.
627-
628- # Example
629- ```jldoctest
630- julia> using Turing; Turing.setprogress!(false);
631- [ Info: [Turing]: progress logging is disabled globally
632-
633- julia> @model function linear_reg(x, y, σ = 0.1)
634- β ~ Normal(0, 1)
635-
636- for i ∈ eachindex(y)
637- y[i] ~ Normal(β * x[i], σ)
638- end
639- end;
640-
641- julia> σ = 0.1; f(x) = 2 * x + 0.1 * randn();
642-
643- julia> Δ = 0.1; xs_train = 0:Δ:10; ys_train = f.(xs_train);
644-
645- julia> xs_test = [10 + Δ, 10 + 2 * Δ]; ys_test = f.(xs_test);
646-
647- julia> m_train = linear_reg(xs_train, ys_train, σ);
648-
649- julia> chain_lin_reg = sample(m_train, NUTS(100, 0.65), 200);
650- ┌ Info: Found initial step size
651- └ ϵ = 0.003125
652-
653- julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);
654-
655- julia> predictions = predict(m_test, chain_lin_reg)
656- Object of type Chains, with data of type 100×2×1 Array{Float64,3}
657-
658- Iterations = 1:100
659- Thinning interval = 1
660- Chains = 1
661- Samples per chain = 100
662- parameters = y[1], y[2]
663-
664- 2-element Array{ChainDataFrame,1}
665-
666- Summary Statistics
667- parameters mean std naive_se mcse ess r_hat
668- ────────── ─────── ────── ──────── ─────── ──────── ──────
669- y[1] 20.1974 0.1007 0.0101 missing 101.0711 0.9922
670- y[2] 20.3867 0.1062 0.0106 missing 101.4889 0.9903
671-
672- Quantiles
673- parameters 2.5% 25.0% 50.0% 75.0% 97.5%
674- ────────── ─────── ─────── ─────── ─────── ───────
675- y[1] 20.0342 20.1188 20.2135 20.2588 20.4188
676- y[2] 20.1870 20.3178 20.3839 20.4466 20.5895
677-
678-
679- julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1));
680-
681- julia> sum(abs2, ys_test - ys_pred) ≤ 0.1
682- true
683- ```
684- """
685- function predict (model:: Model , chain:: MCMCChains.Chains ; kwargs... )
686- return predict (Random. default_rng (), model, chain; kwargs... )
687- end
688- function predict (
689- rng:: AbstractRNG , model:: Model , chain:: MCMCChains.Chains ; include_all= false
690- )
691- # Don't need all the diagnostics
692- chain_parameters = MCMCChains. get_sections (chain, :parameters )
693-
694- spl = DynamicPPL. SampleFromPrior ()
695-
696- # Sample transitions using `spl` conditioned on values in `chain`
697- transitions = transitions_from_chain (rng, model, chain_parameters; sampler= spl)
698-
699- # Let the Turing internals handle everything else for you
700- chain_result = reduce (
701- MCMCChains. chainscat,
702- [
703- AbstractMCMC. bundle_samples (
704- transitions[:, chain_idx], model, spl, nothing , MCMCChains. Chains
705- ) for chain_idx in 1 : size (transitions, 2 )
706- ],
707- )
708-
709- parameter_names = if include_all
710- names (chain_result, :parameters )
711- else
712- filter (
713- k -> ∉ (k, names (chain_parameters, :parameters )),
714- names (chain_result, :parameters ),
715- )
716- end
717-
718- return chain_result[parameter_names]
719- end
720-
721613"""
722614
723615 transitions_from_chain(
0 commit comments