@@ -36,6 +36,69 @@ function chain_sample_to_varname_dict(
3636 return d
3737end
3838
39+ """
40+ DynamicPPL.varinfos_to_chains(
41+ ::Type{MCMCChains.Chains},
42+ model::Model,
43+ varinfos::AbstractArray{<:AbstractVarInfo},
44+ include_colon_eq::Bool=true,
45+ )
46+
47+ Convert an array of `VarInfo`s to an `MCMCChains.Chains` object.
48+ """
49+ function DynamicPPL. varinfos_to_chains (
50+ :: Type{MCMCChains.Chains} ,
51+ model:: DynamicPPL.Model ,
52+ varinfos:: AbstractMatrix{<:DynamicPPL.AbstractVarInfo} ,
53+ include_colon_eq:: Bool = true ,
54+ )
55+ all_vn_leaves = DynamicPPL. OrderedCollections. OrderedSet {DynamicPPL.VarName} ()
56+ # Re-evaluate model
57+ param_dicts = map (varinfos) do vi
58+ # Dict{VarName, Any}
59+ full_vals = DynamicPPL. values_as_in_model (model, include_colon_eq, vi)
60+ # Separate into individual VarNames.
61+ vn_leaves_and_vals = if isempty (full_vals)
62+ Tuple{VarName,Any}[]
63+ else
64+ iters = map (
65+ AbstractPPL. varname_and_value_leaves,
66+ keys (full_vals),
67+ values (full_vals),
68+ )
69+ mapreduce (collect, vcat, iters)
70+ end
71+ vn_leaves = map (first, vn_leaves_and_vals)
72+ vals = map (last, vn_leaves_and_vals)
73+ for vn_leaf in vn_leaves
74+ push! (all_vn_leaves, vn_leaf)
75+ end
76+ return DynamicPPL. OrderedCollections. OrderedDict (zip (vn_leaves, vals))
77+ end
78+ vn_leaves = collect (all_vn_leaves)
79+ vals = [
80+ get (param_dicts[i, j], key, missing ) for i in eachindex (axes (param_dicts, 1 )),
81+ key in vn_leaves, j in eachindex (axes (param_dicts, 2 ))
82+ ]
83+ symbols = map (Symbol, vn_leaves)
84+ info = (
85+ varname_to_symbol= DynamicPPL. OrderedCollections. OrderedDict (
86+ zip (all_vn_leaves, symbols)
87+ ),
88+ )
89+ return MCMCChains. Chains (MCMCChains. concretize (vals), symbols; info= info)
90+ end
91+ function DynamicPPL. varinfos_to_chains (
92+ :: Type{MCMCChains.Chains} ,
93+ model:: DynamicPPL.Model ,
94+ varinfos:: AbstractVector{<:DynamicPPL.AbstractVarInfo} ,
95+ include_colon_eq:: Bool = true ,
96+ )
97+ return DynamicPPL. varinfos_to_chains (
98+ MCMCChains. Chains, model, hcat (varinfos), include_colon_eq
99+ )
100+ end
101+
39102"""
40103 predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
41104
0 commit comments