Skip to content

Commit 6c615ad

Browse files
authored
disable fallback for returned and pointwise_logdensities (#1159)
* disable fallback for returned and pointwise_logdensities * Add tests
1 parent 098e7b0 commit 6c615ad

File tree

5 files changed

+76
-23
lines changed

5 files changed

+76
-23
lines changed

HISTORY.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# DynamicPPL Changelog
22

3+
## 0.38.10
4+
5+
`returned(model, chain)` and `pointwise_logdensities(model, chain)` will now error if a value for a random variable cannot be found in the chain.
6+
(Previously, they would instead resample such variables, which could lead to silent mistakes.)
7+
8+
If you encounter this error and it is accompanied by a warning about `hasvalue` not being implemented, you should be able to fix this by [using FlexiChains instead of MCMCChains](https://github.com/penelopeysm/FlexiChains.jl).
9+
(Alternatively, implementations of `hasvalue` for unsupported distributions are more than welcome; these must be provided in the Distributions extension of AbstractPPL.jl.)
10+
311
## 0.38.9
412

513
Remove warning when using Enzyme as the AD backend.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.38.9"
3+
version = "0.38.10"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,10 @@ function DynamicPPL.predict(
226226
)
227227
predictions = map(params_and_stats) do ps
228228
_, varinfo = DynamicPPL.init!!(
229-
rng, model, varinfo, DynamicPPL.InitFromParams(ps.params)
229+
rng,
230+
model,
231+
varinfo,
232+
DynamicPPL.InitFromParams(ps.params, DynamicPPL.InitFromPrior()),
230233
)
231234
DynamicPPL.ParamsWithStats(varinfo)
232235
end
@@ -316,11 +319,7 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
316319
params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain)
317320
return map(params_with_stats) do ps
318321
first(
319-
DynamicPPL.init!!(
320-
model,
321-
varinfo,
322-
DynamicPPL.InitFromParams(ps.params, DynamicPPL.InitFromPrior()),
323-
),
322+
DynamicPPL.init!!(model, varinfo, DynamicPPL.InitFromParams(ps.params, nothing))
324323
)
325324
end
326325
end
@@ -426,9 +425,7 @@ function DynamicPPL.pointwise_logdensities(
426425
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
427426
# Re-evaluate the model
428427
_, vi = DynamicPPL.init!!(
429-
model,
430-
vi,
431-
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
428+
model, vi, DynamicPPL.InitFromParams(values_dict, nothing)
432429
)
433430
DynamicPPL.getacc(vi, Val(accname)).logps
434431
end

test/ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,6 @@ module DynamicPPLMCMCChainsExtTests
33
using DynamicPPL, Distributions, MCMCChains, Test, AbstractMCMC
44

55
@testset "DynamicPPLMCMCChainsExt" begin
6-
@model demo() = x ~ Normal()
7-
model = demo()
8-
9-
chain = MCMCChains.Chains(
10-
randn(1000, 2, 1),
11-
[:x, :y],
12-
Dict(:internals => [:y]);
13-
info=(; varname_to_symbol=Dict(@varname(x) => :x)),
14-
)
15-
chain_generated = @test_nowarn returned(model, chain)
16-
@test size(chain_generated) == (1000, 1)
17-
@test mean(chain_generated) 0 atol = 0.1
18-
196
@testset "from_samples" begin
207
@model function f(z)
218
x ~ Normal()
@@ -61,6 +48,42 @@ using DynamicPPL, Distributions, MCMCChains, Test, AbstractMCMC
6148
@test new_p.stats == p.stats
6249
end
6350
end
51+
52+
@testset "returned (basic)" begin
53+
@model demo() = x ~ Normal()
54+
model = demo()
55+
56+
chain = MCMCChains.Chains(
57+
randn(1000, 2, 1),
58+
[:x, :y],
59+
Dict(:internals => [:y]);
60+
info=(; varname_to_symbol=Dict(@varname(x) => :x)),
61+
)
62+
chain_generated = @test_nowarn returned(model, chain)
63+
@test size(chain_generated) == (1000, 1)
64+
@test mean(chain_generated) 0 atol = 0.1
65+
end
66+
67+
@testset "returned: errors on missing variable" begin
68+
# Create a chain that only has `m`.
69+
@model function m_only()
70+
return m ~ Normal()
71+
end
72+
model_m_only = m_only()
73+
chain_m_only = AbstractMCMC.from_samples(
74+
MCMCChains.Chains,
75+
hcat([ParamsWithStats(VarInfo(model_m_only), model_m_only) for _ in 1:50]),
76+
)
77+
78+
# Define a model that needs both `m` and `s`.
79+
@model function f()
80+
m ~ Normal()
81+
s ~ Exponential()
82+
return y ~ Normal(m, s)
83+
end
84+
model = f() | (; y=1.0)
85+
@test_throws "No value was provided" returned(model, chain_m_only)
86+
end
6487
end
6588

6689
# test for `predict` is in `test/model.jl`

test/pointwise_logdensities.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,29 @@ end
9494
@test logprior logprior_true
9595
@test loglikelihood loglikelihood_true
9696
end
97+
98+
@testset "errors when variables are missing" begin
99+
# Create a chain that only has `m`.
100+
@model function m_only()
101+
return m ~ Normal()
102+
end
103+
model_m_only = m_only()
104+
chain_m_only = AbstractMCMC.from_samples(
105+
MCMCChains.Chains,
106+
hcat([ParamsWithStats(VarInfo(model_m_only), model_m_only) for _ in 1:50]),
107+
)
108+
109+
# Define a model that needs both `m` and `s`.
110+
@model function f()
111+
m ~ Normal()
112+
s ~ Exponential()
113+
return y ~ Normal(m, s)
114+
end
115+
model = f() | (; y=1.0)
116+
@test_throws "No value was provided" pointwise_logdensities(model, chain_m_only)
117+
@test_throws "No value was provided" pointwise_loglikelihoods(model, chain_m_only)
118+
@test_throws "No value was provided" pointwise_prior_logdensities(
119+
model, chain_m_only
120+
)
121+
end
97122
end

0 commit comments

Comments
 (0)