From 57b338bdf5c93a4ac830dd71f451c52c91aaceff Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 27 Oct 2025 11:48:35 +0000 Subject: [PATCH] Revert "Add `to_chains` and `from_chains` function (#1087)" This reverts commit 11b7e01024b94ce7177ea2f9f2f60fbfafcecced. --- HISTORY.md | 4 - Project.toml | 2 +- docs/src/api.md | 18 --- ext/DynamicPPLMCMCChainsExt.jl | 165 ++++++++++------------------ src/DynamicPPL.jl | 4 - src/to_chains.jl | 158 -------------------------- test/ext/DynamicPPLMCMCChainsExt.jl | 89 --------------- test/runtests.jl | 1 - test/test_util.jl | 35 +++++- test/to_chains.jl | 69 ------------ 10 files changed, 86 insertions(+), 459 deletions(-) delete mode 100644 src/to_chains.jl delete mode 100644 test/to_chains.jl diff --git a/HISTORY.md b/HISTORY.md index 2985e8d41..604dcb725 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,9 +1,5 @@ # DynamicPPL Changelog -## 0.38.3 - -Added a new exported struct, `DynamicPPL.ParamsWithStats`, and a corresponding function `DynamicPPL.to_chains`, which automatically converts a collection of `ParamsWithStats` to a given Chains type. - ## 0.38.2 Added a compatibility entry for JET@0.11. diff --git a/Project.toml b/Project.toml index d54f9d1da..83e0fea3f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.38.3" +version = "0.38.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/api.md b/docs/src/api.md index 98f22bf30..80970c0bb 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -505,21 +505,3 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va DynamicPPL.Experimental.determine_suitable_varinfo DynamicPPL.Experimental.is_suitable_varinfo ``` - -### Converting VarInfos to chains - -It is a fairly common operation to want to convert a collection of `VarInfo` objects into a chains object for downstream analysis. -This can be accomplished with the following: - -```@docs -DynamicPPL.ParamsWithStats -DynamicPPL.to_chains -``` - -Furthermore, one can convert chains back into a collection of parameter dictionaries and/or stats with: - -```@docs -DynamicPPL.from_chains -``` - -This is useful if you want to use the result of a chain in further model evaluations. diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 528ad7f8e..003372449 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -36,113 +36,6 @@ function chain_sample_to_varname_dict( return d end -""" - DynamicPPL.to_chains( - ::Type{MCMCChains.Chains}, - params_and_stats::AbstractArray{<:ParamsWithStats} - ) - -Convert an array of `DynamicPPL.ParamsWithStats` to an `MCMCChains.Chains` object. -""" -function DynamicPPL.to_chains( - ::Type{MCMCChains.Chains}, - params_and_stats::AbstractMatrix{<:DynamicPPL.ParamsWithStats}, -) - # Handle parameters - all_vn_leaves = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() - split_dicts = map(params_and_stats) do ps - # Separate into individual VarNames. - vn_leaves_and_vals = if isempty(ps.params) - Tuple{DynamicPPL.VarName,Any}[] - else - iters = map( - AbstractPPL.varname_and_value_leaves, - keys(ps.params), - values(ps.params), - ) - mapreduce(collect, vcat, iters) - end - vn_leaves = map(first, vn_leaves_and_vals) - vals = map(last, vn_leaves_and_vals) - for vn_leaf in vn_leaves - push!(all_vn_leaves, vn_leaf) - end - DynamicPPL.OrderedCollections.OrderedDict(zip(vn_leaves, vals)) - end - vn_leaves = collect(all_vn_leaves) - param_vals = [ - get(split_dicts[i, j], key, missing) for i in eachindex(axes(split_dicts, 1)), - key in vn_leaves, j in eachindex(axes(split_dicts, 2)) - ] - param_symbols = map(Symbol, vn_leaves) - # Handle statistics - stat_keys = DynamicPPL.OrderedCollections.OrderedSet{Symbol}() - for ps in params_and_stats - for k in keys(ps.stats) - push!(stat_keys, k) - end - end - stat_keys = collect(stat_keys) - stat_vals = [ - get(params_and_stats[i, j].stats, key, missing) for - i in eachindex(axes(params_and_stats, 1)), key in stat_keys, - j in eachindex(axes(params_and_stats, 2)) - ] - # Construct name map and info - name_map = (internals=stat_keys,) - info = ( - varname_to_symbol=DynamicPPL.OrderedCollections.OrderedDict( - zip(all_vn_leaves, param_symbols) - ), - ) - # Concatenate parameter and statistic values - vals = cat(param_vals, stat_vals; dims=2) - symbols = vcat(param_symbols, stat_keys) - return MCMCChains.Chains(MCMCChains.concretize(vals), symbols, name_map; info=info) -end -function DynamicPPL.to_chains( - ::Type{MCMCChains.Chains}, ps::AbstractVector{<:DynamicPPL.ParamsWithStats} -) - return DynamicPPL.to_chains(MCMCChains.Chains, hcat(ps)) -end - -function DynamicPPL.from_chains( - ::Type{T}, chain::MCMCChains.Chains -) where {T<:AbstractDict{<:DynamicPPL.VarName}} - idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - matrix = map(idxs) do (sample_idx, chain_idx) - d = T() - for vn in DynamicPPL.varnames(chain) - d[vn] = DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx) - end - d - end - return matrix -end -function DynamicPPL.from_chains(::Type{NamedTuple}, chain::MCMCChains.Chains) - idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - matrix = map(idxs) do (sample_idx, chain_idx) - get(chain[sample_idx, :, chain_idx], keys(chain); flatten=true) - end - return matrix -end -function DynamicPPL.from_chains( - ::Type{DynamicPPL.ParamsWithStats}, chain::MCMCChains.Chains -) - idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - internals_chain = MCMCChains.get_sections(chain, :internals) - params = DynamicPPL.from_chains( - DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,eltype(chain.value)}, - chain, - ) - stats = DynamicPPL.from_chains(NamedTuple, internals_chain) - return map(idxs) do (sample_idx, chain_idx) - DynamicPPL.ParamsWithStats( - params[sample_idx, chain_idx], stats[sample_idx, chain_idx] - ) - end -end - """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) @@ -217,6 +110,7 @@ function DynamicPPL.predict( DynamicPPL.VarInfo(), ( DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(), DynamicPPL.ValuesAsInModelAccumulator(false), ), @@ -235,9 +129,23 @@ function DynamicPPL.predict( varinfo, DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), ) - DynamicPPL.ParamsWithStats(varinfo, nothing) + vals = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values + varname_vals = mapreduce( + collect, + vcat, + map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)), + ) + + return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo)) end - chain_result = DynamicPPL.to_chains(MCMCChains.Chains, predictive_samples) + + chain_result = reduce( + MCMCChains.chainscat, + [ + _predictive_samples_to_chains(predictive_samples[:, chain_idx]) for + chain_idx in 1:size(predictive_samples, 2) + ], + ) parameter_names = if include_all MCMCChains.names(chain_result, :parameters) else @@ -256,6 +164,45 @@ function DynamicPPL.predict( ) end +function _predictive_samples_to_arrays(predictive_samples) + variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() + + sample_dicts = map(predictive_samples) do sample + varname_value_pairs = sample.varname_and_values + varnames = map(first, varname_value_pairs) + values = map(last, varname_value_pairs) + for varname in varnames + push!(variable_names_set, varname) + end + + return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values)) + end + + variable_names = collect(variable_names_set) + variable_values = [ + get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts), + key in variable_names + ] + + return variable_names, variable_values +end + +function _predictive_samples_to_chains(predictive_samples) + variable_names, variable_values = _predictive_samples_to_arrays(predictive_samples) + variable_names_symbols = map(Symbol, variable_names) + + internal_parameters = [:lp] + log_probabilities = reshape([sample.logp for sample in predictive_samples], :, 1) + + parameter_names = [variable_names_symbols; internal_parameters] + parameter_values = hcat(variable_values, log_probabilities) + parameter_values = MCMCChains.concretize(parameter_values) + + return MCMCChains.Chains( + parameter_values, parameter_names, (internals=internal_parameters,) + ) +end + """ returned(model::Model, chain::MCMCChains.Chains) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 25d6f7282..f5bd33d6d 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -126,9 +126,6 @@ export AbstractVarInfo, prefix, returned, to_submodel, - # Chain construction - ParamsWithStats, - to_chains, # Convenience macros @addlogprob!, value_iterator_from_chain, @@ -197,7 +194,6 @@ include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") include("bijector.jl") -include("to_chains.jl") include("debug_utils.jl") using .DebugUtils diff --git a/src/to_chains.jl b/src/to_chains.jl deleted file mode 100644 index 365cae187..000000000 --- a/src/to_chains.jl +++ /dev/null @@ -1,158 +0,0 @@ -""" - ParamsWithStats - -A struct which contains parameter values extracted from a `VarInfo`, along with any -statistics associated with the VarInfo. The statistics are provided as a NamedTuple and are -optional. - - ParamsWithStats( - varinfo::AbstractVarInfo, - model::Model, - stats::NamedTuple=NamedTuple(); - include_colon_eq::Bool=true, - include_log_probs::Bool=true, - ) - -Generate a `ParamsWithStats` by re-evaluating the given `model` with the provided `varinfo`. -Re-evaluation of the model is often necessary to obtain correct parameter values as well as -log probabilities. This is especially true when using linked VarInfos, i.e., when variables -have been transformed to unconstrained space, and if this is not done, subtle correctness -bugs may arise: see, e.g., https://github.com/TuringLang/Turing.jl/issues/2195. - -`include_colon_eq` controls whether variables on the left-hand side of `:=` are included in -the resulting parameters. - -`include_log_probs` controls whether log probabilities (log prior, log likelihood, and log -joint) are added to the resulting statistics NamedTuple. - - ParamsWithStats( - varinfo::AbstractVarInfo, - ::Nothing, - stats::NamedTuple=NamedTuple(); - include_log_probs::Bool=true, - ) - -There is one case where re-evaluation is not necessary, which is when the VarInfos all -already contain `DynamicPPL.ValuesAsInModelAccumulator`. This accumulator stores values -as seen during the model evaluation, so the values can be simply read off. In this case, -`model` can be set to `nothing`, and no re-evaluation will be performed. However, it is the -caller's responsibility to ensure that `ValuesAsInModelAccumulator` is indeed -present. - -`include_log_probs` controls whether log probabilities (log prior, log likelihood, and log -joint) are added to the resulting statistics NamedTuple. -""" -struct ParamsWithStats{P<:OrderedDict{<:VarName,<:Any},S<:NamedTuple} - params::P - stats::S - - function ParamsWithStats( - params::P, stats::S - ) where {P<:OrderedDict{<:VarName,<:Any},S<:NamedTuple} - return new{P,S}(params, stats) - end - - function ParamsWithStats( - varinfo::AbstractVarInfo, - model::DynamicPPL.Model, - stats::NamedTuple=NamedTuple(); - include_colon_eq::Bool=true, - include_log_probs::Bool=true, - ) - varinfo = maybe_to_typed_varinfo(varinfo) - accs = if include_log_probs - ( - DynamicPPL.LogPriorAccumulator(), - DynamicPPL.LogLikelihoodAccumulator(), - DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq), - ) - else - (DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),) - end - varinfo = DynamicPPL.setaccs!!(varinfo, accs) - varinfo = last(DynamicPPL.evaluate!!(model, varinfo)) - params = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values - if include_log_probs - stats = merge( - stats, - ( - logprior=DynamicPPL.getlogprior(varinfo), - loglikelihood=DynamicPPL.getloglikelihood(varinfo), - lp=DynamicPPL.getlogjoint(varinfo), - ), - ) - end - return new{typeof(params),typeof(stats)}(params, stats) - end - - function ParamsWithStats( - varinfo::AbstractVarInfo, - ::Nothing, - stats::NamedTuple=NamedTuple(); - include_log_probs::Bool=true, - ) - params = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values - if include_log_probs - has_prior_acc = DynamicPPL.hasacc(varinfo, Val(:LogPrior)) - has_likelihood_acc = DynamicPPL.hasacc(varinfo, Val(:LogLikelihood)) - if has_prior_acc - stats = merge(stats, (logprior=DynamicPPL.getlogprior(varinfo),)) - end - if has_likelihood_acc - stats = merge(stats, (loglikelihood=DynamicPPL.getloglikelihood(varinfo),)) - end - if has_prior_acc && has_likelihood_acc - stats = merge(stats, (logjoint=DynamicPPL.getlogjoint(varinfo),)) - end - end - return new{typeof(params),typeof(stats)}(params, stats) - end -end - -# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's much faster to -# convert it to a typed varinfo first, hence this method. -# https://github.com/TuringLang/Turing.jl/issues/2604 -maybe_to_typed_varinfo(vi::VarInfo{<:Metadata}) = typed_varinfo(vi) -maybe_to_typed_varinfo(vi::AbstractVarInfo) = vi - -""" - to_chains( - Tout::Type{<:AbstractChains}, - params_and_stats::AbstractArray{<:ParamsWithStats} - )::Tout - -Convert an array of `ParamsWithStats` to a chains object of type `Tout`. - -This function is not implemented here but rather in package extensions for individual chains -packages. -""" -function to_chains end - -""" - from_chains( - ::Type{Tout}, - chain::AbstractChains - )::AbstractMatrix{<:Tout} - -Convert a chains object to an array of size (niters * nchains) with element type `Tout`. - -Note that even if `chain` contains only a single chain, this is still returned as a matrix, -not a vector. - -This function is not implemented here but rather in package extensions for individual chains -packages. - -Common implementations include: - - - `Tout = ParamsWithStats`: obtain both parameters and statistics - - `Tout <: AbstractDict{<:VarName}`: obtain the parameters only (since stats are not stored - as `VarName`s - - `Tout = NamedTuple`: obtain both parameters and statistics as a NamedTuple - -!!! warning - Note that `Tout = NamedTuple` potentially causes a loss of information especially when - used with `MCMCChains.Chains`, since variable names are not preserved. This may lead to - bugs if the NamedTuple is later used for other purposes, such as evaluating a model. To - avoid this, you should always use something like `Tout = OrderedDict{VarName,Any}`. -""" -function from_chains end diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index bb8895a0e..79e13ad84 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -11,95 +11,6 @@ chain_generated = @test_nowarn returned(model, chain) @test size(chain_generated) == (1000, 1) @test mean(chain_generated) ≈ 0 atol = 0.1 - - @testset "to_chains" begin - @model function f(z) - x ~ Normal() - y := x + 1 - return z ~ Normal(y) - end - - z = 1.0 - model = f(z) - - @testset "vector" begin - ps = [ParamsWithStats(VarInfo(model), model) for _ in 1:50] - c = DynamicPPL.to_chains(MCMCChains.Chains, ps) - @test c isa MCMCChains.Chains - @test size(c, 1) == 50 - @test size(c, 3) == 1 - @test Set(c.name_map.parameters) == Set([:x, :y]) - @test Set(c.name_map.internals) == Set([:logprior, :loglikelihood, :lp]) - @test logpdf.(Normal(), c[:x]) ≈ c[:logprior] - @test c.info.varname_to_symbol[@varname(x)] == :x - @test c.info.varname_to_symbol[@varname(y)] == :y - end - - @testset "matrix" begin - ps = [ParamsWithStats(VarInfo(model), model) for _ in 1:50, _ in 1:3] - c = DynamicPPL.to_chains(MCMCChains.Chains, ps) - @test c isa MCMCChains.Chains - @test size(c, 1) == 50 - @test size(c, 3) == 3 - @test Set(c.name_map.parameters) == Set([:x, :y]) - @test Set(c.name_map.internals) == Set([:logprior, :loglikelihood, :lp]) - @test logpdf.(Normal(), c[:x]) ≈ c[:logprior] - @test c.info.varname_to_symbol[@varname(x)] == :x - @test c.info.varname_to_symbol[@varname(y)] == :y - end - end - - @testset "from_chains" begin - @model function f(z) - x ~ Normal() - y := x + 1 - return z ~ Normal(y) - end - - z = 1.0 - model = f(z) - ps = [ParamsWithStats(VarInfo(model), model) for _ in 1:50] - c = DynamicPPL.to_chains(MCMCChains.Chains, ps) - - @testset "OrderedDict" begin - arr_dicts = DynamicPPL.from_chains(OrderedDict{VarName,Any}, c) - @test size(arr_dicts) == (50, 1) - for i in 1:50 - dict = arr_dicts[i, 1] - @test dict isa OrderedDict{VarName,Any} - p = ps[i].params - @test dict[@varname(x)] == p[@varname(x)] - @test dict[@varname(y)] == p[@varname(y)] - @test length(dict) == 2 - end - end - - @testset "NamedTuple" begin - arr_nts = DynamicPPL.from_chains(NamedTuple, c) - @test size(arr_nts) == (50, 1) - for i in 1:50 - nt = arr_nts[i, 1] - @test length(nt) == 5 - p = ps[i] - @test nt.x == p.params[@varname(x)] - @test nt.y == p.params[@varname(y)] - @test nt.lp == p.stats.lp - @test nt.logprior == p.stats.logprior - @test nt.loglikelihood == p.stats.loglikelihood - end - end - - @testset "ParamsWithStats" begin - arr_pss = DynamicPPL.from_chains(ParamsWithStats, c) - @test size(arr_pss) == (50, 1) - for i in 1:50 - new_p = arr_pss[i, 1] - p = ps[i] - @test new_p.params == p.params - @test new_p.stats == p.stats - end - end - end end # test for `predict` is in `test/model.jl` diff --git a/test/runtests.jl b/test/runtests.jl index fb1d92d7c..b6a3f7bf6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -73,7 +73,6 @@ include("test_util.jl") include("debug_utils.jl") include("submodels.jl") include("bijector.jl") - include("to_chains.jl") end if GROUP == "All" || GROUP == "Group2" diff --git a/test/test_util.jl b/test/test_util.jl index a37d10f31..164751c7b 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -62,12 +62,35 @@ Construct an MCMCChains.Chains object by sampling from the prior of `model` for `n_iters` iterations. """ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int) - vi = VarInfo(model) - vi = DynamicPPL.setaccs!!(vi, (DynamicPPL.ValuesAsInModelAccumulator(false),)) - ps = [ - ParamsWithStats(last(DynamicPPL.init!!(rng, model, vi)), nothing) for _ in 1:n_iters - ] - return DynamicPPL.to_chains(MCMCChains.Chains, ps) + # Sample from the prior + varinfos = [VarInfo(rng, model) for _ in 1:n_iters] + # Extract all varnames found in any dictionary. Doing it this way guards + # against the possibility of having different varnames in different + # dictionaries, e.g. for models that have dynamic variables / array sizes + varnames = OrderedSet{VarName}() + # Convert each varinfo into an OrderedDict of vns => params. + # We have to use varname_and_value_leaves so that each parameter is a scalar + dicts = map(varinfos) do t + vals = DynamicPPL.values_as(t, OrderedDict) + iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)) + tuples = mapreduce(collect, vcat, iters) + # The following loop is a replacement for: + # push!(varnames, map(first, tuples)...) + # which causes a stack overflow if `map(first, tuples)` is too large. + # Unfortunately there isn't a union() function for OrderedSet. + for vn in map(first, tuples) + push!(varnames, vn) + end + OrderedDict(tuples) + end + # Convert back to list + varnames = collect(varnames) + # Construct matrix of values + vals = [get(dict, vn, missing) for dict in dicts, vn in varnames] + # Construct dict of varnames -> symbol + vn_to_sym_dict = Dict(zip(varnames, map(Symbol, varnames))) + # Construct and return the Chains object + return Chains(vals, varnames; info=(; varname_to_symbol=vn_to_sym_dict)) end function make_chain_from_prior(model::Model, n_iters::Int) return make_chain_from_prior(Random.default_rng(), model, n_iters) diff --git a/test/to_chains.jl b/test/to_chains.jl deleted file mode 100644 index 9d6c04b5f..000000000 --- a/test/to_chains.jl +++ /dev/null @@ -1,69 +0,0 @@ -module DynamicPPLToChainsTests - -using DynamicPPL -using Distributions -using Test - -@testset "ParamsWithStats" begin - @model function f(z) - x ~ Normal() - y := x + 1 - return z ~ Normal(y) - end - z = 1.0 - model = f(z) - - @testset "with reevaluation" begin - ps = ParamsWithStats(VarInfo(model), model) - @test haskey(ps.params, @varname(x)) - @test haskey(ps.params, @varname(y)) - @test length(ps.params) == 2 - @test haskey(ps.stats, :logprior) - @test haskey(ps.stats, :loglikelihood) - @test haskey(ps.stats, :lp) - @test length(ps.stats) == 3 - @test ps.stats.lp ≈ ps.stats.logprior + ps.stats.loglikelihood - @test ps.params[@varname(y)] ≈ ps.params[@varname(x)] + 1 - @test ps.stats.logprior ≈ logpdf(Normal(), ps.params[@varname(x)]) - @test ps.stats.loglikelihood ≈ logpdf(Normal(ps.params[@varname(y)]), z) - end - - @testset "without colon_eq" begin - ps = ParamsWithStats(VarInfo(model), model; include_colon_eq=false) - @test haskey(ps.params, @varname(x)) - @test length(ps.params) == 1 - @test haskey(ps.stats, :logprior) - @test haskey(ps.stats, :loglikelihood) - @test haskey(ps.stats, :lp) - @test length(ps.stats) == 3 - @test ps.stats.lp ≈ ps.stats.logprior + ps.stats.loglikelihood - @test ps.stats.logprior ≈ logpdf(Normal(), ps.params[@varname(x)]) - @test ps.stats.loglikelihood ≈ logpdf(Normal(ps.params[@varname(x)] + 1), z) - end - - @testset "without log probs" begin - ps = ParamsWithStats(VarInfo(model), model; include_log_probs=false) - @test haskey(ps.params, @varname(x)) - @test haskey(ps.params, @varname(y)) - @test length(ps.params) == 2 - @test isempty(ps.stats) - end - - @testset "no reevaluation" begin - # Without VAIM, it should error - @test_throws ErrorException ParamsWithStats(VarInfo(model), nothing) - # With VAIM, it should work - vi = DynamicPPL.setaccs!!( - VarInfo(model), (DynamicPPL.ValuesAsInModelAccumulator(true),) - ) - vi = last(DynamicPPL.evaluate!!(model, vi)) - ps = ParamsWithStats(vi, nothing) - @test haskey(ps.params, @varname(x)) - @test haskey(ps.params, @varname(y)) - @test length(ps.params) == 2 - # Because we didn't evaluate with log prob accumulators, there should be no stats - @test isempty(ps.stats) - end -end - -end # module