Skip to content

Commit 268b106

Browse files
committed
Implement varinfos_to_chains function
1 parent 1b159a6 commit 268b106

File tree

7 files changed

+127
-29
lines changed

7 files changed

+127
-29
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## 0.38.2
4+
5+
Added a new exported function, `DynamicPPL.varinfos_to_chains`, which automatically converts a collection of VarInfos to a given Chains type.
6+
37
## 0.38.1
48

59
Added `from_linked_vec_transform` and `from_vec_transform` methods for `ProductNamedTupleDistribution`.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.38.1"
3+
version = "0.38.2"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,69 @@ function chain_sample_to_varname_dict(
3636
return d
3737
end
3838

39+
"""
40+
DynamicPPL.varinfos_to_chains(
41+
::Type{MCMCChains.Chains},
42+
model::Model,
43+
varinfos::AbstractArray{<:AbstractVarInfo},
44+
include_colon_eq::Bool=true,
45+
)
46+
47+
Convert an array of `VarInfo`s to an `MCMCChains.Chains` object.
48+
"""
49+
function DynamicPPL.varinfos_to_chains(
50+
::Type{MCMCChains.Chains},
51+
model::DynamicPPL.Model,
52+
varinfos::AbstractMatrix{<:DynamicPPL.AbstractVarInfo},
53+
include_colon_eq::Bool=true,
54+
)
55+
all_vn_leaves = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
56+
# Re-evaluate model
57+
param_dicts = map(varinfos) do vi
58+
# Dict{VarName, Any}
59+
full_vals = DynamicPPL.values_as_in_model(model, include_colon_eq, vi)
60+
# Separate into individual VarNames.
61+
vn_leaves_and_vals = if isempty(full_vals)
62+
Tuple{VarName,Any}[]
63+
else
64+
iters = map(
65+
AbstractPPL.varname_and_value_leaves,
66+
keys(full_vals),
67+
values(full_vals),
68+
)
69+
mapreduce(collect, vcat, iters)
70+
end
71+
vn_leaves = map(first, vn_leaves_and_vals)
72+
vals = map(last, vn_leaves_and_vals)
73+
for vn_leaf in vn_leaves
74+
push!(all_vn_leaves, vn_leaf)
75+
end
76+
return DynamicPPL.OrderedCollections.OrderedDict(zip(vn_leaves, vals))
77+
end
78+
vn_leaves = collect(all_vn_leaves)
79+
vals = [
80+
get(param_dicts[i, j], key, missing) for i in eachindex(axes(param_dicts, 1)),
81+
key in vn_leaves, j in eachindex(axes(param_dicts, 2))
82+
]
83+
symbols = map(Symbol, vn_leaves)
84+
info = (
85+
varname_to_symbol=DynamicPPL.OrderedCollections.OrderedDict(
86+
zip(all_vn_leaves, symbols)
87+
),
88+
)
89+
return MCMCChains.Chains(MCMCChains.concretize(vals), symbols; info=info)
90+
end
91+
function DynamicPPL.varinfos_to_chains(
92+
::Type{MCMCChains.Chains},
93+
model::DynamicPPL.Model,
94+
varinfos::AbstractVector{<:DynamicPPL.AbstractVarInfo},
95+
include_colon_eq::Bool=true,
96+
)
97+
return DynamicPPL.varinfos_to_chains(
98+
MCMCChains.Chains, model, hcat(varinfos), include_colon_eq
99+
)
100+
end
101+
39102
"""
40103
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
41104

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ export AbstractVarInfo,
126126
prefix,
127127
returned,
128128
to_submodel,
129+
varinfos_to_chains,
129130
# Convenience macros
130131
@addlogprob!,
131132
value_iterator_from_chain,

src/abstract_varinfo.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,3 +1182,24 @@ function linked_internal_to_internal_transform(varinfo::AbstractVarInfo, vn::Var
11821182
f_to_internal = to_internal_transform(varinfo, vn)
11831183
return f_to_internal f_from_linked_internal
11841184
end
1185+
1186+
"""
1187+
varinfos_to_chains(
1188+
Tout::Type{<:AbstractChains},
1189+
model::DynamicPPL.Model,
1190+
varinfos::AbstractArray{<:AbstractVarInfo},
1191+
include_colon_eq::Bool=true
1192+
)
1193+
1194+
Convert an array of `varinfos` to a chains object of type `Tout`.
1195+
1196+
The `model` is required in order to account for cases where the varinfo is linked and
1197+
re-evaluation is required. For example, this is the case when the support of a distribution
1198+
depends on other random variables.
1199+
1200+
`include_colon_eq` indicates whether to include variables on the left-hand side of `:=`.
1201+
1202+
This function is not implemented here but rather in package extensions for individual chains
1203+
packages.
1204+
"""
1205+
function varinfos_to_chains end

test/ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using LinearAlgebra: I
2+
13
@testset "DynamicPPLMCMCChainsExt" begin
24
@model demo() = x ~ Normal()
35
model = demo()
@@ -11,6 +13,40 @@
1113
chain_generated = @test_nowarn returned(model, chain)
1214
@test size(chain_generated) == (1000, 1)
1315
@test mean(chain_generated) 0 atol = 0.1
16+
17+
@testset "varinfos_to_chains" begin
18+
@model function f()
19+
x ~ Normal()
20+
y ~ Normal(x)
21+
return z ~ MvNormal(zeros(3), I)
22+
end
23+
24+
model = f()
25+
26+
@testset "vector" begin
27+
vis = [VarInfo(model) for _ in 1:50]
28+
c = DynamicPPL.varinfos_to_chains(MCMCChains.Chains, model, vis)
29+
@test c isa MCMCChains.Chains
30+
@test size(c) == (50, 5, 1)
31+
@test c.info.varname_to_symbol[@varname(x)] == :x
32+
@test c.info.varname_to_symbol[@varname(y)] == :y
33+
@test c.info.varname_to_symbol[@varname(z[1])] == Symbol("z[1]")
34+
@test c.info.varname_to_symbol[@varname(z[2])] == Symbol("z[2]")
35+
@test c.info.varname_to_symbol[@varname(z[3])] == Symbol("z[3]")
36+
end
37+
38+
@testset "matrix" begin
39+
vis = [VarInfo(model) for _ in 1:50, _ in 1:3]
40+
c = DynamicPPL.varinfos_to_chains(MCMCChains.Chains, model, vis)
41+
@test c isa MCMCChains.Chains
42+
@test size(c) == (50, 5, 3)
43+
@test c.info.varname_to_symbol[@varname(x)] == :x
44+
@test c.info.varname_to_symbol[@varname(y)] == :y
45+
@test c.info.varname_to_symbol[@varname(z[1])] == Symbol("z[1]")
46+
@test c.info.varname_to_symbol[@varname(z[2])] == Symbol("z[2]")
47+
@test c.info.varname_to_symbol[@varname(z[3])] == Symbol("z[3]")
48+
end
49+
end
1450
end
1551

1652
# test for `predict` is in `test/model.jl`

test/test_util.jl

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -62,35 +62,8 @@ Construct an MCMCChains.Chains object by sampling from the prior of `model` for
6262
`n_iters` iterations.
6363
"""
6464
function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int)
65-
# Sample from the prior
6665
varinfos = [VarInfo(rng, model) for _ in 1:n_iters]
67-
# Extract all varnames found in any dictionary. Doing it this way guards
68-
# against the possibility of having different varnames in different
69-
# dictionaries, e.g. for models that have dynamic variables / array sizes
70-
varnames = OrderedSet{VarName}()
71-
# Convert each varinfo into an OrderedDict of vns => params.
72-
# We have to use varname_and_value_leaves so that each parameter is a scalar
73-
dicts = map(varinfos) do t
74-
vals = DynamicPPL.values_as(t, OrderedDict)
75-
iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals))
76-
tuples = mapreduce(collect, vcat, iters)
77-
# The following loop is a replacement for:
78-
# push!(varnames, map(first, tuples)...)
79-
# which causes a stack overflow if `map(first, tuples)` is too large.
80-
# Unfortunately there isn't a union() function for OrderedSet.
81-
for vn in map(first, tuples)
82-
push!(varnames, vn)
83-
end
84-
OrderedDict(tuples)
85-
end
86-
# Convert back to list
87-
varnames = collect(varnames)
88-
# Construct matrix of values
89-
vals = [get(dict, vn, missing) for dict in dicts, vn in varnames]
90-
# Construct dict of varnames -> symbol
91-
vn_to_sym_dict = Dict(zip(varnames, map(Symbol, varnames)))
92-
# Construct and return the Chains object
93-
return Chains(vals, varnames; info=(; varname_to_symbol=vn_to_sym_dict))
66+
return DynamicPPL.varinfos_to_chains(MCMCChains.Chains, model, varinfos)
9467
end
9568
function make_chain_from_prior(model::Model, n_iters::Int)
9669
return make_chain_from_prior(Random.default_rng(), model, n_iters)

0 commit comments

Comments
 (0)