From 8135113c924c4fe53197c02c7e780fe9846580c4 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Apr 2025 17:47:27 +0100 Subject: [PATCH 01/27] Bump minor version to 0.37.0 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 67d844d99..01e2cb612 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.36.0" +version = "0.37.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 299e17b915bfc6cdf9f3f0d467423f0b7258243e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 2 May 2025 14:12:34 +0100 Subject: [PATCH 02/27] Accumulators, stage 1 (#885) * Release 0.36 * AbstractPPL 0.11 + change prefixing behaviour (#830) * AbstractPPL 0.11; change prefixing behaviour * Use DynamicPPL.prefix rather than overloading * Remove VarInfo(VarInfo, params) (#870) * Unify `{untyped,typed}_{vector_,}varinfo` constructor functions (#879) * Unify {Untyped,Typed}{Vector,}VarInfo constructors * Update invocations * NTVarInfo * Fix tests * More fixes * Fixes * Fixes * Fixes * Use lowercase functions, don't deprecate VarInfo * Rewrite VarInfo docstring * Fix methods * Fix methods (really) * Draft of accumulators * Fix some variable names * Fix pointwise_logdensities, gut tilde_observe, remove resetlogp!! * Map rather than broadcast Co-authored-by: Tor Erlend Fjelde * Start documenting accumulators * Use Val{symbols} instead of AccTypes to index * More documentation for accumulators * Link varinfo by default in AD testing utilities; make test suite run on linked varinfos (#890) * Link VarInfo by default * Tweak interface * Fix tests * Fix interface so that callers can inspect results * Document * Fix tests * Fix changelog * Test linked varinfos Closes #891 * Fix docstring + use AbstractFloat * Fix resetlogp!! and type stability for accumulators * Fix type rigidity of LogProbs and NumProduce * Fix uses of getlogp and other assorted issues * setaccs!! nicer interface and logdensity function fixes * Revert back to calling the macro @addlogprob! * Remove a dead test * Clarify a comment * Implement split/combine for PointwiseLogdensityAccumulator * Switch ThreadSafeVarInfo.accs_by_thread to be a tuple * Fix `condition` and `fix` in submodels (#892) * Fix conditioning in submodels * Simplify contextual_isassumption * Add documentation * Fix some tests * Add tests; fix a bunch of nested submodel issues * Fix fix as well * Fix doctests * Add unit tests for new functions * Add changelog entry * Update changelog Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Finish docs * Add a test for conditioning submodel via arguments * Clean new tests up a bit * Fix for VarNames with non-identity lenses * Apply suggestions from code review Co-authored-by: Markus Hauru * Apply suggestions from code review * Make PrefixContext contain a varname rather than symbol (#896) --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> Co-authored-by: Markus Hauru * Revert ThreadSafeVarInfo back to Vectors and fix some AD type casting in (Simple)VarInfo * Improve accumulator docs * Add test/accumulators.jl * Docs fixes * Various small fixes * Make DynamicTransformation not use accumulators other than LogPrior * Fix variable order and name of map_accumulator!! * Typo fixing * Small improvement to ThreadSafeVarInfo * Fix demo_dot_assume_observe_submodel prefixing * Typo fixing * Miscellaneous small fixes * HISTORY entry and more miscellanea * Add more tests for accumulators * Improve accumulators docstrings * Fix a typo * Expand HISTORY entry * Add accumulators to API docs * Remove unexported functions from API docs * Add NamedTuple methods for get/set/acclogp * Fix setlogp!! with single scalar to error * Export AbstractAccumulator, fix a docs typo * Apply suggestions from code review Co-authored-by: Penelope Yong * Rename LogPrior -> LogPriorAccumulator, and Likelihood and NumProduce * Type bound log prob accumulators with T<:Real * Add @addlogprior! and @addloglikelihood! * Apply suggestions from code review Co-authored-by: Penelope Yong * Move default accumulators to default_accumulators.jl * Fix some tests * Introduce default_accumulators() * Go back to only having @addlogprob! * Fix tilde_observe!! prefixing * Fix default_accumulators internal type * Make unflatten more type stable, and add a test for it * Always print all benchmark results * Move NumProduce VI functions to abstract_varinfo.jl --------- Co-authored-by: Penelope Yong Co-authored-by: Tor Erlend Fjelde Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- HISTORY.md | 18 ++ Project.toml | 2 + benchmarks/benchmarks.jl | 1 + docs/src/api.md | 38 ++-- ext/DynamicPPLMCMCChainsExt.jl | 12 +- src/DynamicPPL.jl | 28 ++- src/abstract_varinfo.jl | 321 +++++++++++++++++++++++++++++--- src/accumulators.jl | 189 +++++++++++++++++++ src/compiler.jl | 2 +- src/context_implementations.jl | 152 ++++----------- src/contexts.jl | 61 ++---- src/debug_utils.jl | 56 +++--- src/default_accumulators.jl | 154 +++++++++++++++ src/logdensityfunction.jl | 24 ++- src/model.jl | 26 ++- src/pointwise_logdensities.jl | 263 ++++++++++++-------------- src/simple_varinfo.jl | 122 +++++------- src/submodel_macro.jl | 4 +- src/test_utils/contexts.jl | 30 +-- src/test_utils/models.jl | 30 +-- src/test_utils/varinfo.jl | 14 +- src/threadsafe.jl | 117 +++++++----- src/transforming.jl | 41 +++- src/utils.jl | 61 +++--- src/values_as_in_model.jl | 13 +- src/varinfo.jl | 318 +++++++++++++++++-------------- test/accumulators.jl | 176 +++++++++++++++++ test/compiler.jl | 18 +- test/context_implementations.jl | 7 +- test/contexts.jl | 14 +- test/independence.jl | 11 -- test/linking.jl | 12 +- test/model.jl | 4 +- test/pointwise_logdensities.jl | 7 - test/runtests.jl | 2 +- test/sampler.jl | 8 +- test/simple_varinfo.jl | 34 ++-- test/submodels.jl | 10 +- test/threadsafe.jl | 49 ++--- test/utils.jl | 29 ++- test/varinfo.jl | 201 +++++++++++++++++--- test/varnamedvector.jl | 4 +- 42 files changed, 1768 insertions(+), 915 deletions(-) create mode 100644 src/accumulators.jl create mode 100644 src/default_accumulators.jl create mode 100644 test/accumulators.jl delete mode 100644 test/independence.jl diff --git a/HISTORY.md b/HISTORY.md index 9a70e8d1f..68650f9d1 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,23 @@ # DynamicPPL Changelog +## 0.37.0 + +**Breaking changes** + +### Accumulators + +This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes: + + - `PriorContext` and `LikelihoodContext` no longer exist. By default, a `VarInfo` tracks both the log prior and the log likelihood separately, and they can be accessed with `getlogprior` and `getloglikelihood`. If you want to execute a model while only accumulating one of the two (to save clock cycles), you can do so by creating a `VarInfo` that only has one accumulator in it, e.g. `varinfo = setaccs!!(varinfo, (LogPriorAccumulator(),))`. + - `MiniBatchContext` does not exist anymore. It can be replaced by creating and using a custom accumulator that replaces the default `LikelihoodContext`. We may introduce such an accumulator in DynamicPPL in the future, but for now you'll need to do it yourself. + - `tilde_observe` and `observe` have been removed. `tilde_observe!!` still exists, and any contexts should modify its behaviour. We may further rework the call stack under `tilde_observe!!` in the near future. + - `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well. + - For literal observation statements like `0.0 ~ Normal(blahblah)` we used to call `tilde_observe!!` without the `vn` argument. This method no longer exists. Rather we call `tilde_observe!!` with `vn` set to `nothing`. + - `set/reset/increment_num_produce!` have become `set/reset/increment_num_produce!!` (note the second exclamation mark). They are no longer guaranteed to modify the `VarInfo` in place, and one should always use the return value. + - `@addlogprob!` now _always_ adds to the log likelihood. Previously it added to the log probability that the execution context specified, e.g. the log prior when using `PriorContext`. + - `getlogp` now returns a `NamedTuple` with keys `logprior` and `loglikelihood`. If you want the log joint probability, which is what `getlogp` used to return, use `getlogjoint`. + - Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The `acclogp!!` method with a single scalar value has been deprecated and falls back on `accloglikelihood!!`, and the single scalar version of `setlogp!!` has been removed. Corresponding setter/accumulator functions exist for the log prior as well. + ## 0.36.0 **Breaking changes** diff --git a/Project.toml b/Project.toml index 01e2cb612..25c6acd24 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -68,6 +69,7 @@ MCMCChains = "6" MacroTools = "0.5.6" Mooncake = "0.4.95" OrderedCollections = "1" +Printf = "1.10" Random = "1.6" Requires = "1" Statistics = "1" diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 89b65d2de..9661dd505 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -100,4 +100,5 @@ PrettyTables.pretty_table( header=header, tf=PrettyTables.tf_markdown, formatters=ft_printf("%.1f", [6, 7]), + crop=:none, # Always print the whole table, even if it doesn't fit in the terminal. ) diff --git a/docs/src/api.md b/docs/src/api.md index 08522e2ce..8e5c64886 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -160,7 +160,7 @@ returned(::Model) ## Utilities -It is possible to manually increase (or decrease) the accumulated log density from within a model function. +It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function. ```@docs @addlogprob! @@ -328,9 +328,9 @@ The following functions were used for sequential Monte Carlo methods. ```@docs get_num_produce -set_num_produce! -increment_num_produce! -reset_num_produce! +set_num_produce!! +increment_num_produce!! +reset_num_produce!! setorder! set_retained_vns_del! ``` @@ -345,6 +345,22 @@ Base.empty! SimpleVarInfo ``` +### Accumulators + +The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. + +```@docs +AbstractAccumulator +``` + +DynamicPPL provides the following default accumulators. + +```@docs +LogPriorAccumulator +LogLikelihoodAccumulator +NumProduceAccumulator +``` + ### Common API #### Accumulation of log-probabilities @@ -353,6 +369,13 @@ SimpleVarInfo getlogp setlogp!! acclogp!! +getlogjoint +getlogprior +setlogprior!! +acclogprior!! +getloglikelihood +setloglikelihood!! +accloglikelihood!! resetlogp!! ``` @@ -427,9 +450,6 @@ Contexts are subtypes of `AbstractPPL.AbstractContext`. ```@docs SamplingContext DefaultContext -LikelihoodContext -PriorContext -MiniBatchContext PrefixContext ConditionContext ``` @@ -476,7 +496,3 @@ DynamicPPL.Experimental.is_suitable_varinfo ```@docs tilde_assume ``` - -```@docs -tilde_observe -``` diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 7fcbd6a7c..70f0f0182 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -48,10 +48,10 @@ end Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample in `chain`, and return the resulting `Chains`. -The `model` passed to `predict` is often different from the one used to generate `chain`. -Typically, the model from which `chain` originated treats certain variables as observed (i.e., -data points), while the model you pass to `predict` may mark these same variables as missing -or unobserved. Calling `predict` then leverages the previously inferred parameter values to +The `model` passed to `predict` is often different from the one used to generate `chain`. +Typically, the model from which `chain` originated treats certain variables as observed (i.e., +data points), while the model you pass to `predict` may mark these same variables as missing +or unobserved. Calling `predict` then leverages the previously inferred parameter values to simulate what new, unobserved data might look like, given your posterior beliefs. For each parameter configuration in `chain`: @@ -59,7 +59,7 @@ For each parameter configuration in `chain`: 2. Any variables not included in `chain` are sampled from their prior distributions. If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by -the samples in `chain`. This is useful when you want to sample only new variables from the posterior +the samples in `chain`. This is useful when you want to sample only new variables from the posterior predictive distribution. # Examples @@ -124,7 +124,7 @@ function DynamicPPL.predict( map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)), ) - return (varname_and_values=varname_vals, logp=DynamicPPL.getlogp(varinfo)) + return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo)) end chain_result = reduce( diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c1c613d08..7527c8be2 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -6,6 +6,7 @@ using Bijectors using Compat using Distributions using OrderedCollections: OrderedCollections, OrderedDict +using Printf: Printf using AbstractMCMC: AbstractMCMC using ADTypes: ADTypes @@ -46,17 +47,28 @@ import Base: export AbstractVarInfo, VarInfo, SimpleVarInfo, + AbstractAccumulator, + LogLikelihoodAccumulator, + LogPriorAccumulator, + NumProduceAccumulator, push!!, empty!!, subset, getlogp, + getlogjoint, + getlogprior, + getloglikelihood, setlogp!!, + setlogprior!!, + setloglikelihood!!, acclogp!!, + acclogprior!!, + accloglikelihood!!, resetlogp!!, get_num_produce, - set_num_produce!, - reset_num_produce!, - increment_num_produce!, + set_num_produce!!, + reset_num_produce!!, + increment_num_produce!!, set_retained_vns_del!, is_flagged, set_flag!, @@ -92,15 +104,10 @@ export AbstractVarInfo, # Contexts SamplingContext, DefaultContext, - LikelihoodContext, - PriorContext, - MiniBatchContext, PrefixContext, ConditionContext, assume, - observe, tilde_assume, - tilde_observe, # Pseudo distributions NamedDist, NoDist, @@ -146,6 +153,9 @@ macro prob_str(str) )) end +# TODO(mhauru) We should write down the list of methods that any subtype of AbstractVarInfo +# has to implement. Not sure what the full list is for parameters values, but for +# accumulators we only need `getaccs` and `setaccs!!`. """ AbstractVarInfo @@ -166,6 +176,8 @@ include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") include("varnamedvector.jl") +include("accumulators.jl") +include("default_accumulators.jl") include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index f11b8a3ec..4917a4892 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -90,45 +90,289 @@ Return the `AbstractTransformation` related to `vi`. function transformation end # Accumulation of log-probabilities. +""" + getlogjoint(vi::AbstractVarInfo) + +Return the log of the joint probability of the observed data and parameters in `vi`. + +See also: [`getlogprior`](@ref), [`getloglikelihood`](@ref). +""" +getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi) + """ getlogp(vi::AbstractVarInfo) -Return the log of the joint probability of the observed data and parameters sampled in -`vi`. +Return a NamedTuple of the log prior and log likelihood probabilities. + +The keys are called `logprior` and `loglikelihood`. If either one is not present in `vi` an +error will be thrown. +""" +function getlogp(vi::AbstractVarInfo) + return (; logprior=getlogprior(vi), loglikelihood=getloglikelihood(vi)) +end + +""" + setaccs!!(vi::AbstractVarInfo, accs::AccumulatorTuple) + setaccs!!(vi::AbstractVarInfo, accs::NTuple{N,AbstractAccumulator} where {N}) + +Update the `AccumulatorTuple` of `vi` to `accs`, mutating if it makes sense. + +`setaccs!!(vi:AbstractVarInfo, accs::AccumulatorTuple) should be implemented by each subtype +of `AbstractVarInfo`. +""" +function setaccs!!(vi::AbstractVarInfo, accs::NTuple{N,AbstractAccumulator}) where {N} + return setaccs!!(vi, AccumulatorTuple(accs)) +end + +""" + getaccs(vi::AbstractVarInfo) + +Return the `AccumulatorTuple` of `vi`. + +This should be implemented by each subtype of `AbstractVarInfo`. +""" +function getaccs end + +""" + hasacc(vi::AbstractVarInfo, ::Val{accname}) where {accname} + +Return a boolean for whether `vi` has an accumulator with name `accname`. +""" +hasacc(vi::AbstractVarInfo, accname::Val) = haskey(getaccs(vi), accname) +function hasacc(vi::AbstractVarInfo, accname::Symbol) + return error( + """ + The method hasacc(vi::AbstractVarInfo, accname::Symbol) does not exist. For type + stability reasons use hasacc(vi::AbstractVarInfo, Val(accname)) instead. + """ + ) +end + +""" + acckeys(vi::AbstractVarInfo) + +Return the names of the accumulators in `vi`. +""" +acckeys(vi::AbstractVarInfo) = keys(getaccs(vi)) + +""" + getlogprior(vi::AbstractVarInfo) + +Return the log of the prior probability of the parameters in `vi`. + +See also: [`getlogjoint`](@ref), [`getloglikelihood`](@ref), [`setlogprior!!`](@ref). +""" +getlogprior(vi::AbstractVarInfo) = getacc(vi, Val(:LogPrior)).logp + +""" + getloglikelihood(vi::AbstractVarInfo) + +Return the log of the likelihood probability of the observed data in `vi`. + +See also: [`getlogjoint`](@ref), [`getlogprior`](@ref), [`setloglikelihood!!`](@ref). +""" +getloglikelihood(vi::AbstractVarInfo) = getacc(vi, Val(:LogLikelihood)).logp + +""" + setacc!!(vi::AbstractVarInfo, acc::AbstractAccumulator) + +Add `acc` to the `AccumulatorTuple` of `vi`, mutating if it makes sense. + +If an accumulator with the same [`accumulator_name`](@ref) already exists, it will be +replaced. + +See also: [`getaccs`](@ref). +""" +function setacc!!(vi::AbstractVarInfo, acc::AbstractAccumulator) + return setaccs!!(vi, setacc!!(getaccs(vi), acc)) +end + +""" + setlogprior!!(vi::AbstractVarInfo, logp) + +Set the log of the prior probability of the parameters sampled in `vi` to `logp`. + +See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@ref). +""" +setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPriorAccumulator(logp)) + +""" + setloglikelihood!!(vi::AbstractVarInfo, logp) + +Set the log of the likelihood probability of the observed data sampled in `vi` to `logp`. + +See also: [`setlogprior!!`](@ref), [`setlogp!!`](@ref), [`getloglikelihood`](@ref). +""" +setloglikelihood!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogLikelihoodAccumulator(logp)) + +""" + setlogp!!(vi::AbstractVarInfo, logp::NamedTuple) + +Set both the log prior and the log likelihood probabilities in `vi`. + +`logp` should have fields `logprior` and `loglikelihood` and no other fields. + +See also: [`setlogprior!!`](@ref), [`setloglikelihood!!`](@ref), [`getlogp`](@ref). +""" +function setlogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names} + if !(names == (:logprior, :loglikelihood) || names == (:loglikelihood, :logprior)) + error("logp must have the fields logprior and loglikelihood and no other fields.") + end + vi = setlogprior!!(vi, logp.logprior) + vi = setloglikelihood!!(vi, logp.loglikelihood) + return vi +end + +function setlogp!!(vi::AbstractVarInfo, logp::Number) + return error(""" + `setlogp!!(vi::AbstractVarInfo, logp::Number)` is no longer supported. Use + `setloglikelihood!!` and/or `setlogprior!!` instead. + """) +end + +""" + getacc(vi::AbstractVarInfo, ::Val{accname}) + +Return the `AbstractAccumulator` of `vi` with name `accname`. +""" +function getacc(vi::AbstractVarInfo, accname::Val) + return getacc(getaccs(vi), accname) +end +function getacc(vi::AbstractVarInfo, accname::Symbol) + return error( + """ + The method getacc(vi::AbstractVarInfo, accname::Symbol) does not exist. For type + stability reasons use getacc(vi::AbstractVarInfo, Val(accname)) instead. + """ + ) +end + +""" + accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, right) + +Update all the accumulators of `vi` by calling `accumulate_assume!!` on them. +""" +function accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, right) + return map_accumulators!!(acc -> accumulate_assume!!(acc, val, logjac, vn, right), vi) +end + +""" + accumulate_observe!!(vi::AbstractVarInfo, right, left, vn) + +Update all the accumulators of `vi` by calling `accumulate_observe!!` on them. """ -function getlogp end +function accumulate_observe!!(vi::AbstractVarInfo, right, left, vn) + return map_accumulators!!(acc -> accumulate_observe!!(acc, right, left, vn), vi) +end """ - setlogp!!(vi::AbstractVarInfo, logp) + map_accumulators!!(func::Function, vi::AbstractVarInfo) -Set the log of the joint probability of the observed data and parameters sampled in -`vi` to `logp`, mutating if it makes sense. +Update all accumulators of `vi` by calling `func` on them and replacing them with the return +values. """ -function setlogp!! end +function map_accumulators!!(func::Function, vi::AbstractVarInfo) + return setaccs!!(vi, map(func, getaccs(vi))) +end """ - acclogp!!([context::AbstractContext, ]vi::AbstractVarInfo, logp) + map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) where {accname} -Add `logp` to the value of the log of the joint probability of the observed data and -parameters sampled in `vi`, mutating if it makes sense. +Update the accumulator `accname` of `vi` by calling `func` on it and replacing it with the +return value. """ -function acclogp!!(context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(NodeTrait(context), context, vi, logp) +function map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Val) + return setaccs!!(vi, map_accumulator(func, getaccs(vi), accname)) +end + +function map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol) + return error( + """ + The method map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol) + does not exist. For type stability reasons use + map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) instead. + """ + ) end -function acclogp!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(vi, logp) + +""" + acclogprior!!(vi::AbstractVarInfo, logp) + +Add `logp` to the value of the log of the prior probability in `vi`. + +See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getlogprior`](@ref), [`setlogprior!!`](@ref). +""" +function acclogprior!!(vi::AbstractVarInfo, logp) + return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior)) end -function acclogp!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(childcontext(context), vi, logp) + +""" + accloglikelihood!!(vi::AbstractVarInfo, logp) + +Add `logp` to the value of the log of the likelihood in `vi`. + +See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getloglikelihood`](@ref), [`setloglikelihood!!`](@ref). +""" +function accloglikelihood!!(vi::AbstractVarInfo, logp) + return map_accumulator!!( + acc -> acc + LogLikelihoodAccumulator(logp), vi, Val(:LogLikelihood) + ) +end + +""" + acclogp!!(vi::AbstractVarInfo, logp::NamedTuple; ignore_missing_accumulator::Bool=false) + +Add to both the log prior and the log likelihood probabilities in `vi`. + +`logp` should have fields `logprior` and/or `loglikelihood`, and no other fields. + +By default if the necessary accumulators are not in `vi` an error is thrown. If +`ignore_missing_accumulator` is set to `true` then this is silently ignored instead. +""" +function acclogp!!( + vi::AbstractVarInfo, logp::NamedTuple{names}; ignore_missing_accumulator=false +) where {names} + if !( + names == (:logprior, :loglikelihood) || + names == (:loglikelihood, :logprior) || + names == (:logprior,) || + names == (:loglikelihood,) + ) + error("logp must have fields logprior and/or loglikelihood and no other fields.") + end + if haskey(logp, :logprior) && + (!ignore_missing_accumulator || hasacc(vi, Val(:LogPrior))) + vi = acclogprior!!(vi, logp.logprior) + end + if haskey(logp, :loglikelihood) && + (!ignore_missing_accumulator || hasacc(vi, Val(:LogLikelihood))) + vi = accloglikelihood!!(vi, logp.loglikelihood) + end + return vi +end + +function acclogp!!(vi::AbstractVarInfo, logp::Number) + Base.depwarn( + "`acclogp!!(vi::AbstractVarInfo, logp::Number)` is deprecated. Use `accloglikelihood!!(vi, logp)` instead.", + :acclogp, + ) + return accloglikelihood!!(vi, logp) end """ resetlogp!!(vi::AbstractVarInfo) -Reset the value of the log of the joint probability of the observed data and parameters -sampled in `vi` to 0, mutating if it makes sense. +Reset the values of the log probabilities (prior and likelihood) in `vi` to zero. """ -resetlogp!!(vi::AbstractVarInfo) = setlogp!!(vi, zero(getlogp(vi))) +function resetlogp!!(vi::AbstractVarInfo) + if hasacc(vi, Val(:LogPrior)) + vi = map_accumulator!!(zero, vi, Val(:LogPrior)) + end + if hasacc(vi, Val(:LogLikelihood)) + vi = map_accumulator!!(zero, vi, Val(:LogLikelihood)) + end + return vi +end # Variables and their realizations. @doc """ @@ -566,8 +810,8 @@ function link!!( x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) - lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(unflatten(vi, y), lp_new) + lp_new = getlogprior(vi) - logjac + vi_new = setlogprior!!(unflatten(vi, y), lp_new) return settrans!!(vi_new, t) end @@ -578,8 +822,8 @@ function invlink!!( y = vi[:] x, logjac = with_logabsdet_jacobian(b, y) - lp_new = getlogp(vi) + logjac - vi_new = setlogp!!(unflatten(vi, x), lp_new) + lp_new = getlogprior(vi) + logjac + vi_new = setlogprior!!(unflatten(vi, x), lp_new) return settrans!!(vi_new, NoTransformation()) end @@ -723,9 +967,34 @@ function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y) return x, logpdf(dist, x) + logjac end -# Legacy code that is currently overloaded for the sake of simplicity. -# TODO: Remove when possible. -increment_num_produce!(::AbstractVarInfo) = nothing +""" + get_num_produce(vi::AbstractVarInfo) + +Return the `num_produce` of `vi`. +""" +get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:NumProduce)).num + +""" + set_num_produce!!(vi::AbstractVarInfo, n::Int) + +Set the `num_produce` field of `vi` to `n`. +""" +set_num_produce!!(vi::AbstractVarInfo, n::Int) = setacc!!(vi, NumProduceAccumulator(n)) + +""" + increment_num_produce!!(vi::AbstractVarInfo) + +Add 1 to `num_produce` in `vi`. +""" +increment_num_produce!!(vi::AbstractVarInfo) = + map_accumulator!!(increment, vi, Val(:NumProduce)) + +""" + reset_num_produce!!(vi::AbstractVarInfo) + +Reset the value of `num_produce` in `vi` to 0. +""" +reset_num_produce!!(vi::AbstractVarInfo) = map_accumulator!!(zero, vi, Val(:NumProduce)) """ from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist]) diff --git a/src/accumulators.jl b/src/accumulators.jl new file mode 100644 index 000000000..10a988ae5 --- /dev/null +++ b/src/accumulators.jl @@ -0,0 +1,189 @@ +""" + AbstractAccumulator + +An abstract type for accumulators. + +An accumulator is an object that may change its value at every tilde_assume!! or +tilde_observe!! call based on the random variable in question. The obvious examples of +accumulators are the log prior and log likelihood. Other examples might be a variable that +counts the number of observations in a trace, or a list of the names of random variables +seen so far. + +An accumulator type `T <: AbstractAccumulator` must implement the following methods: +- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})` +- `accumulate_observe!!(acc::T, right, left, vn)` +- `accumulate_assume!!(acc::T, val, logjac, vn, right)` + +To be able to work with multi-threading, it should also implement: +- `split(acc::T)` +- `combine(acc::T, acc2::T)` + +See the documentation for each of these functions for more details. +""" +abstract type AbstractAccumulator end + +""" + accumulator_name(acc::AbstractAccumulator) + +Return a Symbol which can be used as a name for `acc`. + +The name has to be unique in the sense that a `VarInfo` can only have one accumulator for +each name. The most typical case, and the default implementation, is that the name only +depends on the type of `acc`, not on its value. +""" +accumulator_name(acc::AbstractAccumulator) = accumulator_name(typeof(acc)) + +""" + accumulate_observe!!(acc::AbstractAccumulator, right, left, vn) + +Update `acc` in a `tilde_observe!!` call. Returns the updated `acc`. + +`vn` is the name of the variable being observed, `left` is the value of the variable, and +`right` is the distribution on the RHS of the tilde statement. `vn` is `nothing` in the case +of literal observations like `0.0 ~ Normal()`. + +`accumulate_observe!!` may mutate `acc`, but not any of the other arguments. + +See also: [`accumulate_assume!!`](@ref) +""" +function accumulate_observe!! end + +""" + accumulate_assume!!(acc::AbstractAccumulator, val, logjac, vn, right) + +Update `acc` in a `tilde_assume!!` call. Returns the updated `acc`. + +`vn` is the name of the variable being assumed, `val` is the value of the variable, and +`right` is the distribution on the RHS of the tilde statement. `logjac` is the log +determinant of the Jacobian of the transformation that was done to convert the value of `vn` +as it was given (e.g. by sampler operating in linked space) to `val`. + +`accumulate_assume!!` may mutate `acc`, but not any of the other arguments. + +See also: [`accumulate_observe!!`](@ref) +""" +function accumulate_assume!! end + +""" + split(acc::AbstractAccumulator) + +Return a new accumulator like `acc` but empty. + +The precise meaning of "empty" is that that the returned value should be such that +`combine(acc, split(acc))` is equal to `acc`. This is used in the context of multi-threading +where different threads may accumulate independently and the results are the combined. + +See also: [`combine`](@ref) +""" +function split end + +""" + combine(acc::AbstractAccumulator, acc2::AbstractAccumulator) + +Combine two accumulators of the same type. Returns a new accumulator. + +See also: [`split`](@ref) +""" +function combine end + +# TODO(mhauru) The existence of this function makes me sad. See comment in unflatten in +# src/varinfo.jl. +""" + convert_eltype(::Type{T}, acc::AbstractAccumulator) + +Convert `acc` to use element type `T`. + +What "element type" means depends on the type of `acc`. By default this function does +nothing. Accumulator types that need to hold differentiable values, such as dual numbers +used by various AD backends, should implement a method for this function. +""" +convert_eltype(::Type, acc::AbstractAccumulator) = acc + +""" + AccumulatorTuple{N,T<:NamedTuple} + +A collection of accumulators, stored as a `NamedTuple` of length `N` + +This is defined as a separate type to be able to dispatch on it cleanly and without method +ambiguities or conflicts with other `NamedTuple` types. We also use this type to enforce the +constraint that the name in the tuple for each accumulator `acc` must be +`accumulator_name(acc)`, and these names must be unique. + +The constructor can be called with a tuple or a `VarArgs` of `AbstractAccumulators`. The +names will be generated automatically. One can also call the constructor with a `NamedTuple` +but the names in the argument will be discarded in favour of the generated ones. +""" +struct AccumulatorTuple{N,T<:NamedTuple} + nt::T + + function AccumulatorTuple(t::T) where {N,T<:NTuple{N,AbstractAccumulator}} + names = map(accumulator_name, t) + nt = NamedTuple{names}(t) + return new{N,typeof(nt)}(nt) + end +end + +AccumulatorTuple(accs::Vararg{AbstractAccumulator}) = AccumulatorTuple(accs) +AccumulatorTuple(nt::NamedTuple) = AccumulatorTuple(tuple(nt...)) + +# When showing with text/plain, leave out information about the wrapper AccumulatorTuple. +Base.show(io::IO, mime::MIME"text/plain", at::AccumulatorTuple) = show(io, mime, at.nt) +Base.getindex(at::AccumulatorTuple, idx) = at.nt[idx] +Base.length(::AccumulatorTuple{N}) where {N} = N +Base.iterate(at::AccumulatorTuple, args...) = iterate(at.nt, args...) +function Base.haskey(at::AccumulatorTuple, ::Val{accname}) where {accname} + # @inline to ensure constant propagation can resolve this to a compile-time constant. + @inline return haskey(at.nt, accname) +end +Base.keys(at::AccumulatorTuple) = keys(at.nt) + +function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N}) where {N,T} + return AccumulatorTuple(convert(T, accs.nt)) +end + +""" + setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) + +Add `acc` to `at`. Returns a new `AccumulatorTuple`. + +If an `AbstractAccumulator` with the same `accumulator_name` already exists in `at` it is +replaced. `at` will never be mutated, but the name has the `!!` for consistency with the +corresponding function for `AbstractVarInfo`. +""" +function setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) + accname = accumulator_name(acc) + new_nt = merge(at.nt, NamedTuple{(accname,)}((acc,))) + return AccumulatorTuple(new_nt) +end + +""" + getacc(at::AccumulatorTuple, ::Val{accname}) + +Get the accumulator with name `accname` from `at`. +""" +function getacc(at::AccumulatorTuple, ::Val{accname}) where {accname} + return at[accname] +end + +function Base.map(func::Function, at::AccumulatorTuple) + return AccumulatorTuple(map(func, at.nt)) +end + +""" + map_accumulator(func::Function, at::AccumulatorTuple, ::Val{accname}) + +Update the accumulator with name `accname` in `at` by calling `func` on it. + +Returns a new `AccumulatorTuple`. +""" +function map_accumulator( + func::Function, at::AccumulatorTuple, ::Val{accname} +) where {accname} + # Would like to write this as + # return Accessors.@set at.nt[accname] = func(at[accname], args...) + # for readability, but that one isn't type stable due to + # https://github.com/JuliaObjects/Accessors.jl/issues/198 + new_val = func(at[accname]) + new_nt = merge(at.nt, NamedTuple{(accname,)}((new_val,))) + return AccumulatorTuple(new_nt) +end diff --git a/src/compiler.jl b/src/compiler.jl index 6f7489b8e..9eb4835d3 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -418,7 +418,7 @@ function generate_tilde_literal(left, right) @gensym value return quote $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, nothing, __varinfo__ ) $value end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index eb025dec8..b92e49fba 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,27 +14,6 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false -# Allows samplers, etc. to hook into the final logp accumulation in the tilde-pipeline. -function acclogp_assume!!(context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_assume!!(NodeTrait(acclogp_assume!!, context), context, vi, logp) -end -function acclogp_assume!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_assume!!(childcontext(context), vi, logp) -end -function acclogp_assume!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(context, vi, logp) -end - -function acclogp_observe!!(context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_observe!!(NodeTrait(acclogp_observe!!, context), context, vi, logp) -end -function acclogp_observe!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_observe!!(childcontext(context), vi, logp) -end -function acclogp_observe!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(context, vi, logp) -end - # assume """ tilde_assume(context::SamplingContext, right, vn, vi) @@ -52,36 +31,18 @@ function tilde_assume(context::SamplingContext, right, vn, vi) return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) end -# Leaf contexts function tilde_assume(context::AbstractContext, args...) - return tilde_assume(NodeTrait(tilde_assume, context), context, args...) + return tilde_assume(childcontext(context), args...) end -function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vi) +function tilde_assume(::DefaultContext, right, vn, vi) return assume(right, vn, vi) end -function tilde_assume(::IsParent, context::AbstractContext, args...) - return tilde_assume(childcontext(context), args...) -end function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...) - return tilde_assume(NodeTrait(tilde_assume, context), rng, context, args...) -end -function tilde_assume( - ::IsLeaf, rng::Random.AbstractRNG, context::AbstractContext, sampler, right, vn, vi -) - return assume(rng, sampler, right, vn, vi) -end -function tilde_assume( - ::IsParent, rng::Random.AbstractRNG, context::AbstractContext, args... -) return tilde_assume(rng, childcontext(context), args...) end - -function tilde_assume(::LikelihoodContext, right, vn, vi) - return assume(nodist(right), vn, vi) -end -function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi) - return assume(rng, sampler, nodist(right), vn, vi) +function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi) + return assume(rng, sampler, right, vn, vi) end function tilde_assume(context::PrefixContext, right, vn, vi) @@ -137,55 +98,42 @@ function tilde_assume!!(context, right, vn, vi) end rand_like!!(right, context, vi) else - value, logp, vi = tilde_assume(context, right, vn, vi) - value, acclogp_assume!!(context, vi, logp) + value, vi = tilde_assume(context, right, vn, vi) + return value, vi end end # observe """ - tilde_observe(context::SamplingContext, right, left, vi) + tilde_observe!!(context::SamplingContext, right, left, vi) Handle observed constants with a `context` associated with a sampler. -Falls back to `tilde_observe(context.context, context.sampler, right, left, vi)`. +Falls back to `tilde_observe!!(context.context, right, left, vi)`. """ -function tilde_observe(context::SamplingContext, right, left, vi) - return tilde_observe(context.context, context.sampler, right, left, vi) +function tilde_observe!!(context::SamplingContext, right, left, vn, vi) + return tilde_observe!!(context.context, right, left, vn, vi) end -# Leaf contexts -function tilde_observe(context::AbstractContext, args...) - return tilde_observe(NodeTrait(tilde_observe, context), context, args...) -end -tilde_observe(::IsLeaf, context::AbstractContext, args...) = observe(args...) -function tilde_observe(::IsParent, context::AbstractContext, args...) - return tilde_observe(childcontext(context), args...) -end - -tilde_observe(::PriorContext, right, left, vi) = 0, vi -tilde_observe(::PriorContext, sampler, right, left, vi) = 0, vi - -# `MiniBatchContext` -function tilde_observe(context::MiniBatchContext, right, left, vi) - logp, vi = tilde_observe(context.context, right, left, vi) - return context.loglike_scalar * logp, vi -end -function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - logp, vi = tilde_observe(context.context, sampler, right, left, vi) - return context.loglike_scalar * logp, vi +function tilde_observe!!(context::AbstractContext, right, left, vn, vi) + return tilde_observe!!(childcontext(context), right, left, vn, vi) end # `PrefixContext` -function tilde_observe(context::PrefixContext, right, left, vi) - return tilde_observe(context.context, right, left, vi) -end -function tilde_observe(context::PrefixContext, sampler, right, left, vi) - return tilde_observe(context.context, sampler, right, left, vi) +function tilde_observe!!(context::PrefixContext, right, left, vn, vi) + # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal + # value. For the need for prefix_and_strip_contexts rather than just prefix, see the + # comment in `tilde_assume!!`. + new_vn, new_context = if vn !== nothing + prefix_and_strip_contexts(context, vn) + else + vn, childcontext(context) + end + return tilde_observe!!(new_context, right, left, new_vn, vi) end """ - tilde_observe!!(context, right, left, vname, vi) + tilde_observe!!(context, right, left, vn, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value and updated `vi`. @@ -193,46 +141,27 @@ accumulate the log probability, and return the observed value and updated `vi`. Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!!(context, right, left, vname, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`~` with a model on the right-hand side of an observe statement is not supported", - ), - ) - return tilde_observe!!(context, right, left, vi) -end - -""" - tilde_observe(context, right, left, vi) - -Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and -return the observed value. - -By default, calls `tilde_observe(context, right, left, vi)` and accumulates the log -probability of `vi` with the returned value. -""" -function tilde_observe!!(context, right, left, vi) +function tilde_observe!!(context::DefaultContext, right, left, vn, vi) is_rhs_model(right) && throw( ArgumentError( "`~` with a model on the right-hand side of an observe statement is not supported", ), ) - logp, vi = tilde_observe(context, right, left, vi) - return left, acclogp_observe!!(context, vi, logp) + vi = accumulate_observe!!(vi, right, left, vn) + return left, vi end function assume(rng::Random.AbstractRNG, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end -function observe(spl::Sampler, weight) - return error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))") -end - # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) - r, logp = invlink_with_logpdf(vi, vn, dist) - return r, logp, vi + y = getindex_internal(vi, vn) + f = from_maybe_linked_internal_transform(vi, vn, dist) + x, logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, logjac, vn, dist) + return x, vi end # TODO: Remove this thing. @@ -254,8 +183,7 @@ function assume( r = init(rng, dist, sampler) f = to_maybe_linked_internal_transform(vi, vn, dist) # TODO(mhauru) This should probably be call a function called setindex_internal! - # Also, if we use !! we shouldn't ignore the return value. - BangBang.setindex!!(vi, f(r), vn) + vi = BangBang.setindex!!(vi, f(r), vn) setorder!(vi, vn, get_num_produce(vi)) else # Otherwise we just extract it. @@ -265,22 +193,16 @@ function assume( r = init(rng, dist, sampler) if istrans(vi) f = to_linked_internal_transform(vi, vn, dist) - push!!(vi, vn, f(r), dist) + vi = push!!(vi, vn, f(r), dist) # By default `push!!` sets the transformed flag to `false`. - settrans!!(vi, true, vn) + vi = settrans!!(vi, true, vn) else - push!!(vi, vn, r, dist) + vi = push!!(vi, vn, r, dist) end end # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) - return r, logpdf(dist, r) - logjac, vi -end - -# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) -observe(sampler::AbstractSampler, right, left, vi) = observe(right, left, vi) -function observe(right::Distribution, left, vi) - increment_num_produce!(vi) - return Distributions.loglikelihood(right, left), vi + vi = accumulate_assume!!(vi, r, -logjac, vn, dist) + return r, vi end diff --git a/src/contexts.jl b/src/contexts.jl index 8ac085663..addadfa1a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -45,15 +45,17 @@ effectively updating the child context. # Examples ```jldoctest +julia> using DynamicPPL: DynamicTransformationContext + julia> ctx = SamplingContext(); julia> DynamicPPL.childcontext(ctx) DefaultContext() -julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior +julia> ctx_prior = DynamicPPL.setchildcontext(ctx, DynamicTransformationContext{true}()); julia> DynamicPPL.childcontext(ctx_prior) -PriorContext() +DynamicTransformationContext{true}() ``` """ setchildcontext @@ -78,7 +80,7 @@ original leaf context of `left`. # Examples ```jldoctest -julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext +julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, DynamicTransformationContext julia> struct ParentContext{C} <: AbstractContext context::C @@ -96,8 +98,8 @@ julia> ctx = ParentContext(ParentContext(DefaultContext())) ParentContext(ParentContext(DefaultContext())) julia> # Replace the leaf context with another leaf. - leafcontext(setleafcontext(ctx, PriorContext())) -PriorContext() + leafcontext(setleafcontext(ctx, DynamicTransformationContext{true}())) +DynamicTransformationContext{true}() julia> # Append another parent context. setleafcontext(ctx, ParentContext(DefaultContext())) @@ -129,7 +131,7 @@ setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right Create a context that allows you to sample parameters with the `sampler` when running the model. The `context` determines how the returned log density is computed when running the model. -See also: [`DefaultContext`](@ref), [`LikelihoodContext`](@ref), [`PriorContext`](@ref) +See also: [`DefaultContext`](@ref) """ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext rng::R @@ -189,52 +191,11 @@ getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context") """ struct DefaultContext <: AbstractContext end -The `DefaultContext` is used by default to compute the log joint probability of the data -and parameters when running the model. +The `DefaultContext` is used by default to accumulate values like the log joint probability +when running the model. """ struct DefaultContext <: AbstractContext end -NodeTrait(context::DefaultContext) = IsLeaf() - -""" - PriorContext <: AbstractContext - -A leaf context resulting in the exclusion of likelihood terms when running the model. -""" -struct PriorContext <: AbstractContext end -NodeTrait(context::PriorContext) = IsLeaf() - -""" - LikelihoodContext <: AbstractContext - -A leaf context resulting in the exclusion of prior terms when running the model. -""" -struct LikelihoodContext <: AbstractContext end -NodeTrait(context::LikelihoodContext) = IsLeaf() - -""" - struct MiniBatchContext{Tctx, T} <: AbstractContext - context::Tctx - loglike_scalar::T - end - -The `MiniBatchContext` enables the computation of -`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the -`loglike_scalar` field, typically equal to `the number of data points / batch size`. -This is useful in batch-based stochastic gradient descent algorithms to be optimizing -`log(prior) + log(likelihood of all the data points)` in the expectation. -""" -struct MiniBatchContext{Tctx,T} <: AbstractContext - context::Tctx - loglike_scalar::T -end -function MiniBatchContext(context=DefaultContext(); batch_size, npoints) - return MiniBatchContext(context, npoints / batch_size) -end -NodeTrait(context::MiniBatchContext) = IsParent() -childcontext(context::MiniBatchContext) = context.context -function setchildcontext(parent::MiniBatchContext, child) - return MiniBatchContext(child, parent.loglike_scalar) -end +NodeTrait(::DefaultContext) = IsLeaf() """ PrefixContext(vn::VarName[, context::AbstractContext]) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 15ef8fb01..238cd422d 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -76,7 +76,6 @@ Base.@kwdef struct AssumeStmt <: Stmt varname right value - logp varinfo = nothing end @@ -89,16 +88,12 @@ function Base.show(io::IO, stmt::AssumeStmt) print(io, " ") print(io, RESULT_SYMBOL) print(io, " ") - print(io, stmt.value) - print(io, " (logprob = ") - print(io, stmt.logp) - return print(io, ")") + return print(io, stmt.value) end Base.@kwdef struct ObserveStmt <: Stmt left right - logp varinfo = nothing end @@ -107,10 +102,7 @@ function Base.show(io::IO, stmt::ObserveStmt) print(io, "observe: ") show_right(io, stmt.left) print(io, " ~ ") - show_right(io, stmt.right) - print(io, " (logprob = ") - print(io, stmt.logp) - return print(io, ")") + return show_right(io, stmt.right) end # Some utility methods for extracting information from a trace. @@ -252,12 +244,11 @@ function record_pre_tilde_assume!(context::DebugContext, vn, dist, varinfo) return nothing end -function record_post_tilde_assume!(context::DebugContext, vn, dist, value, logp, varinfo) +function record_post_tilde_assume!(context::DebugContext, vn, dist, value, varinfo) stmt = AssumeStmt(; varname=vn, right=dist, value=value, - logp=logp, varinfo=context.record_varinfo ? varinfo : nothing, ) if context.record_statements @@ -268,19 +259,17 @@ end function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi) record_pre_tilde_assume!(context, vn, right, vi) - value, logp, vi = DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) - record_post_tilde_assume!(context, vn, right, value, logp, vi) - return value, logp, vi + value, vi = DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) + record_post_tilde_assume!(context, vn, right, value, vi) + return value, vi end function DynamicPPL.tilde_assume( rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi ) record_pre_tilde_assume!(context, vn, right, vi) - value, logp, vi = DynamicPPL.tilde_assume( - rng, childcontext(context), sampler, right, vn, vi - ) - record_post_tilde_assume!(context, vn, right, value, logp, vi) - return value, logp, vi + value, vi = DynamicPPL.tilde_assume(rng, childcontext(context), sampler, right, vn, vi) + record_post_tilde_assume!(context, vn, right, value, vi) + return value, vi end # observe @@ -304,12 +293,9 @@ function record_pre_tilde_observe!(context::DebugContext, left, dist, varinfo) end end -function record_post_tilde_observe!(context::DebugContext, left, right, logp, varinfo) +function record_post_tilde_observe!(context::DebugContext, left, right, varinfo) stmt = ObserveStmt(; - left=left, - right=right, - logp=logp, - varinfo=context.record_varinfo ? varinfo : nothing, + left=left, right=right, varinfo=context.record_varinfo ? varinfo : nothing ) if context.record_statements push!(context.statements, stmt) @@ -317,17 +303,17 @@ function record_post_tilde_observe!(context::DebugContext, left, right, logp, va return nothing end -function DynamicPPL.tilde_observe(context::DebugContext, right, left, vi) +function DynamicPPL.tilde_observe!!(context::DebugContext, right, left, vn, vi) record_pre_tilde_observe!(context, left, right, vi) - logp, vi = DynamicPPL.tilde_observe(childcontext(context), right, left, vi) - record_post_tilde_observe!(context, left, right, logp, vi) - return logp, vi + vi = DynamicPPL.tilde_observe!!(childcontext(context), right, left, vn, vi) + record_post_tilde_observe!(context, left, right, vi) + return vi end -function DynamicPPL.tilde_observe(context::DebugContext, sampler, right, left, vi) +function DynamicPPL.tilde_observe!!(context::DebugContext, sampler, right, left, vn, vi) record_pre_tilde_observe!(context, left, right, vi) - logp, vi = DynamicPPL.tilde_observe(childcontext(context), sampler, right, left, vi) - record_post_tilde_observe!(context, left, right, logp, vi) - return logp, vi + vi = DynamicPPL.tilde_observe!!(childcontext(context), sampler, right, left, vn, vi) + record_post_tilde_observe!(context, left, right, vi) + return vi end _conditioned_varnames(d::AbstractDict) = keys(d) @@ -413,7 +399,7 @@ julia> issuccess true julia> print(trace) - assume: x ~ Normal{Float64}(μ=0.0, σ=1.0) ⟼ -0.670252 (logprob = -1.14356) + assume: x ~ Normal{Float64}(μ=0.0, σ=1.0) ⟼ -0.670252 julia> issuccess, trace = check_model_and_trace(rng, demo_correct() | (x = 1.0,)); @@ -421,7 +407,7 @@ julia> issuccess true julia> print(trace) -observe: 1.0 ~ Normal{Float64}(μ=0.0, σ=1.0) (logprob = -1.41894) +observe: 1.0 ~ Normal{Float64}(μ=0.0, σ=1.0) ``` ## Incorrect model diff --git a/src/default_accumulators.jl b/src/default_accumulators.jl new file mode 100644 index 000000000..ab538ba51 --- /dev/null +++ b/src/default_accumulators.jl @@ -0,0 +1,154 @@ +""" + LogPriorAccumulator{T<:Real} <: AbstractAccumulator + +An accumulator that tracks the cumulative log prior during model execution. + +# Fields +$(TYPEDFIELDS) +""" +struct LogPriorAccumulator{T<:Real} <: AbstractAccumulator + "the scalar log prior value" + logp::T +end + +""" + LogPriorAccumulator{T}() + +Create a new `LogPriorAccumulator` accumulator with the log prior initialized to zero. +""" +LogPriorAccumulator{T}() where {T<:Real} = LogPriorAccumulator(zero(T)) +LogPriorAccumulator() = LogPriorAccumulator{LogProbType}() + +""" + LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator + +An accumulator that tracks the cumulative log likelihood during model execution. + +# Fields +$(TYPEDFIELDS) +""" +struct LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator + "the scalar log likelihood value" + logp::T +end + +""" + LogLikelihoodAccumulator{T}() + +Create a new `LogLikelihoodAccumulator` accumulator with the log likelihood initialized to zero. +""" +LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T)) +LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}() + +""" + NumProduceAccumulator{T} <: AbstractAccumulator + +An accumulator that tracks the number of observations during model execution. + +# Fields +$(TYPEDFIELDS) +""" +struct NumProduceAccumulator{T<:Integer} <: AbstractAccumulator + "the number of observations" + num::T +end + +""" + NumProduceAccumulator{T<:Integer}() + +Create a new `NumProduceAccumulator` accumulator with the number of observations initialized to zero. +""" +NumProduceAccumulator{T}() where {T<:Integer} = NumProduceAccumulator(zero(T)) +NumProduceAccumulator() = NumProduceAccumulator{Int}() + +function Base.show(io::IO, acc::LogPriorAccumulator) + return print(io, "LogPriorAccumulator($(repr(acc.logp)))") +end +function Base.show(io::IO, acc::LogLikelihoodAccumulator) + return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))") +end +function Base.show(io::IO, acc::NumProduceAccumulator) + return print(io, "NumProduceAccumulator($(repr(acc.num)))") +end + +accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior +accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood +accumulator_name(::Type{<:NumProduceAccumulator}) = :NumProduce + +split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T)) +split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T)) +split(acc::NumProduceAccumulator) = acc + +function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator) + return LogPriorAccumulator(acc.logp + acc2.logp) +end +function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) + return LogLikelihoodAccumulator(acc.logp + acc2.logp) +end +function combine(acc::NumProduceAccumulator, acc2::NumProduceAccumulator) + return NumProduceAccumulator(max(acc.num, acc2.num)) +end + +function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) + return LogPriorAccumulator(acc1.logp + acc2.logp) +end +function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) + return LogLikelihoodAccumulator(acc1.logp + acc2.logp) +end +increment(acc::NumProduceAccumulator) = NumProduceAccumulator(acc.num + oneunit(acc.num)) + +Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp)) +Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp)) +Base.zero(acc::NumProduceAccumulator) = NumProduceAccumulator(zero(acc.num)) + +function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right) + return acc + LogPriorAccumulator(logpdf(right, val) + logjac) +end +accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc + +accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc +function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) + # Note that it's important to use the loglikelihood function here, not logpdf, because + # they handle vectors differently: + # https://github.com/JuliaStats/Distributions.jl/issues/1972 + return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left)) +end + +accumulate_assume!!(acc::NumProduceAccumulator, val, logjac, vn, right) = acc +accumulate_observe!!(acc::NumProduceAccumulator, right, left, vn) = increment(acc) + +function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T} + return LogPriorAccumulator(convert(T, acc.logp)) +end +function Base.convert( + ::Type{LogLikelihoodAccumulator{T}}, acc::LogLikelihoodAccumulator +) where {T} + return LogLikelihoodAccumulator(convert(T, acc.logp)) +end +function Base.convert( + ::Type{NumProduceAccumulator{T}}, acc::NumProduceAccumulator +) where {T} + return NumProduceAccumulator(convert(T, acc.num)) +end + +# TODO(mhauru) +# We ignore the convert_eltype calls for NumProduceAccumulator, by letting them fallback on +# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to +# deal with dual number types of AD backends, which shouldn't concern NumProduceAccumulator. This is +# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`. +function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T} + return LogPriorAccumulator(convert(T, acc.logp)) +end +function convert_eltype(::Type{T}, acc::LogLikelihoodAccumulator) where {T} + return LogLikelihoodAccumulator(convert(T, acc.logp)) +end + +function default_accumulators( + ::Type{FloatT}=LogProbType, ::Type{IntT}=Int +) where {FloatT,IntT} + return AccumulatorTuple( + LogPriorAccumulator{FloatT}(), + LogLikelihoodAccumulator{FloatT}(), + NumProduceAccumulator{IntT}(), + ) +end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index a42855f05..1b5e9b8c4 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -51,7 +51,7 @@ $(FIELDS) ```jldoctest julia> using Distributions -julia> using DynamicPPL: LogDensityFunction, contextualize +julia> using DynamicPPL: LogDensityFunction, setaccs!! julia> @model function demo(x) m ~ Normal() @@ -78,8 +78,8 @@ julia> # By default it uses `VarInfo` under the hood, but this is not necessary. julia> LogDensityProblems.logdensity(f, [0.0]) -2.3378770664093453 -julia> # This also respects the context in `model`. - f_prior = LogDensityFunction(contextualize(model, DynamicPPL.PriorContext()), VarInfo(model)); +julia> # LogDensityFunction respects the accumulators in VarInfo: + f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),))); julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) true @@ -174,14 +174,26 @@ end Evaluate the log density of the given `model` at the given parameter values `x`, using the given `varinfo` and `context`. Note that the `varinfo` argument is provided -only for its structure, in the sense that the parameters from the vector `x` are inserted into -it, and its own parameters are discarded. +only for its structure, in the sense that the parameters from the vector `x` are inserted +into it, and its own parameters are discarded. It does, however, determine whether the log +prior, likelihood, or joint is returned, based on which accumulators are set in it. """ function logdensity_at( x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext ) varinfo_new = unflatten(varinfo, x) - return getlogp(last(evaluate!!(model, varinfo_new, context))) + varinfo_eval = last(evaluate!!(model, varinfo_new, context)) + has_prior = hasacc(varinfo_eval, Val(:LogPrior)) + has_likelihood = hasacc(varinfo_eval, Val(:LogLikelihood)) + if has_prior && has_likelihood + return getlogjoint(varinfo_eval) + elseif has_prior + return getlogprior(varinfo_eval) + elseif has_likelihood + return getloglikelihood(varinfo_eval) + else + error("LogDensityFunction: varinfo tracks neither log prior nor log likelihood") + end end ### LogDensityProblems interface diff --git a/src/model.jl b/src/model.jl index c7c4bdf57..3b93fa14d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -900,7 +900,7 @@ See also: [`evaluate_threadunsafe!!`](@ref) function evaluate_threadsafe!!(model, varinfo, context) wrapper = ThreadSafeVarInfo(resetlogp!!(varinfo)) result, wrapper_new = _evaluate!!(model, wrapper, context) - return result, setlogp!!(wrapper_new.varinfo, getlogp(wrapper_new)) + return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) end """ @@ -1010,7 +1010,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logprior`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - return getlogp(last(evaluate!!(model, varinfo, DefaultContext()))) + return getlogjoint(last(evaluate!!(model, varinfo, DefaultContext()))) end """ @@ -1057,7 +1057,14 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - return getlogp(last(evaluate!!(model, varinfo, PriorContext()))) + # Remove other accumulators from varinfo, since they are unnecessary. + logprioracc = if hasacc(varinfo, Val(:LogPrior)) + getacc(varinfo, Val(:LogPrior)) + else + LogPriorAccumulator() + end + varinfo = setaccs!!(deepcopy(varinfo), (logprioracc,)) + return getlogprior(last(evaluate!!(model, varinfo, DefaultContext()))) end """ @@ -1104,7 +1111,14 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - return getlogp(last(evaluate!!(model, varinfo, LikelihoodContext()))) + # Remove other accumulators from varinfo, since they are unnecessary. + loglikelihoodacc = if hasacc(varinfo, Val(:LogLikelihood)) + getacc(varinfo, Val(:LogLikelihood)) + else + LogLikelihoodAccumulator() + end + varinfo = setaccs!!(deepcopy(varinfo), (loglikelihoodacc,)) + return getloglikelihood(last(evaluate!!(model, varinfo, DefaultContext()))) end """ @@ -1358,7 +1372,7 @@ We can check that the log joint probability of the model accumulated in `vi` is ```jldoctest submodel-to_submodel julia> x = vi[@varname(a.x)]; -julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) +julia> getlogjoint(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) true ``` @@ -1422,7 +1436,7 @@ julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); -julia> getlogp(vi) ≈ logprior + loglikelihood +julia> getlogjoint(vi) ≈ logprior + loglikelihood true ``` diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index cb9ea4894..b6b97c8f9 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -1,142 +1,117 @@ -# Context version -struct PointwiseLogdensityContext{A,Ctx} <: AbstractContext - logdensities::A - context::Ctx -end +""" + PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <: AbstractAccumulator -function PointwiseLogdensityContext( - likelihoods=OrderedDict{VarName,Vector{Float64}}(), - context::AbstractContext=DefaultContext(), -) - return PointwiseLogdensityContext{typeof(likelihoods),typeof(context)}( - likelihoods, context - ) -end +An accumulator that stores the log-probabilities of each variable in a model. -NodeTrait(::PointwiseLogdensityContext) = IsParent() -childcontext(context::PointwiseLogdensityContext) = context.context -function setchildcontext(context::PointwiseLogdensityContext, child) - return PointwiseLogdensityContext(context.logdensities, child) -end +Internally this context stores the log-probabilities in a dictionary, where the keys are +the variable names and the values are vectors of log-probabilities. Each element in a vector +corresponds to one execution of the model. -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{VarName,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.logdensities - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) +`whichlogprob` is a symbol that can be `:both`, `:prior`, or `:likelihood`, and specifies +which log-probabilities to store in the accumulator. `KeyType` is the type by which variable +names are stored, and should be `String` or `VarName`. `D` is the type of the dictionary +used internally to store the log-probabilities, by default +`OrderedDict{KeyType, Vector{LogProbType}}`. +""" +struct PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <: + AbstractAccumulator + logps::D end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{VarName,Float64}}, - vn::VarName, - logp::Real, -) - return context.logdensities[vn] = logp +function PointwiseLogProbAccumulator{whichlogprob}(logps) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob,keytype(logps),typeof(logps)}(logps) end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.logdensities - ℓ = get!(lookup, string(vn), Float64[]) - return push!(ℓ, logp) +function PointwiseLogProbAccumulator{whichlogprob}() where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob,VarName}() end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, - vn::VarName, - logp::Real, -) - return context.logdensities[string(vn)] = logp +function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob,KeyType} + logps = OrderedDict{KeyType,Vector{LogProbType}}() + return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps) end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, - vn::String, - logp::Real, -) - lookup = context.logdensities - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) +function Base.push!(acc::PointwiseLogProbAccumulator, vn, logp) + logps = acc.logps + # The last(fieldtypes(eltype(...))) gets the type of the values, rather than the keys. + T = last(fieldtypes(eltype(logps))) + logpvec = get!(logps, vn, T()) + return push!(logpvec, logp) end function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, - vn::String, - logp::Real, -) - return context.logdensities[vn] = logp + acc::PointwiseLogProbAccumulator{whichlogprob,String}, vn::VarName, logp +) where {whichlogprob} + return push!(acc, string(vn), logp) end -function _include_prior(context::PointwiseLogdensityContext) - return leafcontext(context) isa Union{PriorContext,DefaultContext} -end -function _include_likelihood(context::PointwiseLogdensityContext) - return leafcontext(context) isa Union{LikelihoodContext,DefaultContext} +function accumulator_name( + ::Type{<:PointwiseLogProbAccumulator{whichlogprob}} +) where {whichlogprob} + return Symbol("PointwiseLogProbAccumulator{$whichlogprob}") end -function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) - # Defer literal `observe` to child-context. - return tilde_observe!!(context.context, right, left, vi) +function split(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob}(empty(acc.logps)) end -function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) - # Completely defer to child context if we are not tracking likelihoods. - if !(_include_likelihood(context)) - return tilde_observe!!(context.context, right, left, vn, vi) - end - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `tilde_observe!`. - logp, vi = tilde_observe(context.context, right, left, vi) - - # Track loglikelihood value. - push!(context, vn, logp) - - return left, acclogp!!(vi, logp) +function combine( + acc::PointwiseLogProbAccumulator{whichlogprob}, + acc2::PointwiseLogProbAccumulator{whichlogprob}, +) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob}(mergewith(vcat, acc.logps, acc2.logps)) end -# Note on submodels (penelopeysm) -# -# We don't need to overload tilde_observe!! for Sampleables (yet), because it -# is currently not possible to evaluate a model with a Sampleable on the RHS -# of an observe statement. -# -# Note that calling tilde_assume!! on a Sampleable does not necessarily imply -# that there are no observe statements inside the Sampleable. There could well -# be likelihood terms in there, which must be included in the returned logp. -# See e.g. the `demo_dot_assume_observe_submodel` demo model. -# -# This is handled by passing the same context to rand_like!!, which figures out -# which terms to include using the context, and also mutates the context and vi -# appropriately. Thus, we don't need to check against _include_prior(context) -# here. -function tilde_assume!!(context::PointwiseLogdensityContext, right::Sampleable, vn, vi) - value, vi = DynamicPPL.rand_like!!(right, context, vi) - return value, vi +function accumulate_assume!!( + acc::PointwiseLogProbAccumulator{whichlogprob}, val, logjac, vn, right +) where {whichlogprob} + if whichlogprob == :both || whichlogprob == :prior + # T is the element type of the vectors that are the values of `acc.logps`. Usually + # it's LogProbType. + T = eltype(last(fieldtypes(eltype(acc.logps)))) + subacc = accumulate_assume!!(LogPriorAccumulator{T}(), val, logjac, vn, right) + push!(acc, vn, subacc.logp) + end + return acc end -function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) - !_include_prior(context) && return (tilde_assume!!(context.context, right, vn, vi)) - value, logp, vi = tilde_assume(context.context, right, vn, vi) - # Track loglikelihood value. - push!(context, vn, logp) - return value, acclogp!!(vi, logp) +function accumulate_observe!!( + acc::PointwiseLogProbAccumulator{whichlogprob}, right, left, vn +) where {whichlogprob} + # If `vn` is nothing the LHS of ~ is a literal and we don't have a name to attach this + # acc to, and thus do nothing. + if vn === nothing + return acc + end + if whichlogprob == :both || whichlogprob == :likelihood + # T is the element type of the vectors that are the values of `acc.logps`. Usually + # it's LogProbType. + T = eltype(last(fieldtypes(eltype(acc.logps)))) + subacc = accumulate_observe!!(LogLikelihoodAccumulator{T}(), right, left, vn) + push!(acc, vn, subacc.logp) + end + return acc end """ - pointwise_logdensities(model::Model, chain::Chains, keytype = String) + pointwise_logdensities( + model::Model, + chain::Chains, + keytype=String, + context=DefaultContext(), + ::Val{whichlogprob}=Val(:both), + ) Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` with keys corresponding to symbols of the variables, and values being matrices of shape `(num_chains, num_samples)`. `keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -Currently, only `String` and `VarName` are supported. +Currently, only `String` and `VarName` are supported. `context` is the evaluation context, +and `whichlogprob` specifies which log-probabilities to compute. It can be `:both`, +`:prior`, or `:likelihood`. + +See also: [`pointwise_loglikelihoods`](@ref), [`pointwise_loglikelihoods`](@ref). # Notes Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` @@ -234,14 +209,19 @@ julia> m = demo([1.0; 1.0]); julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) (-1.4189385332046727, -1.4189385332046727) ``` - """ function pointwise_logdensities( - model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() -) where {T} + model::Model, + chain, + ::Type{KeyType}=String, + context::AbstractContext=DefaultContext(), + ::Val{whichlogprob}=Val(:both), +) where {KeyType,whichlogprob} # Get the data by executing the model once vi = VarInfo(model) - point_context = PointwiseLogdensityContext(OrderedDict{T,Vector{Float64}}(), context) + + AccType = PointwiseLogProbAccumulator{whichlogprob,KeyType} + vi = setaccs!!(vi, (AccType(),)) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) for (sample_idx, chain_idx) in iters @@ -249,26 +229,28 @@ function pointwise_logdensities( setval!(vi, chain, sample_idx, chain_idx) # Execute model - model(vi, point_context) + vi = last(evaluate!!(model, vi, context)) end + logps = getacc(vi, Val(accumulator_name(AccType))).logps niters = size(chain, 1) nchains = size(chain, 3) logdensities = OrderedDict( - varname => reshape(logliks, niters, nchains) for - (varname, logliks) in point_context.logdensities + varname => reshape(vals, niters, nchains) for (varname, vals) in logps ) return logdensities end function pointwise_logdensities( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() -) - point_context = PointwiseLogdensityContext( - OrderedDict{VarName,Vector{Float64}}(), context - ) - model(varinfo, point_context) - return point_context.logdensities + model::Model, + varinfo::AbstractVarInfo, + context::AbstractContext=DefaultContext(), + ::Val{whichlogprob}=Val(:both), +) where {whichlogprob} + AccType = PointwiseLogProbAccumulator{whichlogprob} + varinfo = setaccs!!(varinfo, (AccType(),)) + varinfo = last(evaluate!!(model, varinfo, context)) + return getacc(varinfo, Val(accumulator_name(AccType))).logps end """ @@ -277,29 +259,19 @@ end Compute the pointwise log-likelihoods of the model given the chain. This is the same as `pointwise_logdensities(model, chain, context)`, but only including the likelihood terms. -See also: [`pointwise_logdensities`](@ref). + +See also: [`pointwise_logdensities`](@ref), [`pointwise_prior_logdensities`](@ref). """ function pointwise_loglikelihoods( - model::Model, - chain, - keytype::Type{T}=String, - context::AbstractContext=LikelihoodContext(), + model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() ) where {T} - if !(leafcontext(context) isa LikelihoodContext) - throw(ArgumentError("Leaf context should be a LikelihoodContext")) - end - - return pointwise_logdensities(model, chain, T, context) + return pointwise_logdensities(model, chain, T, context, Val(:likelihood)) end function pointwise_loglikelihoods( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=LikelihoodContext() + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() ) - if !(leafcontext(context) isa LikelihoodContext) - throw(ArgumentError("Leaf context should be a LikelihoodContext")) - end - - return pointwise_logdensities(model, varinfo, context) + return pointwise_logdensities(model, varinfo, context, Val(:likelihood)) end """ @@ -308,24 +280,17 @@ end Compute the pointwise log-prior-densities of the model given the chain. This is the same as `pointwise_logdensities(model, chain, context)`, but only including the prior terms. -See also: [`pointwise_logdensities`](@ref). + +See also: [`pointwise_logdensities`](@ref), [`pointwise_loglikelihoods`](@ref). """ function pointwise_prior_logdensities( - model::Model, chain, keytype::Type{T}=String, context::AbstractContext=PriorContext() + model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() ) where {T} - if !(leafcontext(context) isa PriorContext) - throw(ArgumentError("Leaf context should be a PriorContext")) - end - - return pointwise_logdensities(model, chain, T, context) + return pointwise_logdensities(model, chain, T, context, Val(:prior)) end function pointwise_prior_logdensities( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=PriorContext() + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() ) - if !(leafcontext(context) isa PriorContext) - throw(ArgumentError("Leaf context should be a PriorContext")) - end - - return pointwise_logdensities(model, varinfo, context) + return pointwise_logdensities(model, varinfo, context, Val(:prior)) end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index abf14b8fc..42fcedfb8 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -125,18 +125,18 @@ Evaluation in transformed space of course also works: ```jldoctest simplevarinfo-general julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) -Transformed SimpleVarInfo((x = -1.0,), 0.0) +Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0))) julia> # (✓) Positive probability mass on negative numbers! - getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) + getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) -1.3678794411714423 julia> # While if we forget to indicate that it's transformed: vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) -SimpleVarInfo((x = -1.0,), 0.0) +SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0))) julia> # (✓) No probability mass on negative numbers! - getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) + getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) -Inf ``` @@ -188,41 +188,37 @@ ERROR: type NamedTuple has no field b [...] ``` """ -struct SimpleVarInfo{NT,T,C<:AbstractTransformation} <: AbstractVarInfo +struct SimpleVarInfo{NT,Accs<:AccumulatorTuple where {N},C<:AbstractTransformation} <: + AbstractVarInfo "underlying representation of the realization represented" values::NT - "holds the accumulated log-probability" - logp::T + "tuple of accumulators for things like log prior and log likelihood" + accs::Accs "represents whether it assumes variables to be transformed" transformation::C end transformation(vi::SimpleVarInfo) = vi.transformation -# Makes things a bit more readable vs. putting `Float64` everywhere. -const SIMPLEVARINFO_DEFAULT_ELTYPE = Float64 - -function SimpleVarInfo{NT,T}(values, logp) where {NT,T} - return SimpleVarInfo{NT,T,NoTransformation}(values, logp, NoTransformation()) +function SimpleVarInfo(values, accs) + return SimpleVarInfo(values, accs, NoTransformation()) end -function SimpleVarInfo{T}(θ) where {T<:Real} - return SimpleVarInfo{typeof(θ),T}(θ, zero(T)) +function SimpleVarInfo{T}(values) where {T<:Real} + return SimpleVarInfo(values, default_accumulators(T)) end - -# Constructors without type-specification. -SimpleVarInfo(θ) = SimpleVarInfo{SIMPLEVARINFO_DEFAULT_ELTYPE}(θ) -function SimpleVarInfo(θ::Union{<:NamedTuple,<:AbstractDict}) - return if isempty(θ) +function SimpleVarInfo(values) + return SimpleVarInfo{LogProbType}(values) +end +function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict}) + return if isempty(values) # Can't infer from values, so we just use default. - SimpleVarInfo{SIMPLEVARINFO_DEFAULT_ELTYPE}(θ) + SimpleVarInfo{LogProbType}(values) else # Infer from `values`. - SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(θ)))}(θ) + SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(values)))}(values) end end -SimpleVarInfo(values, logp) = SimpleVarInfo{typeof(values),typeof(logp)}(values, logp) - # Using `kwargs` to specify the values. function SimpleVarInfo{T}(; kwargs...) where {T<:Real} return SimpleVarInfo{T}(NamedTuple(kwargs)) @@ -235,7 +231,7 @@ end function SimpleVarInfo( model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... ) - return SimpleVarInfo{Float64}(model, args...) + return SimpleVarInfo{LogProbType}(model, args...) end function SimpleVarInfo{T}( model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... @@ -244,14 +240,14 @@ function SimpleVarInfo{T}( end # Constructor from `VarInfo`. -function SimpleVarInfo(vi::NTVarInfo, (::Type{D})=NamedTuple; kwargs...) where {D} - return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...) +function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D} + values = values_as(vi, D) + return SimpleVarInfo(values, deepcopy(getaccs(vi))) end -function SimpleVarInfo{T}( - vi::VarInfo{<:NamedTuple{names}}, ::Type{D} -) where {T<:Real,names,D} +function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D} values = values_as(vi, D) - return SimpleVarInfo(values, convert(T, getlogp(vi))) + accs = map(acc -> convert_eltype(T, acc), getaccs(vi)) + return SimpleVarInfo(values, accs) end function untyped_simple_varinfo(model::Model) @@ -265,12 +261,16 @@ function typed_simple_varinfo(model::Model) end function unflatten(svi::SimpleVarInfo, x::AbstractVector) - logp = getlogp(svi) vals = unflatten(svi.values, x) - T = eltype(x) - return SimpleVarInfo{typeof(vals),T,typeof(svi.transformation)}( - vals, T(logp), svi.transformation + # TODO(mhauru) See comment in unflatten in src/varinfo.jl for why this conversion is + # required but undesireable. + # The below line is finicky for type stability. For instance, assigning the eltype to + # convert to into an intermediate variable makes this unstable (constant propagation) + # fails. Take care when editing. + accs = map( + acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), getaccs(svi) ) + return SimpleVarInfo(vals, accs, svi.transformation) end function BangBang.empty!!(vi::SimpleVarInfo) @@ -278,21 +278,8 @@ function BangBang.empty!!(vi::SimpleVarInfo) end Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) -getlogp(vi::SimpleVarInfo) = vi.logp -getlogp(vi::SimpleVarInfo{<:Any,<:Ref}) = vi.logp[] - -setlogp!!(vi::SimpleVarInfo, logp) = Accessors.@set vi.logp = logp -acclogp!!(vi::SimpleVarInfo, logp) = Accessors.@set vi.logp = getlogp(vi) + logp - -function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) - vi.logp[] = logp - return vi -end - -function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) - vi.logp[] += logp - return vi -end +getaccs(vi::SimpleVarInfo) = vi.accs +setaccs!!(vi::SimpleVarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs """ keys(vi::SimpleVarInfo) @@ -302,12 +289,12 @@ Return an iterator of keys present in `vi`. Base.keys(vi::SimpleVarInfo) = keys(vi.values) Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values)) -function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) +function Base.show(io::IO, mime::MIME"text/plain", svi::SimpleVarInfo) if !(svi.transformation isa NoTransformation) print(io, "Transformed ") end - return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")") + return print(io, "SimpleVarInfo(", svi.values, ", ", repr(mime, getaccs(svi)), ")") end function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) @@ -454,11 +441,11 @@ _subset(x::VarNamedVector, vns) = subset(x, vns) # `merge` function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) values = merge(varinfo_left.values, varinfo_right.values) - logp = getlogp(varinfo_right) + accs = deepcopy(getaccs(varinfo_right)) transformation = merge_transformations( varinfo_left.transformation, varinfo_right.transformation ) - return SimpleVarInfo(values, logp, transformation) + return SimpleVarInfo(values, accs, transformation) end # Context implementations @@ -473,9 +460,11 @@ function assume( ) value = init(rng, dist, sampler) # Transform if we're working in unconstrained space. - value_raw = to_maybe_linked_internal(vi, vn, dist, value) + f = to_maybe_linked_internal_transform(vi, vn, dist) + value_raw, logjac = with_logabsdet_jacobian(f, value) vi = BangBang.push!!(vi, vn, value_raw, dist) - return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi + vi = accumulate_assume!!(vi, value, -logjac, vn, dist) + return value, vi end # NOTE: We don't implement `settrans!!(vi, trans, vn)`. @@ -497,8 +486,8 @@ islinked(vi::SimpleVarInfo) = istrans(vi) values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values -function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T} - isempty(vi) && return T[] +function values_as(vi::SimpleVarInfo, ::Type{Vector}) + isempty(vi) && return Any[] return mapreduce(tovec, vcat, values(vi.values)) end function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict} @@ -613,12 +602,11 @@ function link!!( vi::SimpleVarInfo{<:NamedTuple}, ::Model, ) - # TODO: Make sure that `spl` is respected. b = inverse(t.bijector) x = vi.values y, logjac = with_logabsdet_jacobian(b, x) - lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(Accessors.@set(vi.values = y), lp_new) + vi_new = Accessors.@set(vi.values = y) + vi_new = acclogprior!!(vi_new, -logjac) return settrans!!(vi_new, t) end @@ -627,12 +615,11 @@ function invlink!!( vi::SimpleVarInfo{<:NamedTuple}, ::Model, ) - # TODO: Make sure that `spl` is respected. b = t.bijector y = vi.values x, logjac = with_logabsdet_jacobian(b, y) - lp_new = getlogp(vi) + logjac - vi_new = setlogp!!(Accessors.@set(vi.values = x), lp_new) + vi_new = Accessors.@set(vi.values = x) + vi_new = acclogprior!!(vi_new, logjac) return settrans!!(vi_new, NoTransformation()) end @@ -645,13 +632,4 @@ function from_linked_internal_transform(vi::SimpleVarInfo, ::VarName, dist) return invlink_transform(dist) end -# Threadsafe stuff. -# For `SimpleVarInfo` we don't really need `Ref` so let's not use it. -function ThreadSafeVarInfo(vi::SimpleVarInfo) - return ThreadSafeVarInfo(vi, zeros(typeof(getlogp(vi)), Threads.nthreads())) -end -function ThreadSafeVarInfo(vi::SimpleVarInfo{<:Any,<:Ref}) - return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()]) -end - has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 5f1ec95ec..bd08b427e 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -45,7 +45,7 @@ We can check that the log joint probability of the model accumulated in `vi` is ```jldoctest submodel julia> x = vi[@varname(x)]; -julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) +julia> getlogjoint(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) true ``` """ @@ -124,7 +124,7 @@ julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); -julia> getlogp(vi) ≈ logprior + loglikelihood +julia> getlogjoint(vi) ≈ logprior + loglikelihood true ``` diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 7404a9af7..08acdfada 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -3,34 +3,6 @@ # # Utilities for testing contexts. -""" -Context that multiplies each log-prior by mod -used to test whether varwise_logpriors respects child-context. -""" -struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext - mod::T - context::Ctx -end -function TestLogModifyingChildContext( - mod=1.2, context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext() -) - return TestLogModifyingChildContext{typeof(mod),typeof(context)}(mod, context) -end - -DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context -function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child) - return TestLogModifyingChildContext(context.mod, child) -end -function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi) - value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) - return value, logp * context.mod, vi -end -function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) - logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) - return logp * context.mod, vi -end - # Dummy context to test nested behaviors. struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext context::C @@ -61,7 +33,7 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod # To see change, let's make sure we're using a different leaf context than the current. leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext - PriorContext() + DynamicPPL.DynamicTransformationContext{false}() else DefaultContext() end diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index e29614982..12f88acad 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -148,7 +148,7 @@ Simple model for which [`default_transformation`](@ref) returns a [`StaticTransf 1.5 ~ Normal(m, sqrt(s)) 2.0 ~ Normal(m, sqrt(s)) - return (; s, m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s, m, x=[1.5, 2.0]) end function DynamicPPL.default_transformation(::Model{typeof(demo_static_transformation)}) @@ -194,7 +194,7 @@ end m ~ product_distribution(Normal.(0, sqrt.(s))) x ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -225,7 +225,7 @@ end end x ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_index_observe)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -248,7 +248,7 @@ end m ~ MvNormal(zero(x), Diagonal(s)) x ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_multivariate_observe)}, s, m) s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) @@ -279,7 +279,7 @@ end x[i] ~ Normal(m[i], sqrt(s[i])) end - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_index)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -304,7 +304,7 @@ end m ~ Normal(0, sqrt(s)) x .~ Normal(m, sqrt(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_dot_observe)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) @@ -327,7 +327,7 @@ end m ~ MvNormal(zeros(2), Diagonal(s)) [1.5, 2.0] ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m) s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) @@ -358,7 +358,7 @@ end 1.5 ~ Normal(m[1], sqrt(s[1])) 2.0 ~ Normal(m[2], sqrt(s[2])) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -384,7 +384,7 @@ end 1.5 ~ Normal(m, sqrt(s)) 2.0 ~ Normal(m, sqrt(s)) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) @@ -407,7 +407,7 @@ end m ~ Normal(0, sqrt(s)) [1.5, 2.0] .~ Normal(m, sqrt(s)) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_assume_dot_observe_literal)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) @@ -440,7 +440,7 @@ end 1.5 ~ Normal(m[1], sqrt(s[1])) 2.0 ~ Normal(m[2], sqrt(s[2])) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true( model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m @@ -476,9 +476,9 @@ end # Submodel likelihood # With to_submodel, we have to have a left-hand side variable to # capture the result, so we just use a dummy variable - _ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x)) + _ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x), false) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -505,7 +505,7 @@ end x[:, 1] ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_matrix_index)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -535,7 +535,7 @@ end x[:, 1] ~ MvNormal(m, Diagonal(s_vec)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}, s, m) n = length(model.args.x) diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 539872143..07a308c7a 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -37,12 +37,6 @@ function setup_varinfos( svi_untyped = SimpleVarInfo(OrderedDict()) svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector()) - # SimpleVarInfo{<:Any,<:Ref} - svi_typed_ref = SimpleVarInfo(example_values, Ref(getlogp(svi_typed))) - svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped))) - svi_vnv_ref = SimpleVarInfo(DynamicPPL.VarNamedVector(), Ref(getlogp(svi_vnv))) - - lp = getlogp(vi_typed_metadata) varinfos = map(( vi_untyped_metadata, vi_untyped_vnv, @@ -51,12 +45,10 @@ function setup_varinfos( svi_typed, svi_untyped, svi_vnv, - svi_typed_ref, - svi_untyped_ref, - svi_vnv_ref, )) do vi - # Set them all to the same values. - DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp) + # Set them all to the same values and evaluate logp. + vi = update_values!!(vi, example_values, varnames) + last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) end if include_threadsafe diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 2dc2645de..7d2d768a6 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -2,69 +2,79 @@ ThreadSafeVarInfo A `ThreadSafeVarInfo` object wraps an [`AbstractVarInfo`](@ref) object and an -array of log probabilities for thread-safe execution of a probabilistic model. +array of accumulators for thread-safe execution of a probabilistic model. """ -struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo +struct ThreadSafeVarInfo{V<:AbstractVarInfo,L<:AccumulatorTuple} <: AbstractVarInfo varinfo::V - logps::L + accs_by_thread::Vector{L} end function ThreadSafeVarInfo(vi::AbstractVarInfo) - return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()]) + accs_by_thread = [map(split, getaccs(vi)) for _ in 1:Threads.nthreads()] + return ThreadSafeVarInfo(vi, accs_by_thread) end ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi -const ThreadSafeVarInfoWithRef{V<:AbstractVarInfo} = ThreadSafeVarInfo{ - V,<:AbstractArray{<:Ref} -} - transformation(vi::ThreadSafeVarInfo) = transformation(vi.varinfo) -# Instead of updating the log probability of the underlying variables we -# just update the array of log probabilities. -function acclogp!!(vi::ThreadSafeVarInfo, logp) - vi.logps[Threads.threadid()] += logp - return vi +# Set the accumulator in question in vi.varinfo, and set the thread-specific +# accumulators of the same type to be empty. +function setacc!!(vi::ThreadSafeVarInfo, acc::AbstractAccumulator) + inner_vi = setacc!!(vi.varinfo, acc) + news_accs_by_thread = map(accs -> setacc!!(accs, split(acc)), vi.accs_by_thread) + return ThreadSafeVarInfo(inner_vi, news_accs_by_thread) end -function acclogp!!(vi::ThreadSafeVarInfoWithRef, logp) - vi.logps[Threads.threadid()][] += logp - return vi + +# Get both the main accumulator and the thread-specific accumulators of the same type and +# combine them. +function getacc(vi::ThreadSafeVarInfo, accname::Val) + main_acc = getacc(vi.varinfo, accname) + other_accs = map(accs -> getacc(accs, accname), vi.accs_by_thread) + return foldl(combine, other_accs; init=main_acc) end -# The current log probability of the variables has to be computed from -# both the wrapped variables and the thread-specific log probabilities. -getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(vi.logps) -getlogp(vi::ThreadSafeVarInfoWithRef) = getlogp(vi.varinfo) + sum(getindex, vi.logps) +hasacc(vi::ThreadSafeVarInfo, accname::Val) = hasacc(vi.varinfo, accname) +acckeys(vi::ThreadSafeVarInfo) = acckeys(vi.varinfo) -# TODO: Make remaining methods thread-safe. -function resetlogp!!(vi::ThreadSafeVarInfo) - return ThreadSafeVarInfo(resetlogp!!(vi.varinfo), zero(vi.logps)) +function getaccs(vi::ThreadSafeVarInfo) + # This method is a bit finicky to maintain type stability. For instance, moving the + # accname -> Val(accname) part in the main `map` call makes constant propagation fail + # and this becomes unstable. Do check the effects if you make edits. + accnames = acckeys(vi) + accname_vals = map(Val, accnames) + return AccumulatorTuple(map(anv -> getacc(vi, anv), accname_vals)) end -function resetlogp!!(vi::ThreadSafeVarInfoWithRef) - for x in vi.logps - x[] = zero(x[]) - end - return ThreadSafeVarInfo(resetlogp!!(vi.varinfo), vi.logps) -end -function setlogp!!(vi::ThreadSafeVarInfo, logp) - return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), zero(vi.logps)) + +# Calls to map_accumulator(s)!! are thread-specific by default. For any use of them that +# should _not_ be thread-specific a specific method has to be written. +function map_accumulator!!(func::Function, vi::ThreadSafeVarInfo, accname::Val) + tid = Threads.threadid() + vi.accs_by_thread[tid] = map_accumulator(func, vi.accs_by_thread[tid], accname) + return vi end -function setlogp!!(vi::ThreadSafeVarInfoWithRef, logp) - for x in vi.logps - x[] = zero(x[]) - end - return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), vi.logps) + +function map_accumulators!!(func::Function, vi::ThreadSafeVarInfo) + tid = Threads.threadid() + vi.accs_by_thread[tid] = map(func, vi.accs_by_thread[tid]) + return vi end -has_varnamedvector(vi::DynamicPPL.ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) +has_varnamedvector(vi::ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) function BangBang.push!!(vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution) return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist) end +# TODO(mhauru) Why these short-circuits? Why not use the thread-specific ones? get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo) -increment_num_produce!(vi::ThreadSafeVarInfo) = increment_num_produce!(vi.varinfo) -reset_num_produce!(vi::ThreadSafeVarInfo) = reset_num_produce!(vi.varinfo) -set_num_produce!(vi::ThreadSafeVarInfo, n::Int) = set_num_produce!(vi.varinfo, n) +function increment_num_produce!!(vi::ThreadSafeVarInfo) + return ThreadSafeVarInfo(increment_num_produce!!(vi.varinfo), vi.accs_by_thread) +end +function reset_num_produce!!(vi::ThreadSafeVarInfo) + return ThreadSafeVarInfo(reset_num_produce!!(vi.varinfo), vi.accs_by_thread) +end +function set_num_produce!!(vi::ThreadSafeVarInfo, n::Int) + return ThreadSafeVarInfo(set_num_produce!!(vi.varinfo, n), vi.accs_by_thread) +end syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) @@ -94,8 +104,8 @@ end # Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. # NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure -# consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates -# to define `getlogp(vi)`. +# consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates +# to define `getacc(vi)`. function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end @@ -130,9 +140,9 @@ end function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model) # Defer to the wrapped `AbstractVarInfo` object. - # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the - # `getlogp(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in - # the `getlogp(vi)`. + # NOTE: When computing `getacc` for `ThreadSafeVarInfo` we do include the + # `getacc(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in + # the `getlogprior(vi)`. return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(vi.varinfo, model) end @@ -169,6 +179,23 @@ function BangBang.empty!!(vi::ThreadSafeVarInfo) return resetlogp!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo))) end +function resetlogp!!(vi::ThreadSafeVarInfo) + vi = Accessors.@set vi.varinfo = resetlogp!!(vi.varinfo) + for i in eachindex(vi.accs_by_thread) + if hasacc(vi, Val(:LogPrior)) + vi.accs_by_thread[i] = map_accumulator( + zero, vi.accs_by_thread[i], Val(:LogPrior) + ) + end + if hasacc(vi, Val(:LogLikelihood)) + vi.accs_by_thread[i] = map_accumulator( + zero, vi.accs_by_thread[i], Val(:LogLikelihood) + ) + end + end + return vi +end + values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo) values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) diff --git a/src/transforming.jl b/src/transforming.jl index 429562ec8..ddd1ab59f 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -27,18 +27,47 @@ function tilde_assume( # Only transform if `!isinverse` since `vi[vn, right]` # already performs the inverse transformation if it's transformed. r_transformed = isinverse ? r : link_transform(right)(r) - return r, lp, setindex!!(vi, r_transformed, vn) + if hasacc(vi, Val(:LogPrior)) + vi = acclogprior!!(vi, lp) + end + return r, setindex!!(vi, r_transformed, vn) +end + +function tilde_observe!!(::DynamicTransformationContext, right, left, vn, vi) + return tilde_observe!!(DefaultContext(), right, left, vn, vi) end function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) + return _transform!!(t, DynamicTransformationContext{false}(), vi, model) end function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return settrans!!( - last(evaluate!!(model, vi, DynamicTransformationContext{true}())), - NoTransformation(), - ) + return _transform!!(NoTransformation(), DynamicTransformationContext{true}(), vi, model) +end + +function _transform!!( + t::AbstractTransformation, + ctx::DynamicTransformationContext, + vi::AbstractVarInfo, + model::Model, +) + # To transform using DynamicTransformationContext, we evaluate the model, but we do not + # need to use any accumulators other than LogPriorAccumulator (which is affected by the Jacobian of + # the transformation). + accs = getaccs(vi) + has_logprior = haskey(accs, Val(:LogPrior)) + if has_logprior + old_logprior = getacc(accs, Val(:LogPrior)) + vi = setaccs!!(vi, (old_logprior,)) + end + vi = settrans!!(last(evaluate!!(model, vi, ctx)), t) + # Restore the accumulators. + if has_logprior + new_logprior = getacc(vi, Val(:LogPrior)) + accs = setacc!!(accs, new_logprior) + end + vi = setaccs!!(vi, accs) + return vi end function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) diff --git a/src/utils.jl b/src/utils.jl index 71919480c..9a9f39ede 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -18,23 +18,29 @@ const LogProbType = float(Real) """ @addlogprob!(ex) -Add the result of the evaluation of `ex` to the joint log probability. +Add a term to the log joint. -# Examples +If `ex` evaluates to a `NamedTuple` with keys `:loglikelihood` and/or `:logprior`, the +values are added to the log likelihood and log prior respectively. + +If `ex` evaluates to a number it is added to the log likelihood. -This macro allows you to [include arbitrary terms in the likelihood](https://github.com/TuringLang/Turing.jl/issues/1332) +# Examples ```jldoctest; setup = :(using Distributions) -julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x); +julia> mylogjoint(x, μ) = (; loglikelihood=loglikelihood(Normal(μ, 1), x), logprior=1.0); julia> @model function demo(x) μ ~ Normal() - @addlogprob! myloglikelihood(x, μ) + @addlogprob! mylogjoint(x, μ) end; julia> x = [1.3, -2.1]; -julia> loglikelihood(demo(x), (μ=0.2,)) ≈ myloglikelihood(x, 0.2) +julia> loglikelihood(demo(x), (μ=0.2,)) ≈ mylogjoint(x, 0.2).loglikelihood +true + +julia> logprior(demo(x), (μ=0.2,)) ≈ logpdf(Normal(), 0.2) + mylogjoint(x, 0.2).logprior true ``` @@ -44,7 +50,7 @@ and to [reject samples](https://github.com/TuringLang/Turing.jl/issues/1328): julia> @model function demo(x) m ~ MvNormal(zero(x), I) if dot(m, x) < 0 - @addlogprob! -Inf + @addlogprob! (; loglikelihood=-Inf) # Exit the model evaluation early return end @@ -55,37 +61,22 @@ julia> @model function demo(x) julia> logjoint(demo([-2.1]), (m=[0.2],)) == -Inf true ``` - -!!! note - The `@addlogprob!` macro increases the accumulated log probability regardless of the evaluation context, - i.e., regardless of whether you evaluate the log prior, the log likelihood or the log joint density. - If you would like to avoid this behaviour you should check the evaluation context. - It can be accessed with the internal variable `__context__`. - For instance, in the following example the log density is not accumulated when only the log prior is computed: - ```jldoctest; setup = :(using Distributions) - julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x); - - julia> @model function demo(x) - μ ~ Normal() - if DynamicPPL.leafcontext(__context__) !== PriorContext() - @addlogprob! myloglikelihood(x, μ) - end - end; - - julia> x = [1.3, -2.1]; - - julia> logprior(demo(x), (μ=0.2,)) ≈ logpdf(Normal(), 0.2) - true - - julia> loglikelihood(demo(x), (μ=0.2,)) ≈ myloglikelihood(x, 0.2) - true - ``` """ macro addlogprob!(ex) return quote - $(esc(:(__varinfo__))) = acclogp!!( - $(esc(:(__context__))), $(esc(:(__varinfo__))), $(esc(ex)) - ) + val = $(esc(ex)) + vi = $(esc(:(__varinfo__))) + if val isa Number + if hasacc(vi, Val(:LogLikelihood)) + $(esc(:(__varinfo__))) = accloglikelihood!!($(esc(:(__varinfo__))), val) + end + elseif val isa NamedTuple + $(esc(:(__varinfo__))) = acclogp!!( + $(esc(:(__varinfo__))), val; ignore_missing_accumulator=true + ) + else + error("logp must be a Number or a NamedTuple.") + end end end diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index d3bfd697a..3ec474940 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -65,29 +65,24 @@ end function tilde_assume(context::ValuesAsInModelContext, right, vn, vi) if is_tracked_value(right) value = right.value - logp = zero(getlogp(vi)) else - value, logp, vi = tilde_assume(childcontext(context), right, vn, vi) + value, vi = tilde_assume(childcontext(context), right, vn, vi) end - # Save the value. push!(context, vn, value) - # Save the value. - # Pass on. - return value, logp, vi + return value, vi end function tilde_assume( rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi ) if is_tracked_value(right) value = right.value - logp = zero(getlogp(vi)) else - value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi) + value, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi) end # Save the value. push!(context, vn, value) # Pass on. - return value, logp, vi + return value, vi end """ diff --git a/src/varinfo.jl b/src/varinfo.jl index 360857ef7..6a968da4d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -69,10 +69,9 @@ end ########### """ - struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo + struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo metadata::Tmeta - logp::Base.RefValue{Tlogp} - num_produce::Base.RefValue{Int} + accs::Accs end A light wrapper over some kind of metadata. @@ -98,12 +97,14 @@ Note that for NTVarInfo, it is the user's responsibility to ensure that each symbol is visited at least once during model evaluation, regardless of any stochastic branching. """ -struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo +struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo metadata::Tmeta - logp::Base.RefValue{Tlogp} - num_produce::Base.RefValue{Int} + accs::Accs end -VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0)) +function VarInfo(meta=Metadata()) + return VarInfo(meta, default_accumulators()) +end + """ VarInfo([rng, ]model[, sampler, context]) @@ -285,10 +286,8 @@ function typed_varinfo(vi::UntypedVarInfo) ), ) end - logp = getlogp(vi) - num_produce = get_num_produce(vi) nt = NamedTuple{syms_tuple}(Tuple(new_metas)) - return VarInfo(nt, Ref(logp), Ref(num_produce)) + return VarInfo(nt, deepcopy(vi.accs)) end function typed_varinfo(vi::NTVarInfo) # This function preserves the behaviour of typed_varinfo(vi) where vi is @@ -349,8 +348,7 @@ single `VarNamedVector` as its metadata field. """ function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) - lp = getlogp(vi) - return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) + return VarInfo(md, deepcopy(vi.accs)) end function untyped_vector_varinfo( rng::Random.AbstractRNG, @@ -393,15 +391,12 @@ NamedTuple of `VarNamedVector`s as its metadata field. """ function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) - lp = getlogp(vi) - return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) + return VarInfo(md, deepcopy(vi.accs)) end function typed_vector_varinfo(vi::UntypedVectorVarInfo) new_metas = group_by_symbol(vi.metadata) - logp = getlogp(vi) - num_produce = get_num_produce(vi) nt = NamedTuple(new_metas) - return VarInfo(nt, Ref(logp), Ref(num_produce)) + return VarInfo(nt, deepcopy(vi.accs)) end function typed_vector_varinfo( rng::Random.AbstractRNG, @@ -441,13 +436,22 @@ vector_length(md::Metadata) = sum(length, md.ranges) function unflatten(vi::VarInfo, x::AbstractVector) md = unflatten_metadata(vi.metadata, x) - # Note that use of RefValue{eltype(x)} rather than Ref is necessary to deal with cases - # where e.g. x is a type gradient of some AD backend. - return VarInfo( - md, - Base.RefValue{float_type_with_fallback(eltype(x))}(getlogp(vi)), - Ref(get_num_produce(vi)), + # Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is + # a gradient type of some AD backend. + # TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!! + # for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but + # the accumulators in the VarInfo are plain floats, we error since we can't change the + # element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here + # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just + # plain ugly and hacky. + # The below line is finicky for type stability. For instance, assigning the eltype to + # convert to into an intermediate variable makes this unstable (constant propagation) + # fails. Take care when editing. + accs = map( + acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), + deepcopy(getaccs(vi)), ) + return VarInfo(md, accs) end # We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in @@ -529,7 +533,7 @@ end function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) metadata = subset(varinfo.metadata, vns) - return VarInfo(metadata, deepcopy(varinfo.logp), deepcopy(varinfo.num_produce)) + return VarInfo(metadata, deepcopy(varinfo.accs)) end function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName}) @@ -618,9 +622,7 @@ end function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) - return VarInfo( - metadata, Ref(getlogp(varinfo_right)), Ref(get_num_produce(varinfo_right)) - ) + return VarInfo(metadata, deepcopy(varinfo_right.accs)) end function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector) @@ -973,8 +975,8 @@ end function BangBang.empty!!(vi::VarInfo) _empty!(vi.metadata) - resetlogp!!(vi) - reset_num_produce!(vi) + vi = resetlogp!!(vi) + vi = reset_num_produce!!(vi) return vi end @@ -1008,46 +1010,8 @@ end istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn) istrans(md::Metadata, vn::VarName) = is_flagged(md, vn, "trans") -getlogp(vi::VarInfo) = vi.logp[] - -function setlogp!!(vi::VarInfo, logp) - vi.logp[] = logp - return vi -end - -function acclogp!!(vi::VarInfo, logp) - vi.logp[] += logp - return vi -end - -""" - get_num_produce(vi::VarInfo) - -Return the `num_produce` of `vi`. -""" -get_num_produce(vi::VarInfo) = vi.num_produce[] - -""" - set_num_produce!(vi::VarInfo, n::Int) - -Set the `num_produce` field of `vi` to `n`. -""" -set_num_produce!(vi::VarInfo, n::Int) = vi.num_produce[] = n - -""" - increment_num_produce!(vi::VarInfo) - -Add 1 to `num_produce` in `vi`. -""" -increment_num_produce!(vi::VarInfo) = vi.num_produce[] += 1 - -""" - reset_num_produce!(vi::VarInfo) - -Reset the value of `num_produce` the log of the joint probability of the observed data -and parameters sampled in `vi` to 0. -""" -reset_num_produce!(vi::VarInfo) = set_num_produce!(vi, 0) +getaccs(vi::VarInfo) = vi.accs +setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs # Need to introduce the _isempty to avoid type piracy of isempty(::NamedTuple). isempty(vi::VarInfo) = _isempty(vi.metadata) @@ -1061,7 +1025,7 @@ function link!!(::DynamicTransformation, vi::NTVarInfo, model::Model) vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) - _link!(vi, vns) + vi = _link!!(vi, vns) return vi end @@ -1069,7 +1033,7 @@ function link!!(::DynamicTransformation, vi::VarInfo, model::Model) vns = keys(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) - _link!(vi, vns) + vi = _link!!(vi, vns) return vi end @@ -1082,8 +1046,7 @@ end function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) - # Call `_link!` instead of `link!` to avoid deprecation warning. - _link!(vi, vns) + vi = _link!!(vi, vns) return vi end @@ -1098,27 +1061,28 @@ function link!!( return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) end -function _link!(vi::UntypedVarInfo, vns) +function _link!!(vi::UntypedVarInfo, vns) # TODO: Change to a lazy iterator over `vns` if ~istrans(vi, vns[1]) for vn in vns f = internal_to_linked_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, true, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, true, vn) end + return vi else @warn("[DynamicPPL] attempt to link a linked vi") end end -# If we try to _link! a NTVarInfo with a Tuple of VarNames, first convert it to a +# If we try to _link!! a NTVarInfo with a Tuple of VarNames, first convert it to a # NamedTuple that matches the structure of the NTVarInfo. -function _link!(vi::NTVarInfo, vns::VarNameTuple) - return _link!(vi, group_varnames_by_symbol(vns)) +function _link!!(vi::NTVarInfo, vns::VarNameTuple) + return _link!!(vi, group_varnames_by_symbol(vns)) end -function _link!(vi::NTVarInfo, vns::NamedTuple) - return _link!(vi.metadata, vi, vns) +function _link!!(vi::NTVarInfo, vns::NamedTuple) + return _link!!(vi.metadata, vi, vns) end """ @@ -1130,7 +1094,7 @@ function filter_subsumed(filter_vns, filtered_vns) return filter(x -> any(subsumes(y, x) for y in filter_vns), filtered_vns) end -@generated function _link!( +@generated function _link!!( ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} ) where {metadata_names,vns_names} expr = Expr(:block) @@ -1148,8 +1112,8 @@ end # Iterate over all `f_vns` and transform for vn in f_vns f = internal_to_linked_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, true, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") @@ -1158,6 +1122,7 @@ end end, ) end + push!(expr.args, :(return vi)) return expr end @@ -1165,8 +1130,7 @@ function invlink!!(::DynamicTransformation, vi::NTVarInfo, model::Model) vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) - # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. - _invlink!(vi, vns) + vi = _invlink!!(vi, vns) return vi end @@ -1174,7 +1138,7 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) vns = keys(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) - _invlink!(vi, vns) + vi = _invlink!!(vi, vns) return vi end @@ -1187,8 +1151,7 @@ end function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) - # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. - _invlink!(vi, vns) + vi = _invlink!!(vi, vns) return vi end @@ -1211,29 +1174,30 @@ function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) return maybe_invlink_before_eval!!(t, vi, model) end -function _invlink!(vi::UntypedVarInfo, vns) +function _invlink!!(vi::UntypedVarInfo, vns) if istrans(vi, vns[1]) for vn in vns f = linked_internal_to_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, false, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, false, vn) end + return vi else @warn("[DynamicPPL] attempt to invlink an invlinked vi") end end -# If we try to _invlink! a NTVarInfo with a Tuple of VarNames, first convert it to a +# If we try to _invlink!! a NTVarInfo with a Tuple of VarNames, first convert it to a # NamedTuple that matches the structure of the NTVarInfo. -function _invlink!(vi::NTVarInfo, vns::VarNameTuple) - return _invlink!(vi.metadata, vi, group_varnames_by_symbol(vns)) +function _invlink!!(vi::NTVarInfo, vns::VarNameTuple) + return _invlink!!(vi.metadata, vi, group_varnames_by_symbol(vns)) end -function _invlink!(vi::NTVarInfo, vns::NamedTuple) - return _invlink!(vi.metadata, vi, vns) +function _invlink!!(vi::NTVarInfo, vns::NamedTuple) + return _invlink!!(vi.metadata, vi, vns) end -@generated function _invlink!( +@generated function _invlink!!( ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} ) where {metadata_names,vns_names} expr = Expr(:block) @@ -1251,8 +1215,8 @@ end # Iterate over all `f_vns` and transform for vn in f_vns f = linked_internal_to_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, false, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") @@ -1260,6 +1224,7 @@ end end, ) end + push!(expr.args, :(return vi)) return expr end @@ -1276,7 +1241,7 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) setrange!(md, vn, start:(start + length(yvec) - 1)) # Set the new value. setval!(md, yvec, vn) - acclogp!!(vi, -logjac) + vi = acclogprior!!(vi, -logjac) return vi end @@ -1311,8 +1276,10 @@ end function _link(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) - md = _link_metadata!!(model, varinfo, varinfo.metadata, vns) - return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) + md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + return new_varinfo end # If we try to _link a NTVarInfo with a Tuple of VarNames, first convert it to a @@ -1323,8 +1290,10 @@ end function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - md = _link_metadata!(model, varinfo, varinfo.metadata, vns) - return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) + md, logjac = _link_metadata!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + return new_varinfo end @generated function _link_metadata!( @@ -1333,20 +1302,39 @@ end metadata::NamedTuple{metadata_names}, vns::NamedTuple{vns_names}, ) where {metadata_names,vns_names} - vals = Expr(:tuple) + expr = quote + cumulative_logjac = zero(LogProbType) + end + mds = Expr(:tuple) for f in metadata_names if f in vns_names - push!(vals.args, :(_link_metadata!!(model, varinfo, metadata.$f, vns.$f))) + push!( + mds.args, + quote + begin + md, logjac = _link_metadata!!(model, varinfo, metadata.$f, vns.$f) + cumulative_logjac += logjac + md + end + end, + ) else - push!(vals.args, :(metadata.$f)) + push!(mds.args, :(metadata.$f)) end end - return :(NamedTuple{$metadata_names}($vals)) + push!( + expr.args, + quote + NamedTuple{$metadata_names}($mds), cumulative_logjac + end, + ) + return expr end function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns + cumulative_logjac = zero(LogProbType) # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn @@ -1364,7 +1352,7 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ # Vectorize value. yvec = tovec(y) # Accumulate the log-abs-det jacobian correction. - acclogp!!(varinfo, -logjac) + cumulative_logjac += logjac # Mark as transformed. settrans!!(varinfo, true, vn) # Return the vectorized transformed value. @@ -1389,7 +1377,8 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ metadata.dists, metadata.orders, metadata.flags, - ) + ), + cumulative_logjac end function _link_metadata!!( @@ -1397,6 +1386,7 @@ function _link_metadata!!( ) vns = target_vns === nothing ? keys(metadata) : target_vns dists = extract_priors(model, varinfo) + cumulative_logjac = zero(LogProbType) for vn in vns # First transform from however the variable is stored in vnv to the model # representation. @@ -1409,11 +1399,11 @@ function _link_metadata!!( val_new, logjac2 = with_logabsdet_jacobian(transform_to_linked, val_orig) # TODO(mhauru) We are calling a !! function but ignoring the return value. # Fix this when attending to issue #653. - acclogp!!(varinfo, -logjac1 - logjac2) + cumulative_logjac += logjac1 + logjac2 metadata = setindex_internal!!(metadata, val_new, vn, transform_from_linked) settrans!(metadata, true, vn) end - return metadata + return metadata, cumulative_logjac end function invlink(::DynamicTransformation, vi::NTVarInfo, model::Model) @@ -1449,11 +1439,10 @@ end function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) - return VarInfo( - _invlink_metadata!!(model, varinfo, varinfo.metadata, vns), - Base.Ref(getlogp(varinfo)), - Ref(get_num_produce(varinfo)), - ) + md, logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + return new_varinfo end # If we try to _invlink a NTVarInfo with a Tuple of VarNames, first convert it to a @@ -1464,8 +1453,10 @@ end function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) - return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) + md, logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + return new_varinfo end @generated function _invlink_metadata!( @@ -1474,20 +1465,41 @@ end metadata::NamedTuple{metadata_names}, vns::NamedTuple{vns_names}, ) where {metadata_names,vns_names} - vals = Expr(:tuple) + expr = quote + cumulative_logjac = zero(LogProbType) + end + mds = Expr(:tuple) for f in metadata_names if (f in vns_names) - push!(vals.args, :(_invlink_metadata!!(model, varinfo, metadata.$f, vns.$f))) + push!( + mds.args, + quote + begin + md, logjac = _invlink_metadata!!( + model, varinfo, metadata.$f, vns.$f + ) + cumulative_logjac += logjac + md + end + end, + ) else - push!(vals.args, :(metadata.$f)) + push!(mds.args, :(metadata.$f)) end end - return :(NamedTuple{$metadata_names}($vals)) + push!( + expr.args, + quote + (NamedTuple{$metadata_names}($mds), cumulative_logjac) + end, + ) + return expr end function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns + cumulative_logjac = zero(LogProbType) # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn @@ -1506,7 +1518,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ # Vectorize value. xvec = tovec(x) # Accumulate the log-abs-det jacobian correction. - acclogp!!(varinfo, -logjac) + cumulative_logjac += logjac # Mark as no longer transformed. settrans!!(varinfo, false, vn) # Return the vectorized transformed value. @@ -1531,24 +1543,26 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ metadata.dists, metadata.orders, metadata.flags, - ) + ), + cumulative_logjac end function _invlink_metadata!!( ::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns ) vns = target_vns === nothing ? keys(metadata) : target_vns + cumulative_logjac = zero(LogProbType) for vn in vns transform = gettransform(metadata, vn) old_val = getindex_internal(metadata, vn) new_val, logjac = with_logabsdet_jacobian(transform, old_val) # TODO(mhauru) We are calling a !! function but ignoring the return value. - acclogp!!(varinfo, -logjac) + cumulative_logjac += logjac new_transform = from_vec_transform(new_val) metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform) settrans!(metadata, false, vn) end - return metadata + return metadata, cumulative_logjac end # TODO(mhauru) The treatment of the case when some variables are linked and others are not @@ -1705,19 +1719,35 @@ function Base.haskey(vi::NTVarInfo, vn::VarName) end function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) - vi_str = """ - /======================================================================= - | VarInfo - |----------------------------------------------------------------------- - | Varnames : $(string(vi.metadata.vns)) - | Range : $(vi.metadata.ranges) - | Vals : $(vi.metadata.vals) - | Orders : $(vi.metadata.orders) - | Logp : $(getlogp(vi)) - | #produce : $(get_num_produce(vi)) - | flags : $(vi.metadata.flags) - \\======================================================================= - """ + lines = Tuple{String,Any}[ + ("VarNames", vi.metadata.vns), + ("Range", vi.metadata.ranges), + ("Vals", vi.metadata.vals), + ("Orders", vi.metadata.orders), + ] + for accname in acckeys(vi) + push!(lines, (string(accname), getacc(vi, Val(accname)))) + end + push!(lines, ("flags", vi.metadata.flags)) + max_name_length = maximum(map(length ∘ first, lines)) + fmt = Printf.Format("%-$(max_name_length)s") + vi_str = ( + """ + /======================================================================= + | VarInfo + |----------------------------------------------------------------------- + """ * + prod( + map(lines) do (name, value) + """ + | $(Printf.format(fmt, name)) : $(value) + """ + end, + ) * + """ + \\======================================================================= + """ + ) return print(io, vi_str) end @@ -1747,7 +1777,11 @@ end function Base.show(io::IO, vi::UntypedVarInfo) print(io, "VarInfo (") _show_varnames(io, vi) - print(io, "; logp: ", round(getlogp(vi); digits=3)) + print(io, "; accumulators: ") + # TODO(mhauru) This uses "text/plain" because we are doing quite a condensed repretation + # of vi anyway. However, technically `show(io, x)` should give full details of x and + # preferably output valid Julia code. + show(io, MIME"text/plain"(), getaccs(vi)) return print(io, ")") end diff --git a/test/accumulators.jl b/test/accumulators.jl new file mode 100644 index 000000000..36bb95e46 --- /dev/null +++ b/test/accumulators.jl @@ -0,0 +1,176 @@ +module AccumulatorTests + +using Test +using Distributions +using DynamicPPL +using DynamicPPL: + AccumulatorTuple, + LogLikelihoodAccumulator, + LogPriorAccumulator, + NumProduceAccumulator, + accumulate_assume!!, + accumulate_observe!!, + combine, + convert_eltype, + getacc, + increment, + map_accumulator, + setacc!!, + split + +@testset "accumulators" begin + @testset "individual accumulator types" begin + @testset "constructors" begin + @test LogPriorAccumulator(0.0) == + LogPriorAccumulator() == + LogPriorAccumulator{Float64}() == + LogPriorAccumulator{Float64}(0.0) == + zero(LogPriorAccumulator(1.0)) + @test LogLikelihoodAccumulator(0.0) == + LogLikelihoodAccumulator() == + LogLikelihoodAccumulator{Float64}() == + LogLikelihoodAccumulator{Float64}(0.0) == + zero(LogLikelihoodAccumulator(1.0)) + @test NumProduceAccumulator(0) == + NumProduceAccumulator() == + NumProduceAccumulator{Int}() == + NumProduceAccumulator{Int}(0) == + zero(NumProduceAccumulator(1)) + end + + @testset "addition and incrementation" begin + @test LogPriorAccumulator(1.0f0) + LogPriorAccumulator(1.0f0) == + LogPriorAccumulator(2.0f0) + @test LogPriorAccumulator(1.0) + LogPriorAccumulator(1.0f0) == + LogPriorAccumulator(2.0) + @test LogLikelihoodAccumulator(1.0f0) + LogLikelihoodAccumulator(1.0f0) == + LogLikelihoodAccumulator(2.0f0) + @test LogLikelihoodAccumulator(1.0) + LogLikelihoodAccumulator(1.0f0) == + LogLikelihoodAccumulator(2.0) + @test increment(NumProduceAccumulator()) == NumProduceAccumulator(1) + @test increment(NumProduceAccumulator{UInt8}()) == + NumProduceAccumulator{UInt8}(1) + end + + @testset "split and combine" begin + for acc in [ + LogPriorAccumulator(1.0), + LogLikelihoodAccumulator(1.0), + NumProduceAccumulator(1), + LogPriorAccumulator(1.0f0), + LogLikelihoodAccumulator(1.0f0), + NumProduceAccumulator(UInt8(1)), + ] + @test combine(acc, split(acc)) == acc + end + end + + @testset "conversions" begin + @test convert(LogPriorAccumulator{Float32}, LogPriorAccumulator(1.0)) == + LogPriorAccumulator{Float32}(1.0f0) + @test convert( + LogLikelihoodAccumulator{Float32}, LogLikelihoodAccumulator(1.0) + ) == LogLikelihoodAccumulator{Float32}(1.0f0) + @test convert(NumProduceAccumulator{UInt8}, NumProduceAccumulator(1)) == + NumProduceAccumulator{UInt8}(1) + + @test convert_eltype(Float32, LogPriorAccumulator(1.0)) == + LogPriorAccumulator{Float32}(1.0f0) + @test convert_eltype(Float32, LogLikelihoodAccumulator(1.0)) == + LogLikelihoodAccumulator{Float32}(1.0f0) + end + + @testset "accumulate_assume" begin + val = 2.0 + logjac = pi + vn = @varname(x) + dist = Normal() + @test accumulate_assume!!(LogPriorAccumulator(1.0), val, logjac, vn, dist) == + LogPriorAccumulator(1.0 + logjac + logpdf(dist, val)) + @test accumulate_assume!!( + LogLikelihoodAccumulator(1.0), val, logjac, vn, dist + ) == LogLikelihoodAccumulator(1.0) + @test accumulate_assume!!(NumProduceAccumulator(1), val, logjac, vn, dist) == + NumProduceAccumulator(1) + end + + @testset "accumulate_observe" begin + right = Normal() + left = 2.0 + vn = @varname(x) + @test accumulate_observe!!(LogPriorAccumulator(1.0), right, left, vn) == + LogPriorAccumulator(1.0) + @test accumulate_observe!!(LogLikelihoodAccumulator(1.0), right, left, vn) == + LogLikelihoodAccumulator(1.0 + logpdf(right, left)) + @test accumulate_observe!!(NumProduceAccumulator(1), right, left, vn) == + NumProduceAccumulator(2) + end + end + + @testset "accumulator tuples" begin + # Some accumulators we'll use for testing + lp_f64 = LogPriorAccumulator(1.0) + lp_f32 = LogPriorAccumulator(1.0f0) + ll_f64 = LogLikelihoodAccumulator(1.0) + ll_f32 = LogLikelihoodAccumulator(1.0f0) + np_i64 = NumProduceAccumulator(1) + + @testset "constructors" begin + @test AccumulatorTuple(lp_f64, ll_f64) == AccumulatorTuple((lp_f64, ll_f64)) + # Names in NamedTuple arguments are ignored + @test AccumulatorTuple((; a=lp_f64)) == AccumulatorTuple(lp_f64) + + # Can't have two accumulators of the same type. + @test_throws "duplicate field name" AccumulatorTuple(lp_f64, lp_f64) + # Not even if their element types differ. + @test_throws "duplicate field name" AccumulatorTuple(lp_f64, lp_f32) + end + + @testset "basic operations" begin + at_all64 = AccumulatorTuple(lp_f64, ll_f64, np_i64) + + @test at_all64[:LogPrior] == lp_f64 + @test at_all64[:LogLikelihood] == ll_f64 + @test at_all64[:NumProduce] == np_i64 + + @test haskey(AccumulatorTuple(np_i64), Val(:NumProduce)) + @test ~haskey(AccumulatorTuple(np_i64), Val(:LogPrior)) + @test length(AccumulatorTuple(lp_f64, ll_f64, np_i64)) == 3 + @test keys(at_all64) == (:LogPrior, :LogLikelihood, :NumProduce) + @test collect(at_all64) == [lp_f64, ll_f64, np_i64] + + # Replace the existing LogPriorAccumulator + @test setacc!!(at_all64, lp_f32)[:LogPrior] == lp_f32 + # Check that setacc!! didn't modify the original + @test at_all64 == AccumulatorTuple(lp_f64, ll_f64, np_i64) + # Add a new accumulator type. + @test setacc!!(AccumulatorTuple(lp_f64), ll_f64) == + AccumulatorTuple(lp_f64, ll_f64) + + @test getacc(at_all64, Val(:LogPrior)) == lp_f64 + end + + @testset "map_accumulator(s)!!" begin + # map over all accumulators + accs = AccumulatorTuple(lp_f32, ll_f32) + @test map(zero, accs) == AccumulatorTuple( + LogPriorAccumulator(0.0f0), LogLikelihoodAccumulator(0.0f0) + ) + # Test that the original wasn't modified. + @test accs == AccumulatorTuple(lp_f32, ll_f32) + + # A map with a closure that changes the types of the accumulators. + @test map(acc -> convert_eltype(Float64, acc), accs) == + AccumulatorTuple(LogPriorAccumulator(1.0), LogLikelihoodAccumulator(1.0)) + + # only apply to a particular accumulator + @test map_accumulator(zero, accs, Val(:LogLikelihood)) == + AccumulatorTuple(lp_f32, LogLikelihoodAccumulator(0.0f0)) + @test map_accumulator( + acc -> convert_eltype(Float64, acc), accs, Val(:LogLikelihood) + ) == AccumulatorTuple(lp_f32, LogLikelihoodAccumulator(1.0)) + end + end +end + +end diff --git a/test/compiler.jl b/test/compiler.jl index a0286d405..81c018111 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -189,12 +189,12 @@ module Issue537 end global model_ = __model__ global context_ = __context__ global rng_ = __context__.rng - global lp = getlogp(__varinfo__) + global lp = getlogjoint(__varinfo__) return x end model = testmodel_missing3([1.0]) varinfo = VarInfo(model) - @test getlogp(varinfo) == lp + @test getlogjoint(varinfo) == lp @test varinfo_ isa AbstractVarInfo @test model_ === model @test context_ isa SamplingContext @@ -208,13 +208,13 @@ module Issue537 end global model_ = __model__ global context_ = __context__ global rng_ = __context__.rng - global lp = getlogp(__varinfo__) + global lp = getlogjoint(__varinfo__) return x end false lpold = lp model = testmodel_missing4([1.0]) varinfo = VarInfo(model) - @test getlogp(varinfo) == lp == lpold + @test getlogjoint(varinfo) == lp == lpold # test DPPL#61 @model function testmodel_missing5(z) @@ -333,14 +333,14 @@ module Issue537 end function makemodel(p) @model function testmodel(x) x[1] ~ Bernoulli(p) - global lp = getlogp(__varinfo__) + global lp = getlogjoint(__varinfo__) return x end return testmodel end model = makemodel(0.5)([1.0]) varinfo = VarInfo(model) - @test getlogp(varinfo) == lp + @test getlogjoint(varinfo) == lp end @testset "user-defined variable name" begin @model f1() = x ~ NamedDist(Normal(), :y) @@ -364,9 +364,9 @@ module Issue537 end # TODO(torfjelde): We need conditioning for `Dict`. @test_broken f2_c() == 1 @test_broken f3_c() == 1 - @test_broken getlogp(VarInfo(f1_c)) == - getlogp(VarInfo(f2_c)) == - getlogp(VarInfo(f3_c)) + @test_broken getlogjoint(VarInfo(f1_c)) == + getlogjoint(VarInfo(f2_c)) == + getlogjoint(VarInfo(f3_c)) end @testset "custom tilde" begin @model demo() = begin diff --git a/test/context_implementations.jl b/test/context_implementations.jl index 0ec88c07c..ac6321d69 100644 --- a/test/context_implementations.jl +++ b/test/context_implementations.jl @@ -10,7 +10,7 @@ end end - test([1, 1, -1])(VarInfo(), SampleFromPrior(), LikelihoodContext()) + test([1, 1, -1])(VarInfo(), SampleFromPrior(), DefaultContext()) end @testset "dot tilde with varying sizes" begin @@ -18,13 +18,14 @@ @model function test(x, size) y = Array{Float64,length(size)}(undef, size...) y .~ Normal(x) - return y, getlogp(__varinfo__) + return y end for ysize in ((2,), (2, 3), (2, 3, 4)) x = randn() model = test(x, ysize) - y, lp = model() + y = model() + lp = logjoint(model, (; y=y)) @test lp ≈ sum(logpdf.(Normal.(x), y)) ys = [first(model()) for _ in 1:10_000] diff --git a/test/contexts.jl b/test/contexts.jl index 1ba099a37..5f22b75eb 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -9,7 +9,6 @@ using DynamicPPL: NodeTrait, IsLeaf, IsParent, - PointwiseLogdensityContext, contextual_isassumption, FixedContext, ConditionContext, @@ -47,18 +46,11 @@ Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "contexts.jl" begin - child_contexts = Dict( + contexts = Dict( :default => DefaultContext(), - :prior => PriorContext(), - :likelihood => LikelihoodContext(), - ) - - parent_contexts = Dict( :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), :sampling => SamplingContext(), - :minibatch => MiniBatchContext(DefaultContext(), 0.0), :prefix => PrefixContext(@varname(x)), - :pointwiselogdensity => PointwiseLogdensityContext(), :condition1 => ConditionContext((x=1.0,)), :condition2 => ConditionContext( (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) @@ -70,8 +62,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() :condition4 => ConditionContext((x=[1.0, missing],)), ) - contexts = merge(child_contexts, parent_contexts) - @testset "$(name)" for (name, context) in contexts @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS DynamicPPL.TestUtils.test_context(context, model) @@ -235,7 +225,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # Values from outer context should override inner one ctx1 = ConditionContext(n1, ConditionContext(n2)) @test ctx1.values == (x=1, y=2) - # Check that the two ConditionContexts are collapsed + # Check that the two ConditionContexts are collapsed @test childcontext(ctx1) isa DefaultContext # Then test the nesting the other way round ctx2 = ConditionContext(n2, ConditionContext(n1)) diff --git a/test/independence.jl b/test/independence.jl deleted file mode 100644 index a4a834a61..000000000 --- a/test/independence.jl +++ /dev/null @@ -1,11 +0,0 @@ -@testset "Turing independence" begin - @model coinflip(y) = begin - p ~ Beta(1, 1) - N = length(y) - for i in 1:N - y[i] ~ Bernoulli(p) - end - end - model = coinflip([1, 1, 0]) - model(SampleFromPrior(), LikelihoodContext()) -end diff --git a/test/linking.jl b/test/linking.jl index d424a9c2d..4f1707263 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -85,7 +85,7 @@ end DynamicPPL.link(vi, model) end # Difference should just be the log-absdet-jacobian "correction". - @test DynamicPPL.getlogp(vi) - DynamicPPL.getlogp(vi_linked) ≈ log(2) + @test DynamicPPL.getlogjoint(vi) - DynamicPPL.getlogjoint(vi_linked) ≈ log(2) @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist]) # Linked one should be working with a lower-dimensional representation. @test length(vi_linked[:]) < length(vi[:]) @@ -98,7 +98,7 @@ end end @test length(vi_invlinked[:]) == length(vi[:]) @test vi_invlinked[@varname(m), dist] ≈ LowerTriangular(vi[@varname(m), dist]) - @test DynamicPPL.getlogp(vi_invlinked) ≈ DynamicPPL.getlogp(vi) + @test DynamicPPL.getlogjoint(vi_invlinked) ≈ DynamicPPL.getlogjoint(vi) end end @@ -130,7 +130,7 @@ end end @test length(vi_linked[:]) == d * (d - 1) ÷ 2 # Should now include the log-absdet-jacobian correction. - @test !(getlogp(vi_linked) ≈ lp) + @test !(getlogjoint(vi_linked) ≈ lp) # Invlinked. vi_invlinked = if mutable DynamicPPL.invlink!!(deepcopy(vi_linked), model) @@ -138,7 +138,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == d^2 - @test getlogp(vi_invlinked) ≈ lp + @test getlogjoint(vi_invlinked) ≈ lp end end end @@ -164,7 +164,7 @@ end end @test length(vi_linked[:]) == d - 1 # Should now include the log-absdet-jacobian correction. - @test !(getlogp(vi_linked) ≈ lp) + @test !(getlogjoint(vi_linked) ≈ lp) # Invlinked. vi_invlinked = if mutable DynamicPPL.invlink!!(deepcopy(vi_linked), model) @@ -172,7 +172,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == d - @test getlogp(vi_invlinked) ≈ lp + @test getlogjoint(vi_invlinked) ≈ lp end end end diff --git a/test/model.jl b/test/model.jl index dd5a35fe6..6e4a24ae6 100644 --- a/test/model.jl +++ b/test/model.jl @@ -41,7 +41,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() m = vi[@varname(m)] # extract log pdf of variable object - lp = getlogp(vi) + lp = getlogjoint(vi) # log prior probability lprior = logprior(model, vi) @@ -494,7 +494,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() varinfo_linked_result = last( DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked), DefaultContext()) ) - @test getlogp(varinfo_linked) ≈ getlogp(varinfo_linked_result) + @test getlogjoint(varinfo_linked) ≈ getlogjoint(varinfo_linked_result) end end diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 61c842638..cfb222b66 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -1,6 +1,4 @@ @testset "logdensities_likelihoods.jl" begin - mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) - mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS example_values = DynamicPPL.TestUtils.rand_prior_true(model) @@ -37,11 +35,6 @@ lps = pointwise_logdensities(model, vi) logp = sum(sum, values(lps)) @test logp ≈ (logprior_true + loglikelihood_true) - - # Test that modifications of Setup are picked up - lps = pointwise_logdensities(model, vi, mod_ctx2) - logp = sum(sum, values(lps)) - @test logp ≈ (logprior_true + loglikelihood_true) * 1.2 * 1.4 end end diff --git a/test/runtests.jl b/test/runtests.jl index 72f33f2d0..4a9acf4e1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -49,13 +49,13 @@ include("test_util.jl") include("Aqua.jl") end include("utils.jl") + include("accumulators.jl") include("compiler.jl") include("varnamedvector.jl") include("varinfo.jl") include("simple_varinfo.jl") include("model.jl") include("sampler.jl") - include("independence.jl") include("distribution_wrappers.jl") include("logdensityfunction.jl") include("linking.jl") diff --git a/test/sampler.jl b/test/sampler.jl index 8c4f1ed96..fe9fd331a 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -84,7 +84,7 @@ let inits = (; p=0.2) chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.p.vals == [0.2] - @test getlogp(chain[1]) == lptrue + @test getlogjoint(chain[1]) == lptrue # parallel sampling chains = sample( @@ -98,7 +98,7 @@ ) for c in chains @test c[1].metadata.p.vals == [0.2] - @test getlogp(c[1]) == lptrue + @test getlogjoint(c[1]) == lptrue end end @@ -113,7 +113,7 @@ chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.s.vals == [4] @test chain[1].metadata.m.vals == [-1] - @test getlogp(chain[1]) == lptrue + @test getlogjoint(chain[1]) == lptrue # parallel sampling chains = sample( @@ -128,7 +128,7 @@ for c in chains @test c[1].metadata.s.vals == [4] @test c[1].metadata.m.vals == [-1] - @test getlogp(c[1]) == lptrue + @test getlogjoint(c[1]) == lptrue end end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 380c24e7d..6f2f39a64 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -2,12 +2,12 @@ @testset "constructor & indexing" begin @testset "NamedTuple" begin svi = SimpleVarInfo(; m=1.0) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test !haskey(svi, @varname(m[1])) svi = SimpleVarInfo(; m=[1.0]) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m[1])) @test !haskey(svi, @varname(m[2])) @@ -21,20 +21,21 @@ @test !haskey(svi, @varname(m.a.b)) svi = SimpleVarInfo{Float32}(; m=1.0) - @test getlogp(svi) isa Float32 + @test getlogjoint(svi) isa Float32 - svi = SimpleVarInfo((m=1.0,), 1.0) - @test getlogp(svi) == 1.0 + svi = SimpleVarInfo((m=1.0,)) + svi = accloglikelihood!!(svi, 1.0) + @test getlogjoint(svi) == 1.0 end @testset "Dict" begin svi = SimpleVarInfo(Dict(@varname(m) => 1.0)) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test !haskey(svi, @varname(m[1])) svi = SimpleVarInfo(Dict(@varname(m) => [1.0])) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m[1])) @test !haskey(svi, @varname(m[2])) @@ -59,12 +60,12 @@ @testset "VarNamedVector" begin svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => 1.0)) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test !haskey(svi, @varname(m[1])) svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => [1.0])) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m[1])) @test !haskey(svi, @varname(m[2])) @@ -98,11 +99,10 @@ vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) end vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) - lp_orig = getlogp(vi) # `link!!` vi_linked = link!!(deepcopy(vi), model) - lp_linked = getlogp(vi_linked) + lp_linked = getlogjoint(vi_linked) values_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( model, values_constrained... ) @@ -113,7 +113,7 @@ # `invlink!!` vi_invlinked = invlink!!(deepcopy(vi_linked), model) - lp_invlinked = getlogp(vi_invlinked) + lp_invlinked = getlogjoint(vi_invlinked) lp_invlinked_true = DynamicPPL.TestUtils.logjoint_true( model, values_constrained... ) @@ -152,7 +152,7 @@ # DynamicPPL.settrans!!(deepcopy(svi_dict), true), # DynamicPPL.settrans!!(deepcopy(svi_vnv), true), ) - # RandOM seed is set in each `@testset`, so we need to sample + # Random seed is set in each `@testset`, so we need to sample # a new realization for `m` here. retval = model() @@ -166,7 +166,7 @@ end # Logjoint should be non-zero wp. 1. - @test getlogp(svi_new) != 0 + @test getlogjoint(svi_new) != 0 ### Evaluation ### values_eval_constrained = DynamicPPL.TestUtils.rand_prior_true(model) @@ -201,7 +201,7 @@ svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) end - # Reset the logp field. + # Reset the logp accumulators. svi_eval = DynamicPPL.resetlogp!!(svi_eval) # Compute `logjoint` using the varinfo. @@ -250,7 +250,7 @@ end # `getlogp` should be equal to the logjoint with log-absdet-jac correction. - lp = getlogp(svi) + lp = getlogjoint(svi) # needs higher atol because of https://github.com/TuringLang/Bijectors.jl/issues/375 @test lp ≈ lp_true atol = 1.2e-5 end @@ -306,7 +306,7 @@ DynamicPPL.tovec(retval_unconstrained.m) # The resulting varinfo should hold the correct logp. - lp = getlogp(vi_linked_result) + lp = getlogjoint(vi_linked_result) @test lp ≈ lp_true end end diff --git a/test/submodels.jl b/test/submodels.jl index e79eed2c3..d3a2f17e7 100644 --- a/test/submodels.jl +++ b/test/submodels.jl @@ -35,7 +35,7 @@ using Test @test model()[1] == x_val # Test that the logp was correctly set vi = VarInfo(model) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)]) # Check the keys @test Set(keys(VarInfo(model))) == Set([@varname(a.y)]) end @@ -67,7 +67,7 @@ using Test @test model()[1] == x_val # Test that the logp was correctly set vi = VarInfo(model) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(y)]) # Check the keys @test Set(keys(VarInfo(model))) == Set([@varname(y)]) end @@ -99,7 +99,7 @@ using Test @test model()[1] == x_val # Test that the logp was correctly set vi = VarInfo(model) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)]) # Check the keys @test Set(keys(VarInfo(model))) == Set([@varname(b.y)]) end @@ -148,7 +148,7 @@ using Test # No conditioning vi = VarInfo(h()) @test Set(keys(vi)) == Set([@varname(a.b.x), @varname(a.b.y)]) - @test getlogp(vi) == + @test getlogjoint(vi) == logpdf(Normal(), vi[@varname(a.b.x)]) + logpdf(Normal(), vi[@varname(a.b.y)]) @@ -174,7 +174,7 @@ using Test @testset "$name" for (name, model) in models vi = VarInfo(model) @test Set(keys(vi)) == Set([@varname(a.b.y)]) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) end end end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 72c439db8..5b4f6951f 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -4,9 +4,12 @@ threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi) @test threadsafe_vi.varinfo === vi - @test threadsafe_vi.logps isa Vector{typeof(Ref(getlogp(vi)))} - @test length(threadsafe_vi.logps) == Threads.nthreads() - @test all(iszero(x[]) for x in threadsafe_vi.logps) + @test threadsafe_vi.accs_by_thread isa Vector{<:DynamicPPL.AccumulatorTuple} + @test length(threadsafe_vi.accs_by_thread) == Threads.nthreads() + expected_accs = DynamicPPL.AccumulatorTuple( + (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(vi))... + ) + @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end # TODO: Add more tests of the public API @@ -14,23 +17,27 @@ vi = VarInfo(gdemo_default) threadsafe_vi = DynamicPPL.ThreadSafeVarInfo(vi) - lp = getlogp(vi) - @test getlogp(threadsafe_vi) == lp + lp = getlogjoint(vi) + @test getlogjoint(threadsafe_vi) == lp - acclogp!!(threadsafe_vi, 42) - @test threadsafe_vi.logps[Threads.threadid()][] == 42 - @test getlogp(vi) == lp - @test getlogp(threadsafe_vi) == lp + 42 + threadsafe_vi = DynamicPPL.acclogprior!!(threadsafe_vi, 42) + @test threadsafe_vi.accs_by_thread[Threads.threadid()][:LogPrior].logp == 42 + @test getlogjoint(vi) == lp + @test getlogjoint(threadsafe_vi) == lp + 42 - resetlogp!!(threadsafe_vi) - @test iszero(getlogp(vi)) - @test iszero(getlogp(threadsafe_vi)) - @test all(iszero(x[]) for x in threadsafe_vi.logps) + threadsafe_vi = resetlogp!!(threadsafe_vi) + @test iszero(getlogjoint(threadsafe_vi)) + expected_accs = DynamicPPL.AccumulatorTuple( + (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))... + ) + @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) - setlogp!!(threadsafe_vi, 42) - @test getlogp(vi) == 42 - @test getlogp(threadsafe_vi) == 42 - @test all(iszero(x[]) for x in threadsafe_vi.logps) + threadsafe_vi = setlogprior!!(threadsafe_vi, 42) + @test getlogjoint(threadsafe_vi) == 42 + expected_accs = DynamicPPL.AccumulatorTuple( + (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))... + ) + @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end @testset "model" begin @@ -48,7 +55,7 @@ vi = VarInfo() wthreads(x)(vi) - lp_w_threads = getlogp(vi) + lp_w_threads = getlogjoint(vi) if Threads.nthreads() == 1 @test vi_ isa VarInfo else @@ -65,7 +72,7 @@ vi, SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), ) - @test getlogp(vi) ≈ lp_w_threads + @test getlogjoint(vi) ≈ lp_w_threads @test vi_ isa DynamicPPL.ThreadSafeVarInfo println(" evaluate_threadsafe!!:") @@ -85,7 +92,7 @@ vi = VarInfo() wothreads(x)(vi) - lp_wo_threads = getlogp(vi) + lp_wo_threads = getlogjoint(vi) if Threads.nthreads() == 1 @test vi_ isa VarInfo else @@ -104,7 +111,7 @@ vi, SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), ) - @test getlogp(vi) ≈ lp_w_threads + @test getlogjoint(vi) ≈ lp_w_threads @test vi_ isa VarInfo println(" evaluate_threadunsafe!!:") diff --git a/test/utils.jl b/test/utils.jl index d683f132d..e4bac14e0 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,15 +1,34 @@ @testset "utils.jl" begin @testset "addlogprob!" begin @model function testmodel() - global lp_before = getlogp(__varinfo__) + global lp_before = getlogjoint(__varinfo__) @addlogprob!(42) - return global lp_after = getlogp(__varinfo__) + return global lp_after = getlogjoint(__varinfo__) end - model = testmodel() - varinfo = VarInfo(model) + varinfo = VarInfo(testmodel()) @test iszero(lp_before) - @test getlogp(varinfo) == lp_after == 42 + @test getlogjoint(varinfo) == lp_after == 42 + @test getloglikelihood(varinfo) == 42 + + @model function testmodel_nt() + global lp_before = getlogjoint(__varinfo__) + @addlogprob! (; logprior=(pi + 1), loglikelihood=42) + return global lp_after = getlogjoint(__varinfo__) + end + + varinfo = VarInfo(testmodel_nt()) + @test iszero(lp_before) + @test getlogjoint(varinfo) == lp_after == 42 + 1 + pi + @test getloglikelihood(varinfo) == 42 + @test getlogprior(varinfo) == pi + 1 + + @model function testmodel_nt2() + global lp_before = getlogjoint(__varinfo__) + llh_nt = (; loglikelihood=42) + @addlogprob! llh_nt + return global lp_after = getlogjoint(__varinfo__) + end end @testset "getargs_dottilde" begin diff --git a/test/varinfo.jl b/test/varinfo.jl index 777917aa6..1c597f951 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -80,7 +80,7 @@ end function test_base!!(vi_original) vi = empty!!(vi_original) - @test getlogp(vi) == 0 + @test getlogjoint(vi) == 0 @test isempty(vi[:]) vn = @varname x @@ -123,13 +123,25 @@ end @testset "get/set/acc/resetlogp" begin function test_varinfo_logp!(vi) - @test DynamicPPL.getlogp(vi) === 0.0 - vi = DynamicPPL.setlogp!!(vi, 1.0) - @test DynamicPPL.getlogp(vi) === 1.0 - vi = DynamicPPL.acclogp!!(vi, 1.0) - @test DynamicPPL.getlogp(vi) === 2.0 + @test DynamicPPL.getlogjoint(vi) === 0.0 + vi = DynamicPPL.setlogprior!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 1.0 + @test DynamicPPL.getloglikelihood(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 1.0 + vi = DynamicPPL.acclogprior!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 2.0 + vi = DynamicPPL.setloglikelihood!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 1.0 + @test DynamicPPL.getlogjoint(vi) === 3.0 + vi = DynamicPPL.accloglikelihood!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 2.0 + @test DynamicPPL.getlogjoint(vi) === 4.0 vi = DynamicPPL.resetlogp!!(vi) - @test DynamicPPL.getlogp(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 0.0 end vi = VarInfo() @@ -140,6 +152,98 @@ end test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) end + @testset "accumulators" begin + @model function demo() + a ~ Normal() + b ~ Normal() + c ~ Normal() + d ~ Normal() + return nothing + end + + values = (; a=1.0, b=2.0, c=3.0, d=4.0) + lp_a = logpdf(Normal(), values.a) + lp_b = logpdf(Normal(), values.b) + lp_c = logpdf(Normal(), values.c) + lp_d = logpdf(Normal(), values.d) + m = demo() | (; c=values.c, d=values.d) + + vi = DynamicPPL.reset_num_produce!!( + DynamicPPL.unflatten(VarInfo(m), collect(values)) + ) + + vi = last(DynamicPPL.evaluate!!(m, deepcopy(vi))) + @test getlogprior(vi) == lp_a + lp_b + @test getloglikelihood(vi) == lp_c + lp_d + @test getlogp(vi) == (; logprior=lp_a + lp_b, loglikelihood=lp_c + lp_d) + @test getlogjoint(vi) == lp_a + lp_b + lp_c + lp_d + @test get_num_produce(vi) == 2 + @test begin + vi = acclogprior!!(vi, 1.0) + getlogprior(vi) == lp_a + lp_b + 1.0 + end + @test begin + vi = accloglikelihood!!(vi, 1.0) + getloglikelihood(vi) == lp_c + lp_d + 1.0 + end + @test begin + vi = setlogprior!!(vi, -1.0) + getlogprior(vi) == -1.0 + end + @test begin + vi = setloglikelihood!!(vi, -1.0) + getloglikelihood(vi) == -1.0 + end + @test begin + vi = setlogp!!(vi, (logprior=-3.0, loglikelihood=-3.0)) + getlogp(vi) == (; logprior=-3.0, loglikelihood=-3.0) + end + @test begin + vi = acclogp!!(vi, (logprior=1.0, loglikelihood=1.0)) + getlogp(vi) == (; logprior=-2.0, loglikelihood=-2.0) + end + @test getlogp(setlogp!!(vi, getlogp(vi))) == getlogp(vi) + + vi = last( + DynamicPPL.evaluate!!( + m, DynamicPPL.setaccs!!(deepcopy(vi), (LogPriorAccumulator(),)) + ), + ) + @test getlogprior(vi) == lp_a + lp_b + @test_throws "has no field LogLikelihood" getloglikelihood(vi) + @test_throws "has no field LogLikelihood" getlogp(vi) + @test_throws "has no field LogLikelihood" getlogjoint(vi) + @test_throws "has no field NumProduce" get_num_produce(vi) + @test begin + vi = acclogprior!!(vi, 1.0) + getlogprior(vi) == lp_a + lp_b + 1.0 + end + @test begin + vi = setlogprior!!(vi, -1.0) + getlogprior(vi) == -1.0 + end + + vi = last( + DynamicPPL.evaluate!!( + m, DynamicPPL.setaccs!!(deepcopy(vi), (NumProduceAccumulator(),)) + ), + ) + @test_throws "has no field LogPrior" getlogprior(vi) + @test_throws "has no field LogLikelihood" getloglikelihood(vi) + @test_throws "has no field LogPrior" getlogp(vi) + @test_throws "has no field LogPrior" getlogjoint(vi) + @test get_num_produce(vi) == 2 + + # Test evaluating without any accumulators. + vi = last(DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), ()))) + @test_throws "has no field LogPrior" getlogprior(vi) + @test_throws "has no field LogLikelihood" getloglikelihood(vi) + @test_throws "has no field LogPrior" getlogp(vi) + @test_throws "has no field LogPrior" getlogjoint(vi) + @test_throws "has no field NumProduce" get_num_produce(vi) + @test_throws "has no field NumProduce" reset_num_produce!!(vi) + end + @testset "flags" begin # Test flag setting: # is_flagged, set_flag!, unset_flag! @@ -455,12 +559,24 @@ end ## `untyped_varinfo` vi = DynamicPPL.untyped_varinfo(model) + + ## `untyped_varinfo` + vi = DynamicPPL.untyped_varinfo(model) + vi = DynamicPPL.settrans!!(vi, true, vn) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) + x = f(DynamicPPL.getindex_internal(vi, vn)) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + + ## `typed_varinfo` + vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) @@ -469,7 +585,7 @@ end vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ### `SimpleVarInfo` ## `SimpleVarInfo{<:NamedTuple}` @@ -478,7 +594,7 @@ end vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `SimpleVarInfo{<:Dict}` vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) @@ -486,7 +602,7 @@ end vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `SimpleVarInfo{<:VarNamedVector}` vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) @@ -494,7 +610,7 @@ end vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) end @testset "values_as" begin @@ -596,8 +712,8 @@ end lp = logjoint(model, varinfo) @test lp ≈ lp_true - @test getlogp(varinfo) ≈ lp_true - lp_linked = getlogp(varinfo_linked) + @test getlogjoint(varinfo) ≈ lp_true + lp_linked = getlogjoint(varinfo_linked) @test lp_linked ≈ lp_linked_true # TODO: Compare values once we are no longer working with `NamedTuple` for @@ -609,13 +725,36 @@ end varinfo_linked_unflattened, model ) @test length(varinfo_invlinked[:]) == length(varinfo[:]) - @test getlogp(varinfo_invlinked) ≈ lp_true + @test getlogjoint(varinfo_invlinked) ≈ lp_true end end end end end + @testset "unflatten type stability" begin + @model function demo(y) + x ~ Normal() + y ~ Normal(x, 1) + return nothing + end + + model = demo(0.0) + varinfos = DynamicPPL.TestUtils.setup_varinfos( + model, (; x=1.0), (@varname(x),); include_threadsafe=true + ) + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + # Skip the severely inconcrete `SimpleVarInfo` types, since checking for type + # stability for them doesn't make much sense anyway. + if varinfo isa SimpleVarInfo{OrderedDict{Any,Any}} || + varinfo isa + DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{OrderedDict{Any,Any}}} + continue + end + @inferred DynamicPPL.unflatten(varinfo, varinfo[:]) + end + end + @testset "subset" begin @model function demo_subsetting_varinfo(::Type{TV}=Vector{Float64}) where {TV} s ~ InverseGamma(2, 3) @@ -941,19 +1080,19 @@ end # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_b, dists[2]) randr(vi, vn_z2, dists[1]) randr(vi, vn_a2, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] @test DynamicPPL.get_num_produce(vi) == 3 - DynamicPPL.reset_num_produce!(vi) + vi = DynamicPPL.reset_num_produce!!(vi) DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") @@ -961,12 +1100,12 @@ end @test DynamicPPL.is_flagged(vi, vn_a2, "del") @test DynamicPPL.is_flagged(vi, vn_z3, "del") - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z2, dists[1]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) randr(vi, vn_a2, dists[2]) @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] @@ -975,21 +1114,21 @@ end vi = empty!!(DynamicPPL.typed_varinfo(vi)) # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_b, dists[2]) randr(vi, vn_z2, dists[1]) randr(vi, vn_a2, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) @test vi.metadata.z.orders == [1, 2, 3] @test vi.metadata.a.orders == [1, 2] @test vi.metadata.b.orders == [2] @test DynamicPPL.get_num_produce(vi) == 3 - DynamicPPL.reset_num_produce!(vi) + vi = DynamicPPL.reset_num_produce!!(vi) DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") @@ -997,12 +1136,12 @@ end @test DynamicPPL.is_flagged(vi, vn_a2, "del") @test DynamicPPL.is_flagged(vi, vn_z3, "del") - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z2, dists[1]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) randr(vi, vn_a2, dists[2]) @test vi.metadata.z.orders == [1, 2, 3] @@ -1017,8 +1156,8 @@ end n = length(varinfo[:]) # `Bool`. - @test getlogp(DynamicPPL.unflatten(varinfo, fill(true, n))) isa typeof(float(1)) + @test getlogjoint(DynamicPPL.unflatten(varinfo, fill(true, n))) isa typeof(float(1)) # `Int`. - @test getlogp(DynamicPPL.unflatten(varinfo, fill(1, n))) isa typeof(float(1)) + @test getlogjoint(DynamicPPL.unflatten(varinfo, fill(1, n))) isa typeof(float(1)) end end diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index bd3f5553f..f21d458a8 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -607,7 +607,7 @@ end DynamicPPL.evaluate!!(model, deepcopy(varinfo), DefaultContext()) ) # Log density should be the same. - @test getlogp(varinfo_eval) ≈ logp_true + @test getlogjoint(varinfo_eval) ≈ logp_true # Values should be the same. DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) @@ -616,7 +616,7 @@ end DynamicPPL.evaluate!!(model, deepcopy(varinfo), SamplingContext()) ) # Log density should be different. - @test getlogp(varinfo_sample) != getlogp(varinfo) + @test getlogjoint(varinfo_sample) != getlogjoint(varinfo) # Values should be different. DynamicPPL.TestUtils.test_values( varinfo_sample, value_true, vns; compare=!isequal From 326d7ed002ac76b8573aabf72510b98ebd43b418 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 8 May 2025 09:19:47 +0100 Subject: [PATCH 03/27] Replace PriorExtractorContext with PriorDistributionAccumulator (#907) --- src/extract_priors.jl | 58 ++++++++++++++++++++++++++----------------- 1 file changed, 35 insertions(+), 23 deletions(-) diff --git a/src/extract_priors.jl b/src/extract_priors.jl index 0f312fa2c..9047c9f0a 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -1,44 +1,47 @@ -struct PriorExtractorContext{D<:OrderedDict{VarName,Any},Ctx<:AbstractContext} <: - AbstractContext +struct PriorDistributionAccumulator{D<:OrderedDict{VarName,Any}} <: AbstractAccumulator priors::D - context::Ctx end -PriorExtractorContext(context) = PriorExtractorContext(OrderedDict{VarName,Any}(), context) +PriorDistributionAccumulator() = PriorDistributionAccumulator(OrderedDict{VarName,Any}()) -NodeTrait(::PriorExtractorContext) = IsParent() -childcontext(context::PriorExtractorContext) = context.context -function setchildcontext(parent::PriorExtractorContext, child) - return PriorExtractorContext(parent.priors, child) +accumulator_name(::PriorDistributionAccumulator) = :PriorDistributionAccumulator + +split(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors)) +function combine(acc1::PriorDistributionAccumulator, acc2::PriorDistributionAccumulator) + return PriorDistributionAccumulator(merge(acc1.priors, acc2.priors)) end -function setprior!(context::PriorExtractorContext, vn::VarName, dist::Distribution) - return context.priors[vn] = dist +function setprior!(acc::PriorDistributionAccumulator, vn::VarName, dist::Distribution) + acc.priors[vn] = dist + return acc end function setprior!( - context::PriorExtractorContext, vns::AbstractArray{<:VarName}, dist::Distribution + acc::PriorDistributionAccumulator, vns::AbstractArray{<:VarName}, dist::Distribution ) for vn in vns - context.priors[vn] = dist + acc.priors[vn] = dist end + return acc end function setprior!( - context::PriorExtractorContext, + acc::PriorDistributionAccumulator, vns::AbstractArray{<:VarName}, dists::AbstractArray{<:Distribution}, ) for (vn, dist) in zip(vns, dists) - context.priors[vn] = dist + acc.priors[vn] = dist end + return acc end -function DynamicPPL.tilde_assume(context::PriorExtractorContext, right, vn, vi) - setprior!(context, vn, right) - return DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) +function accumulate_assume!!(acc::PriorDistributionAccumulator, val, logjac, vn, right) + return setprior!(acc, vn, right) end +accumulate_observe!!(acc::PriorDistributionAccumulator, right, left, vn) = acc + """ extract_priors([rng::Random.AbstractRNG, ]model::Model) @@ -108,9 +111,13 @@ julia> length(extract_priors(rng, model)[@varname(x)]) extract_priors(args::Union{Model,AbstractVarInfo}...) = extract_priors(Random.default_rng(), args...) function extract_priors(rng::Random.AbstractRNG, model::Model) - context = PriorExtractorContext(SamplingContext(rng)) - evaluate!!(model, VarInfo(), context) - return context.priors + varinfo = VarInfo() + # TODO(mhauru) This doesn't actually need the NumProduceAccumulator, it's only a + # workaround for the fact that `order` is still hardcoded in VarInfo, and hence you + # can't push new variables without knowing the num_produce. Remove this when possible. + varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(), NumProduceAccumulator())) + varinfo = last(evaluate!!(model, varinfo, SamplingContext(rng))) + return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end """ @@ -122,7 +129,12 @@ This is done by evaluating the model at the values present in `varinfo` and recording the distributions that are present at each tilde statement. """ function extract_priors(model::Model, varinfo::AbstractVarInfo) - context = PriorExtractorContext(DefaultContext()) - evaluate!!(model, deepcopy(varinfo), context) - return context.priors + # TODO(mhauru) This doesn't actually need the NumProduceAccumulator, it's only a + # workaround for the fact that `order` is still hardcoded in VarInfo, and hence you + # can't push new variables without knowing the num_produce. Remove this when possible. + varinfo = setaccs!!( + deepcopy(varinfo), (PriorDistributionAccumulator(), NumProduceAccumulator()) + ) + varinfo = last(evaluate!!(model, varinfo, DefaultContext())) + return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end From d4ef1f2dadff5b4301f7c468aaa4feb11cc4eb3e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 8 May 2025 15:32:50 +0100 Subject: [PATCH 04/27] Implement values_as_in_model using an accumulator (#908) * Implement values_as_in_model using an accumulator * Make make_varname_expression a function * Refuse to combine ValuesAsInModelAccumulators with different include_colon_eqs * Fix nested context test --- src/compiler.jl | 40 ++++++++++------- src/values_as_in_model.jl | 95 +++++++++++++-------------------------- test/compiler.jl | 31 ++++++++++++- test/contexts.jl | 2 +- 4 files changed, 85 insertions(+), 83 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 9eb4835d3..b783c2a13 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -29,6 +29,18 @@ function need_concretize(expr) end end +""" + make_varname_expression(expr) + +Return a `VarName` based on `expr`, concretizing it if necessary. +""" +function make_varname_expression(expr) + # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact + # that in DynamicPPL we the entire function body. Instead we should be + # more selective with our escape. Until that's the case, we remove them all. + return AbstractPPL.drop_escape(varname(expr, need_concretize(expr))) +end + """ isassumption(expr[, vn]) @@ -48,10 +60,7 @@ evaluates to a `VarName`, and this will be used in the subsequent checks. If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be used in its place. """ -function isassumption( - expr::Union{Expr,Symbol}, - vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr))), -) +function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr)) return quote if $(DynamicPPL.contextual_isassumption)( __context__, $(DynamicPPL.prefix)(__context__, $vn) @@ -402,14 +411,18 @@ function generate_mainbody!(mod, found, expr::Expr, warn) end function generate_assign(left, right) - right_expr = :($(TrackedValue)($right)) - tilde_expr = generate_tilde(left, right_expr) + # A statement `x := y` reduces to `x = y`, but if __varinfo__ has an accumulator for + # ValuesAsInModel then in addition we push! the pair of `x` and `y` to the accumulator. + @gensym acc right_val vn return quote - if $(is_extracting_values)(__context__) - $tilde_expr - else - $left = $right + $right_val = $right + if $(DynamicPPL.is_extracting_values)(__varinfo__) + $vn = $(DynamicPPL.prefix)(__context__, $(make_varname_expression(left))) + __varinfo__ = $(map_accumulator!!)( + $acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel) + ) end + $left = $right_val end end @@ -437,14 +450,9 @@ function generate_tilde(left, right) # if the LHS represents an observation @gensym vn isassumption value dist - # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact - # that in DynamicPPL we the entire function body. Instead we should be - # more selective with our escape. Until that's the case, we remove them all. return quote $dist = $right - $vn = $(DynamicPPL.resolve_varnames)( - $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist - ) + $vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $(DynamicPPL.isfixed(left, vn)) $left = $(DynamicPPL.getfixed_nested)( diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 3ec474940..4d6225c10 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -1,16 +1,7 @@ -struct TrackedValue{T} - value::T -end - -is_tracked_value(::TrackedValue) = true -is_tracked_value(::Any) = false - -check_tilde_rhs(x::TrackedValue) = x - """ - ValuesAsInModelContext + ValuesAsInModelAccumulator <: AbstractAccumulator -A context that is used by [`values_as_in_model`](@ref) to obtain values +An accumulator that is used by [`values_as_in_model`](@ref) to obtain values of the model parameters as they are in the model. This is particularly useful when working in unconstrained space, but one @@ -19,72 +10,47 @@ wants to extract the realization of a model in a constrained space. # Fields $(TYPEDFIELDS) """ -struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext +struct ValuesAsInModelAccumulator <: AbstractAccumulator "values that are extracted from the model" values::OrderedDict "whether to extract variables on the LHS of :=" include_colon_eq::Bool - "child context" - context::C end -function ValuesAsInModelContext(include_colon_eq, context::AbstractContext) - return ValuesAsInModelContext(OrderedDict(), include_colon_eq, context) +function ValuesAsInModelAccumulator(include_colon_eq) + return ValuesAsInModelAccumulator(OrderedDict(), include_colon_eq) end -NodeTrait(::ValuesAsInModelContext) = IsParent() -childcontext(context::ValuesAsInModelContext) = context.context -function setchildcontext(context::ValuesAsInModelContext, child) - return ValuesAsInModelContext(context.values, context.include_colon_eq, child) -end +accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel -is_extracting_values(context::ValuesAsInModelContext) = context.include_colon_eq -function is_extracting_values(context::AbstractContext) - return is_extracting_values(NodeTrait(context), context) +function split(acc::ValuesAsInModelAccumulator) + return ValuesAsInModelAccumulator(empty(acc.values), acc.include_colon_eq) end -is_extracting_values(::IsParent, ::AbstractContext) = false -is_extracting_values(::IsLeaf, ::AbstractContext) = false - -function Base.push!(context::ValuesAsInModelContext, vn::VarName, value) - return setindex!(context.values, copy(value), prefix(context, vn)) +function combine(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator) + if acc1.include_colon_eq != acc2.include_colon_eq + msg = "Cannot combine accumulators with different include_colon_eq values." + throw(ArgumentError(msg)) + end + return ValuesAsInModelAccumulator( + merge(acc1.values, acc2.values), acc1.include_colon_eq + ) end -function broadcast_push!(context::ValuesAsInModelContext, vns, values) - return push!.((context,), vns, values) +function Base.push!(acc::ValuesAsInModelAccumulator, vn::VarName, val) + setindex!(acc.values, deepcopy(val), vn) + return acc end -# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`. -function broadcast_push!( - context::ValuesAsInModelContext, vns::AbstractVector, values::AbstractMatrix -) - for (vn, col) in zip(vns, eachcol(values)) - push!(context, vn, col) - end +function is_extracting_values(vi::AbstractVarInfo) + return hasacc(vi, Val(:ValuesAsInModel)) && + getacc(vi, Val(:ValuesAsInModel)).include_colon_eq end -# `tilde_asssume` -function tilde_assume(context::ValuesAsInModelContext, right, vn, vi) - if is_tracked_value(right) - value = right.value - else - value, vi = tilde_assume(childcontext(context), right, vn, vi) - end - push!(context, vn, value) - return value, vi -end -function tilde_assume( - rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi -) - if is_tracked_value(right) - value = right.value - else - value, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi) - end - # Save the value. - push!(context, vn, value) - # Pass on. - return value, vi +function accumulate_assume!!(acc::ValuesAsInModelAccumulator, val, logjac, vn, right) + return push!(acc, vn, val) end +accumulate_observe!!(acc::ValuesAsInModelAccumulator, right, left, vn) = acc + """ values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext]) @@ -103,7 +69,7 @@ space at the cost of additional model evaluations. - `model::Model`: model to extract realizations from. - `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`. - `varinfo::AbstractVarInfo`: variable information to use for the extraction. -- `context::AbstractContext`: base context to use for the extraction. Defaults +- `context::AbstractContext`: evaluation context to use in the extraction. Defaults to `DynamicPPL.DefaultContext()`. # Examples @@ -164,7 +130,8 @@ function values_as_in_model( varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext(), ) - context = ValuesAsInModelContext(include_colon_eq, context) - evaluate!!(model, varinfo, context) - return context.values + accs = getaccs(varinfo) + varinfo = setaccs!!(deepcopy(varinfo), (ValuesAsInModelAccumulator(include_colon_eq),)) + varinfo = last(evaluate!!(model, varinfo, context)) + return getacc(varinfo, Val(:ValuesAsInModel)).values end diff --git a/test/compiler.jl b/test/compiler.jl index 81c018111..2e76de27f 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -732,10 +732,10 @@ module Issue537 end y := 100 + x return (; x, y) end - @model function demo_tracked_submodel() + @model function demo_tracked_submodel_no_prefix() return vals ~ to_submodel(demo_tracked(), false) end - for model in [demo_tracked(), demo_tracked_submodel()] + for model in [demo_tracked(), demo_tracked_submodel_no_prefix()] # Make sure it's runnable and `y` is present in the return-value. @test model() isa NamedTuple{(:x, :y)} @@ -756,6 +756,33 @@ module Issue537 end @test haskey(values, @varname(x)) @test !haskey(values, @varname(y)) end + + @model function demo_tracked_return_x() + x ~ Normal() + y := 100 + x + return x + end + @model function demo_tracked_submodel_prefix() + return a ~ to_submodel(demo_tracked_return_x()) + end + @model function demo_tracked_subsubmodel_prefix() + return b ~ to_submodel(demo_tracked_submodel_prefix()) + end + # As above, but the variables should now have their names prefixed with `b.a`. + model = demo_tracked_subsubmodel_prefix() + varinfo = VarInfo(model) + @test haskey(varinfo, @varname(b.a.x)) + @test length(keys(varinfo)) == 1 + + values = values_as_in_model(model, true, deepcopy(varinfo)) + @test haskey(values, @varname(b.a.x)) + @test haskey(values, @varname(b.a.y)) + + # And if include_colon_eq is set to `false`, then `values` should + # only contain `x`. + values = values_as_in_model(model, false, deepcopy(varinfo)) + @test haskey(values, @varname(b.a.x)) + @test length(keys(varinfo)) == 1 end @testset "signature parsing + TypeWrap" begin diff --git a/test/contexts.jl b/test/contexts.jl index 5f22b75eb..1dd6a2280 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -154,7 +154,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) ctx3 = PrefixContext(@varname(b), ctx2) @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) - ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3) + ctx4 = DynamicPPL.SamplingContext(ctx3) @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) end From d04310328ecb2c46d099bf0d4584b669ed4ba942 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 2 Jun 2025 12:55:21 +0100 Subject: [PATCH 05/27] Bump DynamicPPL versions --- benchmarks/Project.toml | 2 +- docs/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 2b3bfbbdd..bd4e16663 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -19,7 +19,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" ADTypes = "1.14.0" BenchmarkTools = "1.6.0" Distributions = "0.25.117" -DynamicPPL = "0.36" +DynamicPPL = "0.37" ForwardDiff = "0.10.38, 1" LogDensityProblems = "2.1.2" Mooncake = "0.4" diff --git a/docs/Project.toml b/docs/Project.toml index c00c29c96..fb86a087e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -20,7 +20,7 @@ DataStructures = "0.18" Distributions = "0.25" Documenter = "1" DocumenterMermaid = "0.1, 0.2" -DynamicPPL = "0.36" +DynamicPPL = "0.37" FillArrays = "0.13, 1" ForwardDiff = "0.10, 1" JET = "0.9, 0.10" From d9545c6f0f8d49939f32f0296bf25b41f242a7fc Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 2 Jun 2025 15:08:24 +0100 Subject: [PATCH 06/27] Fix merge (1) --- src/simple_varinfo.jl | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b0df8ccde..42fcedfb8 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -632,15 +632,4 @@ function from_linked_internal_transform(vi::SimpleVarInfo, ::VarName, dist) return invlink_transform(dist) end -# Threadsafe stuff. -# For `SimpleVarInfo` we don't really need `Ref` so let's not use it. -function ThreadSafeVarInfo(vi::SimpleVarInfo) - return ThreadSafeVarInfo(vi, zeros(typeof(getlogp(vi)), Threads.nthreads() * 2)) -end -function ThreadSafeVarInfo(vi::SimpleVarInfo{<:Any,<:Ref}) - return ThreadSafeVarInfo( - vi, [Ref(zero(getlogp(vi))) for _ in 1:(Threads.nthreads() * 2)] - ) -end - has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector From a243bbb44261a055649471c32b70577c7b3b74da Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 2 Jun 2025 16:10:26 +0100 Subject: [PATCH 07/27] Add benchmark Pkg source --- benchmarks/Project.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index bd4e16663..3d14d03ff 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -15,6 +15,9 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +[sources] +DynamicPPL = {path = "../"} + [compat] ADTypes = "1.14.0" BenchmarkTools = "1.6.0" From e2272a578ac18d4165cb93274ac487ddac522b48 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 2 Jun 2025 16:28:40 +0100 Subject: [PATCH 08/27] [no ci] Don't need to dev again --- benchmarks/benchmarks.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 9661dd505..b733d810c 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -1,6 +1,4 @@ using Pkg -# To ensure we benchmark the local version of DynamicPPL, dev the folder above. -Pkg.develop(; path=joinpath(@__DIR__, "..")) using DynamicPPLBenchmarks: Models, make_suite, model_dimension using BenchmarkTools: @benchmark, median, run From 3cb47cdf16e384fda5746c89e82eb1f6c3250c73 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 2 Jun 2025 17:28:20 +0100 Subject: [PATCH 09/27] Disable use_closure for ReverseDiff --- src/logdensityfunction.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 35716aa9f..2b0620757 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -308,7 +308,7 @@ the constant approach will be used. use_closure(::ADTypes.AbstractADType) = false use_closure(::ADTypes.AutoForwardDiff) = false use_closure(::ADTypes.AutoMooncake) = false -use_closure(::ADTypes.AutoReverseDiff) = true +use_closure(::ADTypes.AutoReverseDiff) = false """ getmodel(f) From 2d11ad78e30e3d2f5aa774a04f8d83e9e189f2c4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 2 Jun 2025 17:55:16 +0100 Subject: [PATCH 10/27] Revert "Disable use_closure for ReverseDiff" This reverts commit 3cb47cdf16e384fda5746c89e82eb1f6c3250c73. --- src/logdensityfunction.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 2b0620757..35716aa9f 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -308,7 +308,7 @@ the constant approach will be used. use_closure(::ADTypes.AbstractADType) = false use_closure(::ADTypes.AutoForwardDiff) = false use_closure(::ADTypes.AutoMooncake) = false -use_closure(::ADTypes.AutoReverseDiff) = false +use_closure(::ADTypes.AutoReverseDiff) = true """ getmodel(f) From 0445092b9f40d1f36239a8f0d94bf7a94e7a73bc Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 2 Jun 2025 17:56:06 +0100 Subject: [PATCH 11/27] Fix LogDensityAt struct --- src/logdensityfunction.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 35716aa9f..323bf22a4 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -212,7 +212,18 @@ struct LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext} end function (ld::LogDensityAt)(x::AbstractVector) varinfo_new = unflatten(ld.varinfo, x) - return getlogp(last(evaluate!!(ld.model, varinfo_new, ld.context))) + varinfo_eval = last(evaluate!!(ld.model, varinfo_new, ld.context)) + has_prior = hasacc(varinfo_eval, Val(:LogPrior)) + has_likelihood = hasacc(varinfo_eval, Val(:LogLikelihood)) + if has_prior && has_likelihood + return getlogjoint(varinfo_eval) + elseif has_prior + return getlogprior(varinfo_eval) + elseif has_likelihood + return getloglikelihood(varinfo_eval) + else + error("LogDensityFunction: varinfo tracks neither log prior nor log likelihood") + end end ### LogDensityProblems interface From ff7f8a2270edfcaeab4fcd73577c1c8e2db661a7 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 2 Jun 2025 18:00:31 +0100 Subject: [PATCH 12/27] Try not duplicating --- src/logdensityfunction.jl | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 323bf22a4..b0ff7fdc7 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -211,19 +211,7 @@ struct LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext} context::C end function (ld::LogDensityAt)(x::AbstractVector) - varinfo_new = unflatten(ld.varinfo, x) - varinfo_eval = last(evaluate!!(ld.model, varinfo_new, ld.context)) - has_prior = hasacc(varinfo_eval, Val(:LogPrior)) - has_likelihood = hasacc(varinfo_eval, Val(:LogLikelihood)) - if has_prior && has_likelihood - return getlogjoint(varinfo_eval) - elseif has_prior - return getlogprior(varinfo_eval) - elseif has_likelihood - return getloglikelihood(varinfo_eval) - else - error("LogDensityFunction: varinfo tracks neither log prior nor log likelihood") - end + return logdensity_at(x, ld.model, ld.varinfo, ld.context) end ### LogDensityProblems interface From 80db9e2bf197b86423a7b62691487de6b45b3d62 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 2 Jun 2025 18:17:23 +0100 Subject: [PATCH 13/27] Update comment pointing to closure benchmarks --- src/logdensityfunction.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index b0ff7fdc7..c5586f80f 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -297,17 +297,14 @@ There are two ways of dealing with this: The relative performance of the two approaches, however, depends on the AD backend used. Some benchmarks are provided here: -https://github.com/TuringLang/DynamicPPL.jl/pull/806#issuecomment-2658061480 +https://github.com/TuringLang/DynamicPPL.jl/issues/946#issuecomment-2931604829 This function is used to determine whether a given AD backend should use a closure or a constant. If `use_closure(adtype)` returns `true`, then the closure approach will be used. By default, this function returns `false`, i.e. the constant approach will be used. """ -use_closure(::ADTypes.AbstractADType) = false -use_closure(::ADTypes.AutoForwardDiff) = false -use_closure(::ADTypes.AutoMooncake) = false -use_closure(::ADTypes.AutoReverseDiff) = true +use_closure(::ADTypes.AbstractADType) = true """ getmodel(f) From 3af63d59826591e59a23a17d6ce4d29e5367a200 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 19 Jun 2025 17:45:44 +0100 Subject: [PATCH 14/27] Remove `context` from model evaluation (use `model.context` instead) (#952) * Change `evaluate!!` API, add `sample!!` * Fix literally everything else that I broke * Fix some docstrings * fix ForwardDiffExt (look, multiple dispatch bad...) * Changelog * fix a test * Fix docstrings * use `sample!!` * Fix a couple more cases * Globally rename `sample!!` -> `evaluate_and_sample!!`, add changelog warning --- HISTORY.md | 36 +++++ benchmarks/src/DynamicPPLBenchmarks.jl | 3 +- docs/src/api.md | 18 ++- ext/DynamicPPLForwardDiffExt.jl | 1 - ext/DynamicPPLJETExt.jl | 22 +-- ext/DynamicPPLMCMCChainsExt.jl | 2 +- src/DynamicPPL.jl | 1 + src/compiler.jl | 30 ++-- src/debug_utils.jl | 53 +++---- src/experimental.jl | 16 +- src/extract_priors.jl | 4 +- src/logdensityfunction.jl | 101 +++++------- src/model.jl | 211 +++++++++++++++---------- src/pointwise_logdensities.jl | 59 +++---- src/sampler.jl | 17 +- src/simple_varinfo.jl | 49 +++--- src/submodel_macro.jl | 30 ++-- src/test_utils/ad.jl | 21 +-- src/test_utils/model_interface.jl | 2 +- src/test_utils/varinfo.jl | 2 +- src/threadsafe.jl | 11 +- src/transforming.jl | 9 +- src/values_as_in_model.jl | 14 +- src/varinfo.jl | 156 ++++-------------- test/ad.jl | 5 +- test/compiler.jl | 23 ++- test/context_implementations.jl | 4 +- test/contexts.jl | 3 +- test/debug_utils.jl | 18 +-- test/ext/DynamicPPLForwardDiffExt.jl | 5 +- test/ext/DynamicPPLJETExt.jl | 5 +- test/linking.jl | 2 +- test/model.jl | 24 ++- test/simple_varinfo.jl | 16 +- test/threadsafe.jl | 40 ++--- test/varinfo.jl | 33 ++-- test/varnamedvector.jl | 6 +- 37 files changed, 477 insertions(+), 575 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 7ab9ee1dc..9edac441f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -18,6 +18,42 @@ This release overhauls how VarInfo objects track variables such as the log joint - `getlogp` now returns a `NamedTuple` with keys `logprior` and `loglikelihood`. If you want the log joint probability, which is what `getlogp` used to return, use `getlogjoint`. - Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The `acclogp!!` method with a single scalar value has been deprecated and falls back on `accloglikelihood!!`, and the single scalar version of `setlogp!!` has been removed. Corresponding setter/accumulator functions exist for the log prior as well. +### Evaluation contexts + +Historically, evaluating a DynamicPPL model has required three arguments: a model, some kind of VarInfo, and a context. +It's less known, though, that since DynamicPPL 0.14.0 the _model_ itself actually contains a context as well. +This version therefore excises the context argument, and instead uses `model.context` as the evaluation context. + +The upshot of this is that many functions that previously took a context argument now no longer do. +There were very few such functions where the context argument was actually used (most of them simply took `DefaultContext()` as the default value). + +`evaluate!!(model, varinfo, ext_context)` is deprecated, and broadly speaking you should replace calls to that with `new_model = contextualize(model, ext_context); evaluate!!(new_model, varinfo)`. +If the 'external context' `ext_context` is a parent context, then you should wrap `model.context` appropriately to ensure that its information content is not lost. +If, on the other hand, `ext_context` is a `DefaultContext`, then you can just drop the argument entirely. + +To aid with this process, `contextualize` is now exported from DynamicPPL. + +The main situation where one _did_ want to specify an additional evaluation context was when that context was a `SamplingContext`. +Doing this would allow you to run the model and sample fresh values, instead of just using the values that existed in the VarInfo object. +Thus, this release also introduces the **unexported** function `evaluate_and_sample!!`. +Essentially, `evaluate_and_sample!!(rng, model, varinfo, sampler)` is a drop-in replacement for `evaluate!!(model, varinfo, SamplingContext(rng, sampler))`. +**Do note that this is an internal method**, and its name or semantics are liable to change in the future without warning. + +There are many methods that no longer take a context argument, and listing them all would be too much. +However, here are the more user-facing ones: + + - `LogDensityFunction` no longer has a context field (or type parameter) + - `DynamicPPL.TestUtils.AD.run_ad` no longer uses a context (and the returned `ADResult` object no longer has a context field) + - `VarInfo(rng, model, sampler)` and other VarInfo constructors / functions that made VarInfos (e.g. `typed_varinfo`) from a model + - `(::Model)(args...)`: specifically, this now only takes `rng` and `varinfo` arguments (with both being optional) + - If you are using the `__context__` special variable inside a model, you will now have to use `__model__.context` instead + +And a couple of more internal changes: + + - `evaluate!!`, `evaluate_threadsafe!!`, and `evaluate_threadunsafe!!` no longer accept context arguments + - `evaluate!!` no longer takes rng and sampler (if you used this, you should use `evaluate_and_sample!!` instead, or construct your own `SamplingContext`) + - The model evaluation function, `model.f` for some `model::Model`, no longer takes a context as an argument + ## 0.36.12 Removed several unexported functions. diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 6f486e2f5..26ec35b65 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -81,13 +81,12 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: end adbackend = to_backend(adbackend) - context = DynamicPPL.DefaultContext() if islinked vi = DynamicPPL.link(vi, model) end - f = DynamicPPL.LogDensityFunction(model, vi, context; adtype=adbackend) + f = DynamicPPL.LogDensityFunction(model, vi; adtype=adbackend) # The parameters at which we evaluate f. θ = vi[:] diff --git a/docs/src/api.md b/docs/src/api.md index 8e5c64886..32b3d80a6 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -36,6 +36,12 @@ getargnames getmissings ``` +The context of a model can be set using [`contextualize`](@ref): + +```@docs +contextualize +``` + ## Evaluation With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref). @@ -438,13 +444,21 @@ DynamicPPL.varname_and_value_leaves ### Evaluation Contexts -Internally, both sampling and evaluation of log densities are performed with [`AbstractPPL.evaluate!!`](@ref). +Internally, model evaluation is performed with [`AbstractPPL.evaluate!!`](@ref). ```@docs AbstractPPL.evaluate!! ``` -The behaviour of a model execution can be changed with evaluation contexts that are passed as additional argument to the model function. +This method mutates the `varinfo` used for execution. +By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`. +To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method: + +```@docs +DynamicPPL.evaluate_and_sample!! +``` + +The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. Contexts are subtypes of `AbstractPPL.AbstractContext`. ```@docs diff --git a/ext/DynamicPPLForwardDiffExt.jl b/ext/DynamicPPLForwardDiffExt.jl index 6bd7a5d94..7ea51918f 100644 --- a/ext/DynamicPPLForwardDiffExt.jl +++ b/ext/DynamicPPLForwardDiffExt.jl @@ -11,7 +11,6 @@ function DynamicPPL.tweak_adtype( ad::ADTypes.AutoForwardDiff{chunk_size}, ::DynamicPPL.Model, vi::DynamicPPL.AbstractVarInfo, - ::DynamicPPL.AbstractContext, ) where {chunk_size} params = vi[:] diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl index aa95093f2..760d17bb0 100644 --- a/ext/DynamicPPLJETExt.jl +++ b/ext/DynamicPPLJETExt.jl @@ -4,15 +4,10 @@ using DynamicPPL: DynamicPPL using JET: JET function DynamicPPL.Experimental.is_suitable_varinfo( - model::DynamicPPL.Model, - context::DynamicPPL.AbstractContext, - varinfo::DynamicPPL.AbstractVarInfo; - only_ddpl::Bool=true, + model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_ddpl::Bool=true ) # Let's make sure that both evaluation and sampling doesn't result in type errors. - f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo, context - ) + f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo) # If specified, we only check errors originating somewhere in the DynamicPPL.jl. # This way we don't just fall back to untyped if the user's code is the issue. result = if only_ddpl @@ -24,14 +19,19 @@ function DynamicPPL.Experimental.is_suitable_varinfo( end function DynamicPPL.Experimental._determine_varinfo_jet( - model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true + model::DynamicPPL.Model; only_ddpl::Bool=true ) + # Use SamplingContext to test type stability. + sampling_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(model.context) + ) + # First we try with the typed varinfo. - varinfo = DynamicPPL.typed_varinfo(model, context) + varinfo = DynamicPPL.typed_varinfo(sampling_model) # Let's make sure that both evaluation and sampling doesn't result in type errors. issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo( - model, context, varinfo; only_ddpl + sampling_model, varinfo; only_ddpl ) if !issuccess @@ -46,7 +46,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet( else # Warn the user that we can't use the type stable one. @warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." - DynamicPPL.untyped_varinfo(model, context) + DynamicPPL.untyped_varinfo(sampling_model) end end diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 70f0f0182..a29696720 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -115,7 +115,7 @@ function DynamicPPL.predict( iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) predictive_samples = map(iters) do (sample_idx, chain_idx) DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx) - model(rng, varinfo, DynamicPPL.SampleFromPrior()) + varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo)) vals = DynamicPPL.values_as_in_model(model, false, varinfo) varname_vals = mapreduce( diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 9217deb4f..4bd4f2529 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -102,6 +102,7 @@ export AbstractVarInfo, # LogDensityFunction LogDensityFunction, # Contexts + contextualize, SamplingContext, DefaultContext, PrefixContext, diff --git a/src/compiler.jl b/src/compiler.jl index b783c2a13..22dff33a2 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,4 +1,4 @@ -const INTERNALNAMES = (:__model__, :__context__, :__varinfo__) +const INTERNALNAMES = (:__model__, :__varinfo__) """ need_concretize(expr) @@ -63,9 +63,9 @@ used in its place. function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr)) return quote if $(DynamicPPL.contextual_isassumption)( - __context__, $(DynamicPPL.prefix)(__context__, $vn) + __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) ) - # Considered an assumption by `__context__` which means either: + # Considered an assumption by `__model__.context` which means either: # 1. We hit the default implementation, e.g. using `DefaultContext`, # which in turn means that we haven't considered if it's one of # the model arguments, hence we need to check this. @@ -116,7 +116,7 @@ end isfixed(expr, vn) = false function isfixed(::Union{Symbol,Expr}, vn) return :($(DynamicPPL.contextual_isfixed)( - __context__, $(DynamicPPL.prefix)(__context__, $vn) + __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) )) end @@ -417,7 +417,7 @@ function generate_assign(left, right) return quote $right_val = $right if $(DynamicPPL.is_extracting_values)(__varinfo__) - $vn = $(DynamicPPL.prefix)(__context__, $(make_varname_expression(left))) + $vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left))) __varinfo__ = $(map_accumulator!!)( $acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel) ) @@ -431,7 +431,11 @@ function generate_tilde_literal(left, right) @gensym value return quote $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, nothing, __varinfo__ + __model__.context, + $(DynamicPPL.check_tilde_rhs)($right), + $left, + nothing, + __varinfo__, ) $value end @@ -456,7 +460,7 @@ function generate_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $(DynamicPPL.isfixed(left, vn)) $left = $(DynamicPPL.getfixed_nested)( - __context__, $(DynamicPPL.prefix)(__context__, $vn) + __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) ) elseif $isassumption $(generate_tilde_assume(left, dist, vn)) @@ -464,12 +468,12 @@ function generate_tilde(left, right) # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) $left = $(DynamicPPL.getconditioned_nested)( - __context__, $(DynamicPPL.prefix)(__context__, $vn) + __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) ) end $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( - __context__, + __model__.context, $(DynamicPPL.check_tilde_rhs)($dist), $(maybe_view(left)), $vn, @@ -494,7 +498,7 @@ function generate_tilde_assume(left, right, vn) return quote $value, __varinfo__ = $(DynamicPPL.tilde_assume!!)( - __context__, + __model__.context, $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., __varinfo__, ) @@ -652,11 +656,7 @@ function build_output(modeldef, linenumbernode) # Add the internal arguments to the user-specified arguments (positional + keywords). evaluatordef[:args] = vcat( - [ - :(__model__::$(DynamicPPL.Model)), - :(__varinfo__::$(DynamicPPL.AbstractVarInfo)), - :(__context__::$(DynamicPPL.AbstractContext)), - ], + [:(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo))], args, ) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 54852a736..4343ce8ac 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -131,9 +131,7 @@ A context used for checking validity of a model. # Fields $(FIELDS) """ -struct DebugContext{M<:Model,C<:AbstractContext} <: AbstractContext - "model that is being run" - model::M +struct DebugContext{C<:AbstractContext} <: AbstractContext "context used for running the model" context::C "mapping from varnames to the number of times they have been seen" @@ -149,7 +147,6 @@ struct DebugContext{M<:Model,C<:AbstractContext} <: AbstractContext end function DebugContext( - model::Model, context::AbstractContext=DefaultContext(); varnames_seen=OrderedDict{VarName,Int}(), statements=Vector{Stmt}(), @@ -158,7 +155,6 @@ function DebugContext( record_varinfo=false, ) return DebugContext( - model, context, varnames_seen, statements, @@ -344,7 +340,7 @@ function check_varnames_seen(varnames_seen::AbstractDict{VarName,Int}) end # A check we run on the model before evaluating it. -function check_model_pre_evaluation(context::DebugContext, model::Model) +function check_model_pre_evaluation(model::Model) issuccess = true # If something is in the model arguments, then it should NOT be in `condition`, # nor should there be any symbol present in `condition` that has the same symbol. @@ -361,8 +357,8 @@ function check_model_pre_evaluation(context::DebugContext, model::Model) return issuccess end -function check_model_post_evaluation(context::DebugContext, model::Model) - return check_varnames_seen(context.varnames_seen) +function check_model_post_evaluation(model::Model) + return check_varnames_seen(model.context.varnames_seen) end """ @@ -438,25 +434,23 @@ function check_model_and_trace( rng::Random.AbstractRNG, model::Model; varinfo=VarInfo(), - context=SamplingContext(rng), error_on_failure=false, kwargs..., ) # Execute the model with the debug context. debug_context = DebugContext( - model, context; error_on_failure=error_on_failure, kwargs... + SamplingContext(rng, model.context); error_on_failure=error_on_failure, kwargs... ) + debug_model = DynamicPPL.contextualize(model, debug_context) # Perform checks before evaluating the model. - issuccess = check_model_pre_evaluation(debug_context, model) + issuccess = check_model_pre_evaluation(debug_model) # Force single-threaded execution. - retval, varinfo_result = DynamicPPL.evaluate_threadunsafe!!( - model, varinfo, debug_context - ) + DynamicPPL.evaluate_threadunsafe!!(debug_model, varinfo) # Perform checks after evaluating the model. - issuccess &= check_model_post_evaluation(debug_context, model) + issuccess &= check_model_post_evaluation(debug_model) if !issuccess && error_on_failure error("model check failed") @@ -535,14 +529,13 @@ function has_static_constraints( end """ - gen_evaluator_call_with_types(model[, varinfo, context]) + gen_evaluator_call_with_types(model[, varinfo]) Generate the evaluator call and the types of the arguments. # Arguments - `model::Model`: The model whose evaluator is of interest. - `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. -- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). # Returns A 2-tuple with the following elements: @@ -551,11 +544,9 @@ A 2-tuple with the following elements: - `argtypes::Type{<:Tuple}`: The types of the arguments for the evaluator. """ function gen_evaluator_call_with_types( - model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=DefaultContext(), + model::Model, varinfo::AbstractVarInfo=VarInfo(model) ) - args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) + args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo) return if isempty(kwargs) (model.f, Base.typesof(args...)) else @@ -564,7 +555,7 @@ function gen_evaluator_call_with_types( end """ - model_warntype(model[, varinfo, context]; optimize=true) + model_warntype(model[, varinfo]; optimize=true) Check the type stability of the model's evaluator, warning about any potential issues. @@ -573,23 +564,19 @@ This simply calls `@code_warntype` on the model's evaluator, filling in internal # Arguments - `model::Model`: The model to check. - `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. -- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). # Keyword Arguments - `optimize::Bool`: Whether to generate optimized code. Default: `false`. """ function model_warntype( - model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=DefaultContext(); - optimize::Bool=false, + model::Model, varinfo::AbstractVarInfo=VarInfo(model), optimize::Bool=false ) - ftype, argtypes = gen_evaluator_call_with_types(model, varinfo, context) + ftype, argtypes = gen_evaluator_call_with_types(model, varinfo) return InteractiveUtils.code_warntype(ftype, argtypes; optimize=optimize) end """ - model_typed(model[, varinfo, context]; optimize=true) + model_typed(model[, varinfo]; optimize=true) Return the type inference for the model's evaluator. @@ -598,18 +585,14 @@ This simply calls `@code_typed` on the model's evaluator, filling in internal ar # Arguments - `model::Model`: The model to check. - `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. -- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). # Keyword Arguments - `optimize::Bool`: Whether to generate optimized code. Default: `true`. """ function model_typed( - model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=DefaultContext(); - optimize::Bool=true, + model::Model, varinfo::AbstractVarInfo=VarInfo(model), optimize::Bool=true ) - ftype, argtypes = gen_evaluator_call_with_types(model, varinfo, context) + ftype, argtypes = gen_evaluator_call_with_types(model, varinfo) return only(InteractiveUtils.code_typed(ftype, argtypes; optimize=optimize)) end diff --git a/src/experimental.jl b/src/experimental.jl index 84038803c..974912957 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -4,16 +4,15 @@ using DynamicPPL: DynamicPPL # This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. """ - is_suitable_varinfo(model::Model, context::AbstractContext, varinfo::AbstractVarInfo; kwargs...) + is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...) -Check if the `model` supports evaluation using the provided `context` and `varinfo`. +Check if the `model` supports evaluation using the provided `varinfo`. !!! warning Loading JET.jl is required before calling this function. # Arguments - `model`: The model to verify the support for. -- `context`: The context to use for the model evaluation. - `varinfo`: The varinfo to verify the support for. # Keyword Arguments @@ -29,7 +28,7 @@ function is_suitable_varinfo end function _determine_varinfo_jet end """ - determine_suitable_varinfo(model[, context]; only_ddpl::Bool=true) + determine_suitable_varinfo(model; only_ddpl::Bool=true) Return a suitable varinfo for the given `model`. @@ -41,7 +40,6 @@ See also: [`DynamicPPL.Experimental.is_suitable_varinfo`](@ref). # Arguments - `model`: The model for which to determine the varinfo. -- `context`: The context to use for the model evaluation. Default: `SamplingContext()`. # Keyword Arguments - `only_ddpl`: If `true`, only consider error reports within DynamicPPL.jl. @@ -85,14 +83,10 @@ julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support())) true ``` """ -function determine_suitable_varinfo( - model::DynamicPPL.Model, - context::DynamicPPL.AbstractContext=DynamicPPL.SamplingContext(); - only_ddpl::Bool=true, -) +function determine_suitable_varinfo(model::DynamicPPL.Model; only_ddpl::Bool=true) # If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that. return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing - _determine_varinfo_jet(model, context; only_ddpl) + _determine_varinfo_jet(model; only_ddpl) else # Warn the user. @warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo." diff --git a/src/extract_priors.jl b/src/extract_priors.jl index 9047c9f0a..bd6bdb2f2 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -116,7 +116,7 @@ function extract_priors(rng::Random.AbstractRNG, model::Model) # workaround for the fact that `order` is still hardcoded in VarInfo, and hence you # can't push new variables without knowing the num_produce. Remove this when possible. varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(), NumProduceAccumulator())) - varinfo = last(evaluate!!(model, varinfo, SamplingContext(rng))) + varinfo = last(evaluate_and_sample!!(rng, model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end @@ -135,6 +135,6 @@ function extract_priors(model::Model, varinfo::AbstractVarInfo) varinfo = setaccs!!( deepcopy(varinfo), (PriorDistributionAccumulator(), NumProduceAccumulator()) ) - varinfo = last(evaluate!!(model, varinfo, DefaultContext())) + varinfo = last(evaluate!!(model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index c5586f80f..e7565d137 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -18,8 +18,7 @@ is_supported(::ADTypes.AutoReverseDiff) = true """ LogDensityFunction( model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=DefaultContext(); + varinfo::AbstractVarInfo=VarInfo(model); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing ) @@ -29,10 +28,9 @@ A struct which contains a model, along with all the information necessary to: - and if `adtype` is provided, calculate the gradient of the log density at that point. -At its most basic level, a LogDensityFunction wraps the model together with its -the type of varinfo to be used, as well as the evaluation context. These must -be known in order to calculate the log density (using -[`DynamicPPL.evaluate!!`](@ref)). +At its most basic level, a LogDensityFunction wraps the model together with the +type of varinfo to be used. These must be known in order to calculate the log +density (using [`DynamicPPL.evaluate!!`](@ref)). If the `adtype` keyword argument is provided, then this struct will also store the adtype along with other information for efficient calculation of the @@ -95,14 +93,12 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) ``` """ struct LogDensityFunction{ - M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType} + M<:Model,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType} } <: AbstractModel "model used for evaluation" model::M "varinfo used for evaluation" varinfo::V - "context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable" - context::C "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" adtype::AD "(internal use only) gradient preparation object for the model" @@ -110,35 +106,29 @@ struct LogDensityFunction{ function LogDensityFunction( model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=leafcontext(model.context); + varinfo::AbstractVarInfo=VarInfo(model); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) if adtype === nothing prep = nothing else # Make backend-specific tweaks to the adtype - adtype = tweak_adtype(adtype, model, varinfo, context) + adtype = tweak_adtype(adtype, model, varinfo) # Check whether it is supported is_supported(adtype) || @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." # Get a set of dummy params to use for prep x = map(identity, varinfo[:]) if use_closure(adtype) - prep = DI.prepare_gradient(LogDensityAt(model, varinfo, context), adtype, x) + prep = DI.prepare_gradient(LogDensityAt(model, varinfo), adtype, x) else prep = DI.prepare_gradient( - logdensity_at, - adtype, - x, - DI.Constant(model), - DI.Constant(varinfo), - DI.Constant(context), + logdensity_at, adtype, x, DI.Constant(model), DI.Constant(varinfo) ) end end - return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}( - model, varinfo, context, adtype, prep + return new{typeof(model),typeof(varinfo),typeof(adtype)}( + model, varinfo, adtype, prep ) end end @@ -149,9 +139,9 @@ end adtype::Union{Nothing,ADTypes.AbstractADType} ) -Create a new LogDensityFunction using the model, varinfo, and context from the given -`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, pass -`nothing` as the second argument. +Create a new LogDensityFunction using the model and varinfo from the given +`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, +pass `nothing` as the second argument. """ function LogDensityFunction( f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType} @@ -159,7 +149,7 @@ function LogDensityFunction( return if adtype === f.adtype f # Avoid recomputing prep if not needed else - LogDensityFunction(f.model, f.varinfo, f.context; adtype=adtype) + LogDensityFunction(f.model, f.varinfo; adtype=adtype) end end @@ -168,20 +158,18 @@ end x::AbstractVector, model::Model, varinfo::AbstractVarInfo, - context::AbstractContext ) Evaluate the log density of the given `model` at the given parameter values `x`, -using the given `varinfo` and `context`. Note that the `varinfo` argument is provided -only for its structure, in the sense that the parameters from the vector `x` are inserted -into it, and its own parameters are discarded. It does, however, determine whether the log -prior, likelihood, or joint is returned, based on which accumulators are set in it. +using the given `varinfo`. Note that the `varinfo` argument is provided only +for its structure, in the sense that the parameters from the vector `x` are +inserted into it, and its own parameters are discarded. It does, however, +determine whether the log prior, likelihood, or joint is returned, based on +which accumulators are set in it. """ -function logdensity_at( - x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext -) +function logdensity_at(x::AbstractVector, model::Model, varinfo::AbstractVarInfo) varinfo_new = unflatten(varinfo, x) - varinfo_eval = last(evaluate!!(model, varinfo_new, context)) + varinfo_eval = last(evaluate!!(model, varinfo_new)) has_prior = hasacc(varinfo_eval, Val(:LogPrior)) has_likelihood = hasacc(varinfo_eval, Val(:LogLikelihood)) if has_prior && has_likelihood @@ -196,60 +184,48 @@ function logdensity_at( end """ - LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext}( + LogDensityAt{M<:Model,V<:AbstractVarInfo}( model::M varinfo::V - context::C ) A callable struct that serves the same purpose as `x -> logdensity_at(x, model, -varinfo, context)`. +varinfo)`. """ -struct LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext} +struct LogDensityAt{M<:Model,V<:AbstractVarInfo} model::M varinfo::V - context::C -end -function (ld::LogDensityAt)(x::AbstractVector) - return logdensity_at(x, ld.model, ld.varinfo, ld.context) end +(ld::LogDensityAt)(x::AbstractVector) = logdensity_at(x, ld.model, ld.varinfo) ### LogDensityProblems interface function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,V,C,Nothing}} -) where {M,V,C} + ::Type{<:LogDensityFunction{M,V,Nothing}} +) where {M,V} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,V,C,AD}} -) where {M,V,C,AD<:ADTypes.AbstractADType} + ::Type{<:LogDensityFunction{M,V,AD}} +) where {M,V,AD<:ADTypes.AbstractADType} return LogDensityProblems.LogDensityOrder{1}() end function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) - return logdensity_at(x, f.model, f.varinfo, f.context) + return logdensity_at(x, f.model, f.varinfo) end function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunction{M,V,C,AD}, x::AbstractVector -) where {M,V,C,AD<:ADTypes.AbstractADType} + f::LogDensityFunction{M,V,AD}, x::AbstractVector +) where {M,V,AD<:ADTypes.AbstractADType} f.prep === nothing && error("Gradient preparation not available; this should not happen") x = map(identity, x) # Concretise type # Make branching statically inferrable, i.e. type-stable (even if the two # branches happen to return different types) return if use_closure(f.adtype) - DI.value_and_gradient( - LogDensityAt(f.model, f.varinfo, f.context), f.prep, f.adtype, x - ) + DI.value_and_gradient(LogDensityAt(f.model, f.varinfo), f.prep, f.adtype, x) else DI.value_and_gradient( - logdensity_at, - f.prep, - f.adtype, - x, - DI.Constant(f.model), - DI.Constant(f.varinfo), - DI.Constant(f.context), + logdensity_at, f.prep, f.adtype, x, DI.Constant(f.model), DI.Constant(f.varinfo) ) end end @@ -264,7 +240,6 @@ LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) adtype::ADTypes.AbstractADType, model::Model, varinfo::AbstractVarInfo, - context::AbstractContext ) Return an 'optimised' form of the adtype. This is useful for doing @@ -275,9 +250,7 @@ model. By default, this just returns the input unchanged. """ -tweak_adtype( - adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo, ::AbstractContext -) = adtype +tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype """ use_closure(adtype::ADTypes.AbstractADType) @@ -319,7 +292,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. """ function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype) + return LogDensityFunction(model, f.varinfo; adtype=f.adtype) end """ diff --git a/src/model.jl b/src/model.jl index 3b93fa14d..f46137ed1 100644 --- a/src/model.jl +++ b/src/model.jl @@ -85,6 +85,12 @@ function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); k return Model(f, args, NamedTuple(kwargs), context) end +""" + contextualize(model::Model, context::AbstractContext) + +Return a new `Model` with the same evaluation function and other arguments, but +with its underlying context set to `context`. +""" function contextualize(model::Model, context::AbstractContext) return Model(model.f, model.args, model.defaults, context) end @@ -794,15 +800,23 @@ julia> # Now `a.x` will be sampled. fixed(model::Model) = fixed(model.context) """ - (model::Model)([rng, varinfo, sampler, context]) + (model::Model)([rng, varinfo]) + +Sample from the prior of the `model` with random number generator `rng`. -Sample from the `model` using the `sampler` with random number generator `rng` and the -`context`, and store the sample and log joint probability in `varinfo`. +Returns the model's return value. -The method resets the log joint probability of `varinfo` and increases the evaluation -number of `sampler`. +Note that calling this with an existing `varinfo` object will mutate it. """ -(model::Model)(args...) = first(evaluate!!(model, args...)) +(model::Model)() = model(Random.default_rng(), VarInfo()) +function (model::Model)(varinfo::AbstractVarInfo) + return model(Random.default_rng(), varinfo) +end +# ^ Weird Documenter.jl bug means that we have to write the two above separately +# as it can only detect the `function`-less syntax. +function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo()) + return first(evaluate_and_sample!!(rng, model, varinfo)) +end """ use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) @@ -815,65 +829,69 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) end """ - evaluate!!(model::Model[, rng, varinfo, sampler, context]) + evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler]) -Sample from the `model` using the `sampler` with random number generator `rng` and the -`context`, and store the sample and log joint probability in `varinfo`. +Evaluate the `model` with the given `varinfo`, but perform sampling during the +evaluation using the given `sampler` by wrapping the model's context in a +`SamplingContext`. -Returns both the return-value of the original model, and the resulting varinfo. +If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref). -The method resets the log joint probability of `varinfo` and increases the evaluation -number of `sampler`. +Returns a tuple of the model's return value, plus the updated `varinfo` object. """ -function AbstractPPL.evaluate!!( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext -) - return if use_threadsafe_eval(context, varinfo) - evaluate_threadsafe!!(model, varinfo, context) - else - evaluate_threadunsafe!!(model, varinfo, context) - end -end - -function AbstractPPL.evaluate!!( - model::Model, +function evaluate_and_sample!!( rng::Random.AbstractRNG, - varinfo::AbstractVarInfo=VarInfo(), + model::Model, + varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), ) - return evaluate!!(model, varinfo, SamplingContext(rng, sampler, context)) -end - -function AbstractPPL.evaluate!!(model::Model, context::AbstractContext) - return evaluate!!(model, VarInfo(), context) + sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context)) + return evaluate!!(sampling_model, varinfo) end - -function AbstractPPL.evaluate!!( - model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... +function evaluate_and_sample!!( + model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior() ) - return evaluate!!(model, Random.default_rng(), args...) + return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) end -# without VarInfo -function AbstractPPL.evaluate!!( - model::Model, - rng::Random.AbstractRNG, - sampler::AbstractSampler, - args::AbstractContext..., -) - return evaluate!!(model, rng, VarInfo(), sampler, args...) -end +""" + evaluate!!(model::Model, varinfo) + +Evaluate the `model` with the given `varinfo`. + +If multiple threads are available, the varinfo provided will be wrapped in a +`ThreadSafeVarInfo` before evaluation. -# without VarInfo and without AbstractSampler +Returns a tuple of the model's return value, plus the updated `varinfo` +(unwrapped if necessary). + + evaluate!!(model::Model, varinfo, context) + +When an extra context stack is provided, the model's context is inserted into +that context stack. See `combine_model_and_external_contexts`. This method is +deprecated. +""" +function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) + return if use_threadsafe_eval(model.context, varinfo) + evaluate_threadsafe!!(model, varinfo) + else + evaluate_threadunsafe!!(model, varinfo) + end +end function AbstractPPL.evaluate!!( - model::Model, rng::Random.AbstractRNG, context::AbstractContext + model::Model, varinfo::AbstractVarInfo, context::AbstractContext ) - return evaluate!!(model, rng, VarInfo(), SampleFromPrior(), context) + Base.depwarn( + "The `context` argument to evaluate!!(model, varinfo, context) is deprecated.", + :dynamicppl_evaluate_context, + ) + new_ctx = combine_model_and_external_contexts(model.context, context) + model = contextualize(model, new_ctx) + return evaluate!!(model, varinfo) end """ - evaluate_threadunsafe!!(model, varinfo, context) + evaluate_threadunsafe!!(model, varinfo) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. @@ -882,8 +900,8 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe!!`](@ref) """ -function evaluate_threadunsafe!!(model, varinfo, context) - return _evaluate!!(model, resetlogp!!(varinfo), context) +function evaluate_threadunsafe!!(model, varinfo) + return _evaluate!!(model, resetlogp!!(varinfo)) end """ @@ -897,31 +915,78 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe!!`](@ref) """ -function evaluate_threadsafe!!(model, varinfo, context) +function evaluate_threadsafe!!(model, varinfo) wrapper = ThreadSafeVarInfo(resetlogp!!(varinfo)) - result, wrapper_new = _evaluate!!(model, wrapper, context) + result, wrapper_new = _evaluate!!(model, wrapper) + # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it + # will return the underlying VI, which is a bit counterintuitive (because + # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it + # again). return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) end """ + _evaluate!!(model::Model, varinfo) + +Evaluate the `model` with the given `varinfo`. + +This function does not wrap the varinfo in a `ThreadSafeVarInfo`. It also does not +reset the log probability of the `varinfo` before running. + _evaluate!!(model::Model, varinfo, context) -Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. +If an additional `context` is provided, the model's context is combined with +that context before evaluation. """ -function _evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext) - args, kwargs = make_evaluate_args_and_kwargs(model, varinfo, context) +function _evaluate!!(model::Model, varinfo::AbstractVarInfo) + args, kwargs = make_evaluate_args_and_kwargs(model, varinfo) return model.f(args...; kwargs...) end +function _evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext) + # TODO(penelopeysm): We don't really need this, but it's a useful + # convenience method. We could remove it after we get rid of the + # evaluate_threadsafe!! stuff (in favour of making users call evaluate!! + # with a TSVI themselves). + new_ctx = combine_model_and_external_contexts(model.context, context) + model = contextualize(model, new_ctx) + return _evaluate!!(model, varinfo) +end is_splat_symbol(s::Symbol) = startswith(string(s), "#splat#") """ - make_evaluate_args_and_kwargs(model, varinfo, context) + combine_model_and_external_contexts(model_context, external_context) + +Combine a context from a model and an external context into a single context. + +The resulting context stack has the following structure: + + `external_context` -> `childcontext(external_context)` -> ... -> + `model_context` -> `childcontext(model_context)` -> ... -> + `leafcontext(external_context)` + +The reason for this is that we want to give `external_context` precedence over +`model_context`, while also preserving the leaf context of `external_context`. +We can do this by + +1. Set the leaf context of `model_context` to `leafcontext(external_context)`. +2. Set leaf context of `external_context` to the context resulting from (1). +""" +function combine_model_and_external_contexts( + model_context::AbstractContext, external_context::AbstractContext +) + return setleafcontext( + external_context, setleafcontext(model_context, leafcontext(external_context)) + ) +end + +""" + make_evaluate_args_and_kwargs(model, varinfo) Return the arguments and keyword arguments to be passed to the evaluator of the model, i.e. `model.f`e. """ @generated function make_evaluate_args_and_kwargs( - model::Model{_F,argnames}, varinfo::AbstractVarInfo, context::AbstractContext + model::Model{_F,argnames}, varinfo::AbstractVarInfo ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) @@ -930,18 +995,7 @@ Return the arguments and keyword arguments to be passed to the evaluator of the :($matchingvalue(varinfo, model.args.$var)) end for var in argnames ] - - # We want to give `context` precedence over `model.context` while also - # preserving the leaf context of `context`. We can do this by - # 1. Set the leaf context of `model.context` to `leafcontext(context)`. - # 2. Set leaf context of `context` to the context resulting from (1). - # The result is: - # `context` -> `childcontext(context)` -> ... -> `model.context` - # -> `childcontext(model.context)` -> ... -> `leafcontext(context)` return quote - context_new = setleafcontext( - context, setleafcontext(model.context, leafcontext(context)) - ) args = ( model, # Maybe perform `invlink!!` once prior to evaluation to avoid @@ -949,7 +1003,6 @@ Return the arguments and keyword arguments to be passed to the evaluator of the # speeding up computation. See docs for `maybe_invlink_before_eval!!` # for more information. maybe_invlink_before_eval!!(varinfo, model), - context_new, $(unwrap_args...), ) kwargs = model.defaults @@ -985,15 +1038,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} - x = last( - evaluate!!( - model, - SimpleVarInfo{Float64}(OrderedDict()), - # NOTE: Use `leafcontext` here so we a) avoid overriding the leaf context of `model`, - # and b) avoid double-stacking the parent contexts. - SamplingContext(rng, SampleFromPrior(), leafcontext(model.context)), - ), - ) + x = last(evaluate_and_sample!!(rng, model, SimpleVarInfo{Float64}(OrderedDict()))) return values_as(x, T) end @@ -1010,7 +1055,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logprior`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - return getlogjoint(last(evaluate!!(model, varinfo, DefaultContext()))) + return getlogjoint(last(evaluate!!(model, varinfo))) end """ @@ -1064,7 +1109,7 @@ function logprior(model::Model, varinfo::AbstractVarInfo) LogPriorAccumulator() end varinfo = setaccs!!(deepcopy(varinfo), (logprioracc,)) - return getlogprior(last(evaluate!!(model, varinfo, DefaultContext()))) + return getlogprior(last(evaluate!!(model, varinfo))) end """ @@ -1118,7 +1163,7 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) LogLikelihoodAccumulator() end varinfo = setaccs!!(deepcopy(varinfo), (loglikelihoodacc,)) - return getloglikelihood(last(evaluate!!(model, varinfo, DefaultContext()))) + return getloglikelihood(last(evaluate!!(model, varinfo))) end """ @@ -1158,7 +1203,7 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC end """ - predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) + predict([rng::Random.AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) Generate samples from the posterior predictive distribution by evaluating `model` at each set of parameter values provided in `chain`. The number of posterior predictive samples matches @@ -1172,7 +1217,7 @@ function predict( return map(chain) do params_varinfo vi = deepcopy(varinfo) DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple)) - model(rng, vi, SampleFromPrior()) + model(rng, vi) return vi end end diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index b6b97c8f9..59cc5e1bb 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -3,9 +3,10 @@ An accumulator that stores the log-probabilities of each variable in a model. -Internally this context stores the log-probabilities in a dictionary, where the keys are -the variable names and the values are vectors of log-probabilities. Each element in a vector -corresponds to one execution of the model. +Internally this accumulator stores the log-probabilities in a dictionary, where +the keys are the variable names and the values are vectors of +log-probabilities. Each element in a vector corresponds to one execution of the +model. `whichlogprob` is a symbol that can be `:both`, `:prior`, or `:likelihood`, and specifies which log-probabilities to store in the accumulator. `KeyType` is the type by which variable @@ -98,7 +99,6 @@ end model::Model, chain::Chains, keytype=String, - context=DefaultContext(), ::Val{whichlogprob}=Val(:both), ) @@ -107,9 +107,9 @@ with keys corresponding to symbols of the variables, and values being matrices of shape `(num_chains, num_samples)`. `keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -Currently, only `String` and `VarName` are supported. `context` is the evaluation context, -and `whichlogprob` specifies which log-probabilities to compute. It can be `:both`, -`:prior`, or `:likelihood`. +Currently, only `String` and `VarName` are supported. `whichlogprob` specifies +which log-probabilities to compute. It can be `:both`, `:prior`, or +`:likelihood`. See also: [`pointwise_loglikelihoods`](@ref), [`pointwise_loglikelihoods`](@ref). @@ -211,11 +211,7 @@ julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ``` """ function pointwise_logdensities( - model::Model, - chain, - ::Type{KeyType}=String, - context::AbstractContext=DefaultContext(), - ::Val{whichlogprob}=Val(:both), + model::Model, chain, ::Type{KeyType}=String, ::Val{whichlogprob}=Val(:both) ) where {KeyType,whichlogprob} # Get the data by executing the model once vi = VarInfo(model) @@ -229,7 +225,7 @@ function pointwise_logdensities( setval!(vi, chain, sample_idx, chain_idx) # Execute model - vi = last(evaluate!!(model, vi, context)) + vi = last(evaluate!!(model, vi)) end logps = getacc(vi, Val(accumulator_name(AccType))).logps @@ -242,55 +238,46 @@ function pointwise_logdensities( end function pointwise_logdensities( - model::Model, - varinfo::AbstractVarInfo, - context::AbstractContext=DefaultContext(), - ::Val{whichlogprob}=Val(:both), + model::Model, varinfo::AbstractVarInfo, ::Val{whichlogprob}=Val(:both) ) where {whichlogprob} AccType = PointwiseLogProbAccumulator{whichlogprob} varinfo = setaccs!!(varinfo, (AccType(),)) - varinfo = last(evaluate!!(model, varinfo, context)) + varinfo = last(evaluate!!(model, varinfo)) return getacc(varinfo, Val(accumulator_name(AccType))).logps end """ - pointwise_loglikelihoods(model, chain[, keytype, context]) + pointwise_loglikelihoods(model, chain[, keytype]) Compute the pointwise log-likelihoods of the model given the chain. -This is the same as `pointwise_logdensities(model, chain, context)`, but only +This is the same as `pointwise_logdensities(model, chain)`, but only including the likelihood terms. See also: [`pointwise_logdensities`](@ref), [`pointwise_prior_logdensities`](@ref). """ -function pointwise_loglikelihoods( - model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() -) where {T} - return pointwise_logdensities(model, chain, T, context, Val(:likelihood)) +function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} + return pointwise_logdensities(model, chain, T, Val(:likelihood)) end -function pointwise_loglikelihoods( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() -) - return pointwise_logdensities(model, varinfo, context, Val(:likelihood)) +function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) + return pointwise_logdensities(model, varinfo, Val(:likelihood)) end """ - pointwise_prior_logdensities(model, chain[, keytype, context]) + pointwise_prior_logdensities(model, chain[, keytype]) Compute the pointwise log-prior-densities of the model given the chain. -This is the same as `pointwise_logdensities(model, chain, context)`, but only +This is the same as `pointwise_logdensities(model, chain)`, but only including the prior terms. See also: [`pointwise_logdensities`](@ref), [`pointwise_loglikelihoods`](@ref). """ function pointwise_prior_logdensities( - model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() + model::Model, chain, keytype::Type{T}=String ) where {T} - return pointwise_logdensities(model, chain, T, context, Val(:prior)) + return pointwise_logdensities(model, chain, T, Val(:prior)) end -function pointwise_prior_logdensities( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() -) - return pointwise_logdensities(model, varinfo, context, Val(:prior)) +function pointwise_prior_logdensities(model::Model, varinfo::AbstractVarInfo) + return pointwise_logdensities(model, varinfo, Val(:prior)) end diff --git a/src/sampler.jl b/src/sampler.jl index 49d910fec..673b5128f 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -58,12 +58,12 @@ function AbstractMCMC.step( kwargs..., ) vi = VarInfo() - model(rng, vi, sampler) + DynamicPPL.evaluate_and_sample!!(rng, model, vi, sampler) return vi, nothing end """ - default_varinfo(rng, model, sampler[, context]) + default_varinfo(rng, model, sampler) Return a default varinfo object for the given `model` and `sampler`. @@ -71,22 +71,13 @@ Return a default varinfo object for the given `model` and `sampler`. - `rng::Random.AbstractRNG`: Random number generator. - `model::Model`: Model for which we want to create a varinfo object. - `sampler::AbstractSampler`: Sampler which will make use of the varinfo object. -- `context::AbstractContext`: Context in which the model is evaluated. # Returns - `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`. """ function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler) - return default_varinfo(rng, model, sampler, DefaultContext()) -end -function default_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler, - context::AbstractContext, -) init_sampler = initialsampler(sampler) - return typed_varinfo(rng, model, init_sampler, context) + return typed_varinfo(rng, model, init_sampler) end function AbstractMCMC.sample( @@ -119,7 +110,7 @@ function AbstractMCMC.step( # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 # and https://github.com/TuringLang/Turing.jl/issues/1563 # to avoid that existing variables are resampled - vi = last(evaluate!!(model, vi, DefaultContext())) + vi = last(evaluate!!(model, vi)) end return initialstep(rng, model, spl, vi; initial_params, kwargs...) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 42fcedfb8..ea371c7da 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -36,13 +36,10 @@ julia> m = demo(); julia> rng = StableRNG(42); -julia> ### Sampling ### - ctx = SamplingContext(rng, SampleFromPrior(), DefaultContext()); - julia> # In the `NamedTuple` version we need to provide the place-holder values for # the variables which are using "containers", e.g. `Array`. # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo((x = ones(2), )), ctx); + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo((x = ones(2), ))); julia> # (✓) Vroom, vroom! FAST!!! vi[@varname(x[1])] @@ -60,12 +57,12 @@ julia> vi[@varname(x[1:2])] 1.3736306979834252 julia> # (×) If we don't provide the container... - _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo(), ctx); vi + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); vi ERROR: type NamedTuple has no field x [...] julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(OrderedDict()), ctx); + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict())); julia> # (✓) Sort of fast, but only possible at runtime. vi[@varname(x[1])] @@ -94,28 +91,28 @@ demo_constrained (generic function with 2 methods) julia> m = demo_constrained(); -julia> _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo(), ctx); +julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ 1.8632965762164932 -julia> _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx); +julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ -0.21080155351918753 -julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true), ctx); + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ 0.6225185067787314 -julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true @@ -128,7 +125,7 @@ julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0))) julia> # (✓) Positive probability mass on negative numbers! - getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) + getlogjoint(last(DynamicPPL.evaluate!!(m, vi))) -1.3678794411714423 julia> # While if we forget to indicate that it's transformed: @@ -136,7 +133,7 @@ julia> # While if we forget to indicate that it's transformed: SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0))) julia> # (✓) No probability mass on negative numbers! - getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) + getlogjoint(last(DynamicPPL.evaluate!!(m, vi))) -Inf ``` @@ -228,15 +225,25 @@ function SimpleVarInfo(; kwargs...) end # Constructor from `Model`. -function SimpleVarInfo( - model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... -) - return SimpleVarInfo{LogProbType}(model, args...) +function SimpleVarInfo{T}( + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() +) where {T<:Real} + new_model = contextualize(model, SamplingContext(rng, sampler, model.context)) + return last(evaluate!!(new_model, SimpleVarInfo{T}())) end function SimpleVarInfo{T}( - model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... + model::Model, sampler::AbstractSampler=SampleFromPrior() ) where {T<:Real} - return last(evaluate!!(model, SimpleVarInfo{T}(), args...)) + return SimpleVarInfo{T}(Random.default_rng(), model, sampler) +end +# Constructors without type param +function SimpleVarInfo( + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() +) + return SimpleVarInfo{LogProbType}(rng, model, sampler) +end +function SimpleVarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return SimpleVarInfo{LogProbType}(Random.default_rng(), model, sampler) end # Constructor from `VarInfo`. @@ -252,12 +259,12 @@ end function untyped_simple_varinfo(model::Model) varinfo = SimpleVarInfo(OrderedDict()) - return last(evaluate!!(model, varinfo, SamplingContext())) + return last(evaluate_and_sample!!(model, varinfo)) end function typed_simple_varinfo(model::Model) varinfo = SimpleVarInfo{Float64}() - return last(evaluate!!(model, varinfo, SamplingContext())) + return last(evaluate_and_sample!!(model, varinfo)) end function unflatten(svi::SimpleVarInfo, x::AbstractVector) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index bd08b427e..67c3a8c18 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -207,41 +207,41 @@ ERROR: LoadError: cannot automatically prefix with no left-hand side resolved at runtime. """ macro submodel(prefix_expr, expr) - return submodel(prefix_expr, expr, esc(:__context__)) + return submodel(prefix_expr, expr, esc(:__model__)) end # Automatic prefixing. -function prefix_submodel_context(prefix::Bool, left::Symbol, ctx) - return prefix ? prefix_submodel_context(left, ctx) : ctx +function prefix_submodel_context(prefix::Bool, left::Symbol, model) + return prefix ? prefix_submodel_context(left, model) : :($model.context) end -function prefix_submodel_context(prefix::Bool, left::Expr, ctx) - return prefix ? prefix_submodel_context(varname(left), ctx) : ctx +function prefix_submodel_context(prefix::Bool, left::Expr, model) + return prefix ? prefix_submodel_context(varname(left), model) : :($model.context) end # Manual prefixing. -prefix_submodel_context(prefix, left, ctx) = prefix_submodel_context(prefix, ctx) -function prefix_submodel_context(prefix, ctx) +prefix_submodel_context(prefix, left, model) = prefix_submodel_context(prefix, model) +function prefix_submodel_context(prefix, model) # E.g. `prefix="asd[$i]"` or `prefix=asd` with `asd` to be evaluated. - return :($(PrefixContext)($(Val)($(Symbol)($(esc(prefix)))), $ctx)) + return :($(PrefixContext)($(Val)($(Symbol)($(esc(prefix)))), $model.context)) end -function prefix_submodel_context(prefix::Union{AbstractString,Symbol}, ctx) +function prefix_submodel_context(prefix::Union{AbstractString,Symbol}, model) # E.g. `prefix="asd"`. - return :($(PrefixContext)($(esc(Meta.quot(Val(Symbol(prefix))))), $ctx)) + return :($(PrefixContext)($(esc(Meta.quot(Val(Symbol(prefix))))), $model.context)) end -function prefix_submodel_context(prefix::Bool, ctx) +function prefix_submodel_context(prefix::Bool, model) if prefix error("cannot automatically prefix with no left-hand side") end - return ctx + return :($model.context) end const SUBMODEL_DEPWARN_MSG = "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax." -function submodel(prefix_expr, expr, ctx=esc(:__context__)) +function submodel(prefix_expr, expr, model=esc(:__model__)) prefix_left, prefix = getargs_assignment(prefix_expr) if prefix_left !== :prefix error("$(prefix_left) is not a valid kwarg") @@ -257,7 +257,7 @@ function submodel(prefix_expr, expr, ctx=esc(:__context__)) # `prefix=...` => use it. args_assign = getargs_assignment(expr) return if args_assign === nothing - ctx = prefix_submodel_context(prefix, ctx) + ctx = prefix_submodel_context(prefix, model) quote # Raise deprecation warning to let user know that we recommend using `left ~ to_submodel(model)`. $(Base.depwarn)(SUBMODEL_DEPWARN_MSG, Symbol("@submodel")) @@ -271,7 +271,7 @@ function submodel(prefix_expr, expr, ctx=esc(:__context__)) L, R = args_assign # Now that we have `L` and `R`, we can prefix automagically. try - ctx = prefix_submodel_context(prefix, L, ctx) + ctx = prefix_submodel_context(prefix, L, model) catch e error( "failed to determine prefix from $(L); please specify prefix using the `@submodel prefix=\"your prefix\" ...` syntax", diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 0c267c1c5..5285391b1 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -60,8 +60,6 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} model::Model "The VarInfo that was used" varinfo::AbstractVarInfo - "The evaluation context that was used" - context::AbstractContext "The values at which the model was evaluated" params::Vector{Tparams} "The AD backend that was tested" @@ -92,7 +90,6 @@ end grad_atol=1e-6, varinfo::AbstractVarInfo=link(VarInfo(model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, - context::AbstractContext=DefaultContext(), reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, @@ -146,13 +143,7 @@ Everything else is optional, and can be categorised into several groups: prep_params)`. You could then evaluate the gradient at a different set of parameters using the `params` keyword argument. -3. _How to specify the evaluation context._ - - A `DynamicPPL.AbstractContext` can be passed as the `context` keyword - argument to control the evaluation context. This defaults to - `DefaultContext()`. - -4. _How to specify the results to compare against._ (Only if `test=true`.) +3. _How to specify the results to compare against._ (Only if `test=true`.) Once logp and its gradient has been calculated with the specified `adtype`, it must be tested for correctness. @@ -167,12 +158,12 @@ Everything else is optional, and can be categorised into several groups: The default reference backend is ForwardDiff. If none of these parameters are specified, ForwardDiff will be used to calculate the ground truth. -5. _How to specify the tolerances._ (Only if `test=true`.) +4. _How to specify the tolerances._ (Only if `test=true`.) The tolerances for the value and gradient can be set using `value_atol` and `grad_atol`. These default to 1e-6. -6. _Whether to output extra logging information._ +5. _Whether to output extra logging information._ By default, this function prints messages when it runs. To silence it, set `verbose=false`. @@ -195,7 +186,6 @@ function run_ad( grad_atol::AbstractFloat=1e-6, varinfo::AbstractVarInfo=link(VarInfo(model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, - context::AbstractContext=DefaultContext(), reference_adtype::AbstractADType=REFERENCE_ADTYPE, expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, @@ -207,7 +197,7 @@ function run_ad( verbose && @info "Running AD on $(model.f) with $(adtype)\n" verbose && println(" params : $(params)") - ldf = LogDensityFunction(model, varinfo, context; adtype=adtype) + ldf = LogDensityFunction(model, varinfo; adtype=adtype) value, grad = logdensity_and_gradient(ldf, params) grad = collect(grad) @@ -216,7 +206,7 @@ function run_ad( if test # Calculate ground truth to compare against value_true, grad_true = if expected_value_and_grad === nothing - ldf_reference = LogDensityFunction(model, varinfo, context; adtype=reference_adtype) + ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype) logdensity_and_gradient(ldf_reference, params) else expected_value_and_grad @@ -245,7 +235,6 @@ function run_ad( return ADResult( model, varinfo, - context, params, adtype, value_atol, diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index ce79f2302..93aed074c 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -93,7 +93,7 @@ a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. """ function varnames(model::Model) return collect( - keys(last(DynamicPPL.evaluate!!(model, SimpleVarInfo(Dict()), SamplingContext()))) + keys(last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(Dict())))) ) end diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 07a308c7a..542fc17fc 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -48,7 +48,7 @@ function setup_varinfos( )) do vi # Set them all to the same values and evaluate logp. vi = update_values!!(vi, example_values, varnames) - last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) + last(DynamicPPL.evaluate!!(model, vi)) end if include_threadsafe diff --git a/src/threadsafe.jl b/src/threadsafe.jl index cc07d70bb..51c57651d 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -116,14 +116,17 @@ end # consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates # to define `getacc(vi)`. function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) + model = contextualize( + model, setleafcontext(model.context, DynamicTransformationContext{false}()) + ) + return settrans!!(last(evaluate!!(model, vi)), t) end function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return settrans!!( - last(evaluate!!(model, vi, DynamicTransformationContext{true}())), - NoTransformation(), + model = contextualize( + model, setleafcontext(model.context, DynamicTransformationContext{true}()) ) + return settrans!!(last(evaluate!!(model, vi)), NoTransformation()) end function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) diff --git a/src/transforming.jl b/src/transforming.jl index ddd1ab59f..e3da0ff29 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -51,16 +51,17 @@ function _transform!!( vi::AbstractVarInfo, model::Model, ) - # To transform using DynamicTransformationContext, we evaluate the model, but we do not - # need to use any accumulators other than LogPriorAccumulator (which is affected by the Jacobian of - # the transformation). + # To transform using DynamicTransformationContext, we evaluate the model using that as the leaf context: + model = contextualize(model, setleafcontext(model.context, ctx)) + # but we do not need to use any accumulators other than LogPriorAccumulator + # (which is affected by the Jacobian of the transformation). accs = getaccs(vi) has_logprior = haskey(accs, Val(:LogPrior)) if has_logprior old_logprior = getacc(accs, Val(:LogPrior)) vi = setaccs!!(vi, (old_logprior,)) end - vi = settrans!!(last(evaluate!!(model, vi, ctx)), t) + vi = settrans!!(last(evaluate!!(model, vi)), t) # Restore the accumulators. if has_logprior new_logprior = getacc(vi, Val(:LogPrior)) diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 4d6225c10..4922ddbb0 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -52,7 +52,7 @@ end accumulate_observe!!(acc::ValuesAsInModelAccumulator, right, left, vn) = acc """ - values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext]) + values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo) Get the values of `varinfo` as they would be seen in the model. @@ -69,8 +69,6 @@ space at the cost of additional model evaluations. - `model::Model`: model to extract realizations from. - `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`. - `varinfo::AbstractVarInfo`: variable information to use for the extraction. -- `context::AbstractContext`: evaluation context to use in the extraction. Defaults - to `DynamicPPL.DefaultContext()`. # Examples @@ -124,14 +122,8 @@ julia> # Approach 2: Extract realizations using `values_as_in_model`. true ``` """ -function values_as_in_model( - model::Model, - include_colon_eq::Bool, - varinfo::AbstractVarInfo, - context::AbstractContext=DefaultContext(), -) - accs = getaccs(varinfo) +function values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo) varinfo = setaccs!!(deepcopy(varinfo), (ValuesAsInModelAccumulator(include_colon_eq),)) - varinfo = last(evaluate!!(model, varinfo, context)) + varinfo = last(evaluate!!(model, varinfo)) return getacc(varinfo, Val(:ValuesAsInModel)).values end diff --git a/src/varinfo.jl b/src/varinfo.jl index 20986d1a4..b3380e7f9 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -106,10 +106,10 @@ function VarInfo(meta=Metadata()) end """ - VarInfo([rng, ]model[, sampler, context]) + VarInfo([rng, ]model[, sampler]) Generate a `VarInfo` object for the given `model`, by evaluating it once using -the given `rng`, `sampler`, and `context`. +the given `rng`, `sampler`. !!! warning @@ -122,28 +122,12 @@ the given `rng`, `sampler`, and `context`. instead. """ function VarInfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - return typed_varinfo(rng, model, sampler, context) -end -function VarInfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - # No rng - return VarInfo(Random.default_rng(), model, sampler, context) + return typed_varinfo(rng, model, sampler) end -function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) - # No sampler - return VarInfo(rng, model, SampleFromPrior(), context) -end -function VarInfo(model::Model, context::AbstractContext) - # No sampler, no rng - return VarInfo(Random.default_rng(), model, SampleFromPrior(), context) +function VarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return VarInfo(Random.default_rng(), model, sampler) end const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} @@ -200,42 +184,23 @@ end ######################## """ - untyped_varinfo([rng, ]model[, sampler, context, metadata]) + untyped_varinfo([rng, ]model[, sampler]) -Return a VarInfo object for the given `model` and `context`, which has just a -single `Metadata` as its metadata field. +Construct a VarInfo object for the given `model`, which has just a single +`Metadata` as its metadata field. # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. -- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ function untyped_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - varinfo = VarInfo(Metadata()) - context = SamplingContext(rng, sampler, context) - return last(evaluate!!(model, varinfo, context)) + return last(evaluate_and_sample!!(rng, model, VarInfo(Metadata()), sampler)) end -function untyped_varinfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - # No rng - return untyped_varinfo(Random.default_rng(), model, sampler, context) -end -function untyped_varinfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) - # No sampler - return untyped_varinfo(rng, model, SampleFromPrior(), context) -end -function untyped_varinfo(model::Model, context::AbstractContext) - # No sampler, no rng - return untyped_varinfo(model, SampleFromPrior(), context) +function untyped_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return untyped_varinfo(Random.default_rng(), model, sampler) end """ @@ -298,96 +263,59 @@ function typed_varinfo(vi::NTVarInfo) return vi end """ - typed_varinfo([rng, ]model[, sampler, context, metadata]) + typed_varinfo([rng, ]model[, sampler]) -Return a VarInfo object for the given `model` and `context`, which has a NamedTuple of +Return a VarInfo object for the given `model`, which has a NamedTuple of `Metadata` structs as its metadata field. # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. -- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ function typed_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - return typed_varinfo(untyped_varinfo(rng, model, sampler, context)) -end -function typed_varinfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - # No rng - return typed_varinfo(Random.default_rng(), model, sampler, context) + return typed_varinfo(untyped_varinfo(rng, model, sampler)) end -function typed_varinfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) - # No sampler - return typed_varinfo(rng, model, SampleFromPrior(), context) -end -function typed_varinfo(model::Model, context::AbstractContext) - # No sampler, no rng - return typed_varinfo(model, SampleFromPrior(), context) +function typed_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return typed_varinfo(Random.default_rng(), model, sampler) end """ - untyped_vector_varinfo([rng, ]model[, sampler, context, metadata]) + untyped_vector_varinfo([rng, ]model[, sampler]) -Return a VarInfo object for the given `model` and `context`, which has just a -single `VarNamedVector` as its metadata field. +Return a VarInfo object for the given `model`, which has just a single +`VarNamedVector` as its metadata field. # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. -- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) return VarInfo(md, deepcopy(vi.accs)) end function untyped_vector_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler, context)) -end -function untyped_vector_varinfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - # No rng - return untyped_vector_varinfo(Random.default_rng(), model, sampler, context) -end -function untyped_vector_varinfo( - rng::Random.AbstractRNG, model::Model, context::AbstractContext + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - # No sampler - return untyped_vector_varinfo(rng, model, SampleFromPrior(), context) + return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler)) end -function untyped_vector_varinfo(model::Model, context::AbstractContext) - # No sampler, no rng - return untyped_vector_varinfo(model, SampleFromPrior(), context) +function untyped_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return untyped_vector_varinfo(Random.default_rng(), model, sampler) end """ - typed_vector_varinfo([rng, ]model[, sampler, context, metadata]) + typed_vector_varinfo([rng, ]model[, sampler]) -Return a VarInfo object for the given `model` and `context`, which has a -NamedTuple of `VarNamedVector`s as its metadata field. +Return a VarInfo object for the given `model`, which has a NamedTuple of +`VarNamedVector`s as its metadata field. # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. -- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) @@ -399,30 +327,12 @@ function typed_vector_varinfo(vi::UntypedVectorVarInfo) return VarInfo(nt, deepcopy(vi.accs)) end function typed_vector_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler, context)) -end -function typed_vector_varinfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - # No rng - return typed_vector_varinfo(Random.default_rng(), model, sampler, context) -end -function typed_vector_varinfo( - rng::Random.AbstractRNG, model::Model, context::AbstractContext + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - # No sampler - return typed_vector_varinfo(rng, model, SampleFromPrior(), context) + return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler)) end -function typed_vector_varinfo(model::Model, context::AbstractContext) - # No sampler, no rng - return typed_vector_varinfo(model, SampleFromPrior(), context) +function typed_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return typed_vector_varinfo(Random.default_rng(), model, sampler) end """ diff --git a/test/ad.jl b/test/ad.jl index c34624f5b..0947c017a 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -110,9 +110,8 @@ using DynamicPPL: LogDensityFunction # Compiling the ReverseDiff tape used to fail here spl = Sampler(MyEmptyAlg()) vi = VarInfo(model) - ldf = LogDensityFunction( - model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true) - ) + sampling_model = contextualize(model, SamplingContext(model.context)) + ldf = LogDensityFunction(sampling_model, vi; adtype=AutoReverseDiff(; compile=true)) @test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any end diff --git a/test/compiler.jl b/test/compiler.jl index 2e76de27f..2d1342fea 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -185,10 +185,7 @@ module Issue537 end @model function testmodel_missing3(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __context__.sampler global model_ = __model__ - global context_ = __context__ - global rng_ = __context__.rng global lp = getlogjoint(__varinfo__) return x end @@ -196,18 +193,18 @@ module Issue537 end varinfo = VarInfo(model) @test getlogjoint(varinfo) == lp @test varinfo_ isa AbstractVarInfo - @test model_ === model - @test context_ isa SamplingContext - @test rng_ isa Random.AbstractRNG + # During the model evaluation, its context is wrapped in a + # SamplingContext, so `model_` is not going to be equal to `model`. + # We can still check equality of `f` though. + @test model_.f === model.f + @test model_.context isa SamplingContext + @test model_.context.rng isa Random.AbstractRNG # disable warnings @model function testmodel_missing4(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __context__.sampler global model_ = __model__ - global context_ = __context__ - global rng_ = __context__.rng global lp = getlogjoint(__varinfo__) return x end false @@ -601,13 +598,13 @@ module Issue537 end # an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`. @model empty_model() = return x = 1 empty_vi = VarInfo() - retval_and_vi = DynamicPPL.evaluate!!(empty_model(), empty_vi, SamplingContext()) + retval_and_vi = DynamicPPL.evaluate_and_sample!!(empty_model(), empty_vi) @test retval_and_vi isa Tuple{Int,typeof(empty_vi)} # Even if the return-value is `AbstractVarInfo`, we should return # a `Tuple` with `AbstractVarInfo` in the second component too. @model demo() = return __varinfo__ - retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) + retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() if Threads.nthreads() > 1 @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} @@ -623,11 +620,11 @@ module Issue537 end f(x) = return x^2 return f(1.0) end - retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) + retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) @test retval isa Float64 @model demo() = x ~ Normal() - retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) + retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) # Return-value when using `to_submodel` @model inner() = x ~ Normal() diff --git a/test/context_implementations.jl b/test/context_implementations.jl index ac6321d69..e16b2dc96 100644 --- a/test/context_implementations.jl +++ b/test/context_implementations.jl @@ -5,12 +5,12 @@ μ ~ MvNormal(zeros(2), 4 * I) z = Vector{Int}(undef, length(x)) z ~ product_distribution(Categorical.(fill([0.5, 0.5], length(x)))) - for i in 1:length(x) + for i in eachindex(x) x[i] ~ Normal(μ[z[i]], 0.1) end end - test([1, 1, -1])(VarInfo(), SampleFromPrior(), DefaultContext()) + test([1, 1, -1])(VarInfo()) end @testset "dot tilde with varying sizes" begin diff --git a/test/contexts.jl b/test/contexts.jl index 1dd6a2280..597ab736c 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -184,9 +184,10 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS prefix_vn = @varname(my_prefix) context = DynamicPPL.PrefixContext(prefix_vn, SamplingContext()) + sampling_model = contextualize(model, context) # Sample with the context. varinfo = DynamicPPL.VarInfo() - DynamicPPL.evaluate!!(model, varinfo, context) + DynamicPPL.evaluate!!(sampling_model, varinfo) # Extract the resulting varnames vns_actual = Set(keys(varinfo)) diff --git a/test/debug_utils.jl b/test/debug_utils.jl index d2269e089..8279ac51a 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -1,7 +1,7 @@ @testset "check_model" begin @testset "context interface" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - context = DynamicPPL.DebugUtils.DebugContext(model) + context = DynamicPPL.DebugUtils.DebugContext() DynamicPPL.TestUtils.test_context(context, model) end end @@ -35,9 +35,7 @@ buggy_model = buggy_demo_model() @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) + issuccess = check_model(buggy_model; record_varinfo=false) @test !issuccess @test_throws ErrorException check_model(buggy_model; error_on_failure=true) end @@ -81,9 +79,7 @@ buggy_model = buggy_subsumes_demo_model() @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) + issuccess = check_model(buggy_model; record_varinfo=false) @test !issuccess @test_throws ErrorException check_model(buggy_model; error_on_failure=true) end @@ -98,9 +94,7 @@ buggy_model = buggy_subsumes_demo_model() @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) + issuccess = check_model(buggy_model; record_varinfo=false) @test !issuccess @test_throws ErrorException check_model(buggy_model; error_on_failure=true) end @@ -115,9 +109,7 @@ buggy_model = buggy_subsumes_demo_model() @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) + issuccess = check_model(buggy_model; record_varinfo=false) @test !issuccess @test_throws ErrorException check_model(buggy_model; error_on_failure=true) end diff --git a/test/ext/DynamicPPLForwardDiffExt.jl b/test/ext/DynamicPPLForwardDiffExt.jl index 73a0510e9..44db66296 100644 --- a/test/ext/DynamicPPLForwardDiffExt.jl +++ b/test/ext/DynamicPPLForwardDiffExt.jl @@ -14,17 +14,16 @@ using Test: @test, @testset @model f() = x ~ MvNormal(zeros(MODEL_SIZE), I) model = f() varinfo = VarInfo(model) - context = DefaultContext() @testset "Chunk size setting" for chunksize in (nothing, 0) base_adtype = AutoForwardDiff(; chunksize=chunksize) - new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context) + new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo) @test new_adtype isa AutoForwardDiff{MODEL_SIZE} end @testset "Tag setting" begin base_adtype = AutoForwardDiff() - new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context) + new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo) @test new_adtype.tag isa ForwardDiff.Tag{DynamicPPL.DynamicPPLTag} end end diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 86329a51d..6737cf056 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -62,6 +62,7 @@ @testset "demo models" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + sampling_model = contextualize(model, SamplingContext(model.context)) # Use debug logging below. varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) # Check that the inferred varinfo is indeed suitable for evaluation and sampling @@ -71,7 +72,7 @@ JET.test_call(f_eval, argtypes_eval) f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo, DynamicPPL.SamplingContext() + sampling_model, varinfo ) JET.test_call(f_sample, argtypes_sample) # For our demo models, they should all result in typed. @@ -85,7 +86,7 @@ ) JET.test_call(f_eval, argtypes_eval) f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, typed_vi, DynamicPPL.SamplingContext() + sampling_model, typed_vi ) JET.test_call(f_sample, argtypes_sample) end diff --git a/test/linking.jl b/test/linking.jl index 4f1707263..b0c2dcb5c 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -78,7 +78,7 @@ end vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(m),)) @testset "$(short_varinfo_name(vi))" for vi in vis # Evaluate once to ensure we have `logp` value. - vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) + vi = last(DynamicPPL.evaluate!!(model, vi)) vi_linked = if mutable DynamicPPL.link!!(deepcopy(vi), model) else diff --git a/test/model.jl b/test/model.jl index ea260a68c..daa3cc743 100644 --- a/test/model.jl +++ b/test/model.jl @@ -162,12 +162,12 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() for i in 1:10 Random.seed!(100 + i) vi = VarInfo() - model(Random.default_rng(), vi, sampler) + DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) vals = vi[:] Random.seed!(100 + i) vi = VarInfo() - model(Random.default_rng(), vi, sampler) + DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) @test vi[:] == vals end end @@ -223,7 +223,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # Second component of return-value of `evaluate!!` should # be a `DynamicPPL.AbstractVarInfo`. - evaluate_retval = DynamicPPL.evaluate!!(model, vi, DefaultContext()) + evaluate_retval = DynamicPPL.evaluate!!(model, vi) @test evaluate_retval[2] isa DynamicPPL.AbstractVarInfo # Should not return `AbstractVarInfo` when we call the model. @@ -332,11 +332,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test logjoint(model, x) != DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) # Ensure `varnames` is implemented. - vi = last( - DynamicPPL.evaluate!!( - model, SimpleVarInfo(OrderedDict()), SamplingContext() - ), - ) + vi = last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(OrderedDict()))) @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) # Ensure `posterior_mean` is implemented. @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) @@ -397,7 +393,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() models_to_test = [ DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) ] - context = DefaultContext() @testset "$(model.f)" for model in models_to_test vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) @@ -407,13 +402,13 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos @test begin - @inferred(DynamicPPL.evaluate!!(model, varinfo, context)) + @inferred(DynamicPPL.evaluate!!(model, varinfo)) true end varinfo_linked = DynamicPPL.link(varinfo, model) @test begin - @inferred(DynamicPPL.evaluate!!(model, varinfo_linked, context)) + @inferred(DynamicPPL.evaluate!!(model, varinfo_linked)) true end end @@ -492,7 +487,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos varinfo_linked = DynamicPPL.link(varinfo, model) varinfo_linked_result = last( - DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked), DefaultContext()) + DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked)) ) @test getlogjoint(varinfo_linked) ≈ getlogjoint(varinfo_linked_result) end @@ -596,7 +591,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() xs_train = 1:0.1:10 ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) m_lin_reg = linear_reg(xs_train, ys_train) - chain = [evaluate!!(m_lin_reg)[2] for _ in 1:10000] + chain = [ + last(DynamicPPL.evaluate_and_sample!!(m_lin_reg, VarInfo())) for + _ in 1:10000 + ] # chain is generated from the prior @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 6f2f39a64..e300c651e 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -98,7 +98,7 @@ for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) end - vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) + vi = last(DynamicPPL.evaluate!!(model, vi)) # `link!!` vi_linked = link!!(deepcopy(vi), model) @@ -158,7 +158,7 @@ ### Sampling ### # Sample a new varinfo! - _, svi_new = DynamicPPL.evaluate!!(model, svi, SamplingContext()) + _, svi_new = DynamicPPL.evaluate_and_sample!!(model, svi) # Realization for `m` should be different wp. 1. for vn in DynamicPPL.TestUtils.varnames(model) @@ -226,9 +226,9 @@ # Initialize. svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true) - svi_nt = last(DynamicPPL.evaluate!!(model, svi_nt, SamplingContext())) + svi_nt = last(DynamicPPL.evaluate_and_sample!!(model, svi_nt)) svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - svi_vnv = last(DynamicPPL.evaluate!!(model, svi_vnv, SamplingContext())) + svi_vnv = last(DynamicPPL.evaluate_and_sample!!(model, svi_vnv)) for svi in (svi_nt, svi_vnv) # Sample with large variations in unconstrained space. @@ -236,7 +236,7 @@ for vn in keys(svi) svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) end - retval, svi = DynamicPPL.evaluate!!(model, svi, DefaultContext()) + retval, svi = DynamicPPL.evaluate!!(model, svi) @test retval.m == svi[@varname(m)] # `m` is unconstrained @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` @@ -273,7 +273,7 @@ ) # Resulting varinfo should no longer be transformed. - vi_result = last(DynamicPPL.evaluate!!(model, deepcopy(vi), SamplingContext())) + vi_result = last(DynamicPPL.evaluate_and_sample!!(model, deepcopy(vi))) @test !DynamicPPL.istrans(vi_result) # Set the values to something that is out of domain if we're in constrained space. @@ -281,9 +281,7 @@ vi_linked = DynamicPPL.setindex!!(vi_linked, -rand(), vn) end - retval, vi_linked_result = DynamicPPL.evaluate!!( - model, deepcopy(vi_linked), DefaultContext() - ) + retval, vi_linked_result = DynamicPPL.evaluate!!(model, deepcopy(vi_linked)) @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≠ DynamicPPL.tovec(retval.s) # `s` is unconstrained in original diff --git a/test/threadsafe.jl b/test/threadsafe.jl index c673c8b36..24a738a78 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -52,9 +52,10 @@ x[i] ~ Normal(x[i - 1], 1) end end + model = wthreads(x) vi = VarInfo() - wthreads(x)(vi) + model(vi) lp_w_threads = getlogjoint(vi) if Threads.nthreads() == 1 @test vi_ isa VarInfo @@ -64,23 +65,19 @@ println("With `@threads`:") println(" default:") - @time wthreads(x)(vi) + @time model(vi) # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - DynamicPPL.evaluate_threadsafe!!( - wthreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) + sampling_model = contextualize(model, SamplingContext(model.context)) + DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) @test getlogjoint(vi) ≈ lp_w_threads + # check that it's wrapped during the model evaluation @test vi_ isa DynamicPPL.ThreadSafeVarInfo + # ensure that it's unwrapped after evaluation finishes + @test vi isa VarInfo println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!( - wthreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) + @time DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) @model function wothreads(x) global vi_ = __varinfo__ @@ -89,9 +86,10 @@ x[i] ~ Normal(x[i - 1], 1) end end + model = wothreads(x) vi = VarInfo() - wothreads(x)(vi) + model(vi) lp_wo_threads = getlogjoint(vi) if Threads.nthreads() == 1 @test vi_ isa VarInfo @@ -101,24 +99,18 @@ println("Without `@threads`:") println(" default:") - @time wothreads(x)(vi) + @time model(vi) @test lp_w_threads ≈ lp_wo_threads # Ensure that we use `VarInfo`. - DynamicPPL.evaluate_threadunsafe!!( - wothreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) + sampling_model = contextualize(model, SamplingContext(model.context)) + DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) @test getlogjoint(vi) ≈ lp_w_threads @test vi_ isa VarInfo + @test vi isa VarInfo println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!( - wothreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) + @time DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) end end diff --git a/test/varinfo.jl b/test/varinfo.jl index 053fd3203..d788e6215 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -488,10 +488,17 @@ end end model = gdemo([1.0, 1.5], [2.0, 2.5]) - # Check that instantiating the model does not perform linking + # Check that instantiating the model using SampleFromUniform does not + # perform linking + # Note (penelopeysm): The purpose of using SampleFromUniform (SFU) + # specifically in this test is because SFU samples from the linked + # distribution i.e. in unconstrained space. However, it does this not + # by linking the varinfo but by transforming the distributions on the + # fly. That's why it's worth specifically checking that it can do this + # without having to change the VarInfo object. vi = VarInfo() meta = vi.metadata - model(vi, SampleFromUniform()) + _, vi = DynamicPPL.evaluate_and_sample!!(model, vi, SampleFromUniform()) @test all(x -> !istrans(vi, x), meta.vns) # Check that linking and invlinking set the `trans` flag accordingly @@ -565,7 +572,7 @@ end vi = DynamicPPL.untyped_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -574,7 +581,7 @@ end vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -583,7 +590,7 @@ end vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -592,7 +599,7 @@ end ## `SimpleVarInfo{<:NamedTuple}` vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -600,7 +607,7 @@ end ## `SimpleVarInfo{<:Dict}` vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -608,7 +615,7 @@ end ## `SimpleVarInfo{<:VarNamedVector}` vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -690,7 +697,7 @@ end end # Evaluate the model once to update the logp of the varinfo. - varinfo = last(DynamicPPL.evaluate!!(model, varinfo, DefaultContext())) + varinfo = last(DynamicPPL.evaluate!!(model, varinfo)) varinfo_linked = if mutating DynamicPPL.link!!(deepcopy(varinfo), model) @@ -993,9 +1000,7 @@ end # Sampling from `model2` should hit the `istrans(vi) == true` branches # because all the existing variables are linked. model2 = demo(2) - varinfo2 = last( - DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext()) - ) + varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) for vn in [@varname(x[1]), @varname(x[2])] @test DynamicPPL.istrans(varinfo2, vn) end @@ -1014,9 +1019,7 @@ end # Sampling from `model2` should hit the `istrans(vi) == true` branches # because all the existing variables are linked. model2 = demo_dot(2) - varinfo2 = last( - DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext()) - ) + varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) for vn in [@varname(x), @varname(y[1])] @test DynamicPPL.istrans(varinfo2, vn) end diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index f21d458a8..57a8175d4 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -603,9 +603,7 @@ end DynamicPPL.TestUtils.test_values(varinfo, value_true, vns) # Is evaluation correct? - varinfo_eval = last( - DynamicPPL.evaluate!!(model, deepcopy(varinfo), DefaultContext()) - ) + varinfo_eval = last(DynamicPPL.evaluate!!(model, deepcopy(varinfo))) # Log density should be the same. @test getlogjoint(varinfo_eval) ≈ logp_true # Values should be the same. @@ -613,7 +611,7 @@ end # Is sampling correct? varinfo_sample = last( - DynamicPPL.evaluate!!(model, deepcopy(varinfo), SamplingContext()) + DynamicPPL.evaluate_and_sample!!(model, deepcopy(varinfo)) ) # Log density should be different. @test getlogjoint(varinfo_sample) != getlogjoint(varinfo) From 8b67e96adb2c22f8348fa2adfca364cd0f802bbb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 19 Jun 2025 18:09:56 +0100 Subject: [PATCH 15/27] Mark function as Const for Enzyme tests (#957) --- test/integration/enzyme/main.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index 62b7ace4d..ea4ec497d 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -2,12 +2,14 @@ using DynamicPPL.TestUtils: DEMO_MODELS using DynamicPPL.TestUtils.AD: run_ad using ADTypes: AutoEnzyme using Test: @test, @testset -import Enzyme: set_runtime_activity, Forward, Reverse +import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test ADTYPES = Dict( - "EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward)), - "EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse)), + "EnzymeForward" => + AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), + "EnzymeReverse" => + AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const), ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES From 1882f7270bb3160115f6dc5c1f0a4815898bbe9d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 26 Jun 2025 01:57:51 +0100 Subject: [PATCH 16/27] Move submodel code to submodel.jl; remove `@submodel` (#959) * Move submodel code to submodel.jl * Remove `@submodel` --- HISTORY.md | 4 + docs/src/api.md | 6 - src/DynamicPPL.jl | 3 +- src/model.jl | 240 ---------------------------------- src/submodel.jl | 239 ++++++++++++++++++++++++++++++++++ src/submodel_macro.jl | 290 ------------------------------------------ test/deprecated.jl | 57 --------- test/runtests.jl | 1 - 8 files changed, 244 insertions(+), 596 deletions(-) create mode 100644 src/submodel.jl delete mode 100644 src/submodel_macro.jl delete mode 100644 test/deprecated.jl diff --git a/HISTORY.md b/HISTORY.md index 9edac441f..d559e6373 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,10 @@ **Breaking changes** +### Submodel macro + +The `@submodel` macro is fully removed; please use `to_submodel` instead. + ### Accumulators This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes: diff --git a/docs/src/api.md b/docs/src/api.md index 32b3d80a6..886d34a2f 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -146,12 +146,6 @@ to_submodel Note that a `[to_submodel](@ref)` is only sampleable; one cannot compute `logpdf` for its realizations. -In the past, one would instead embed sub-models using [`@submodel`](@ref), which has been deprecated since the introduction of [`to_submodel(model)`](@ref) - -```@docs -@submodel -``` - In the context of including models within models, it's also useful to prefix the variables in sub-models to avoid variable names clashing: ```@docs diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 4bd4f2529..2b4d0e4a6 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -128,7 +128,6 @@ export AbstractVarInfo, to_submodel, # Convenience macros @addlogprob!, - @submodel, value_iterator_from_chain, check_model, check_model_and_trace, @@ -172,6 +171,7 @@ abstract type AbstractVarInfo <: AbstractModelTrace end include("utils.jl") include("chains.jl") include("model.jl") +include("submodel.jl") include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") @@ -186,7 +186,6 @@ include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") include("pointwise_logdensities.jl") -include("submodel_macro.jl") include("transforming.jl") include("logdensityfunction.jl") include("model_utils.jl") diff --git a/src/model.jl b/src/model.jl index f46137ed1..27551bfa2 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1265,243 +1265,3 @@ end function returned(model::Model, values, keys) return returned(model, NamedTuple{keys}(values)) end - -""" - is_rhs_model(x) - -Return `true` if `x` is a model or model wrapper, and `false` otherwise. -""" -is_rhs_model(x) = false - -""" - Distributional - -Abstract type for type indicating that something is "distributional". -""" -abstract type Distributional end - -""" - should_auto_prefix(distributional) - -Return `true` if the `distributional` should use automatic prefixing, and `false` otherwise. -""" -function should_auto_prefix end - -""" - is_rhs_model(x) - -Return `true` if the `distributional` is a model, and `false` otherwise. -""" -function is_rhs_model end - -""" - Sampleable{M} <: Distributional - -A wrapper around a model indicating it is sampleable. -""" -struct Sampleable{M,AutoPrefix} <: Distributional - model::M -end - -should_auto_prefix(::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} = AutoPrefix -is_rhs_model(x::Sampleable) = is_rhs_model(x.model) - -# TODO: Export this if it end up having a purpose beyond `to_submodel`. -""" - to_sampleable(model[, auto_prefix]) - -Return a wrapper around `model` indicating it is sampleable. - -# Arguments -- `model::Model`: the model to wrap. -- `auto_prefix::Bool`: whether to prefix the variables in the model. Default: `true`. -""" -to_sampleable(model, auto_prefix::Bool=true) = Sampleable{typeof(model),auto_prefix}(model) - -""" - rand_like!!(model_wrap, context, varinfo) - -Returns a tuple with the first element being the realization and the second the updated varinfo. - -# Arguments -- `model_wrap::ReturnedModelWrapper`: the wrapper of the model to use. -- `context::AbstractContext`: the context to use for evaluation. -- `varinfo::AbstractVarInfo`: the varinfo to use for evaluation. - """ -function rand_like!!( - model_wrap::Sampleable, context::AbstractContext, varinfo::AbstractVarInfo -) - return rand_like!!(model_wrap.model, context, varinfo) -end - -""" - ReturnedModelWrapper - -A wrapper around a model indicating it is a model over its return values. - -This should rarely be constructed explicitly; see [`returned(model)`](@ref) instead. -""" -struct ReturnedModelWrapper{M<:Model} - model::M -end - -is_rhs_model(::ReturnedModelWrapper) = true - -function rand_like!!( - model_wrap::ReturnedModelWrapper, context::AbstractContext, varinfo::AbstractVarInfo -) - # Return's the value and the (possibly mutated) varinfo. - return _evaluate!!(model_wrap.model, varinfo, context) -end - -""" - returned(model) - -Return a `model` wrapper indicating that it is a model over its return-values. -""" -returned(model::Model) = ReturnedModelWrapper(model) - -""" - to_submodel(model::Model[, auto_prefix::Bool]) - -Return a model wrapper indicating that it is a sampleable model over the return-values. - -This is mainly meant to be used on the right-hand side of a `~` operator to indicate that -the model can be sampled from but not necessarily evaluated for its log density. - -!!! warning - Note that some other operations that one typically associate with expressions of the form - `left ~ right` such as [`condition`](@ref), will also not work with `to_submodel`. - -!!! warning - To avoid variable names clashing between models, it is recommend leave argument `auto_prefix` equal to `true`. - If one does not use automatic prefixing, then it's recommended to use [`prefix(::Model, input)`](@ref) explicitly. - -# Arguments -- `model::Model`: the model to wrap. -- `auto_prefix::Bool`: whether to automatically prefix the variables in the model using the left-hand - side of the `~` statement. Default: `true`. - -# Examples - -## Simple example -```jldoctest submodel-to_submodel; setup=:(using Distributions) -julia> @model function demo1(x) - x ~ Normal() - return 1 + abs(x) - end; - -julia> @model function demo2(x, y) - a ~ to_submodel(demo1(x)) - return y ~ Uniform(0, a) - end; -``` - -When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: -```jldoctest submodel-to_submodel -julia> vi = VarInfo(demo2(missing, 0.4)); - -julia> @varname(a.x) in keys(vi) -true -``` - -The variable `a` is not tracked. However, it will be assigned the return value of `demo1`, -and can be used in subsequent lines of the model, as shown above. -```jldoctest submodel-to_submodel -julia> @varname(a) in keys(vi) -false -``` - -We can check that the log joint probability of the model accumulated in `vi` is correct: - -```jldoctest submodel-to_submodel -julia> x = vi[@varname(a.x)]; - -julia> getlogjoint(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) -true -``` - -## Without automatic prefixing -As mentioned earlier, by default, the `auto_prefix` argument specifies whether to automatically -prefix the variables in the submodel. If `auto_prefix=false`, then the variables in the submodel -will not be prefixed. -```jldoctest submodel-to_submodel-prefix; setup=:(using Distributions) -julia> @model function demo1(x) - x ~ Normal() - return 1 + abs(x) - end; - -julia> @model function demo2_no_prefix(x, z) - a ~ to_submodel(demo1(x), false) - return z ~ Uniform(-a, 1) - end; - -julia> vi = VarInfo(demo2_no_prefix(missing, 0.4)); - -julia> @varname(x) in keys(vi) # here we just use `x` instead of `a.x` -true -``` -However, not using prefixing is generally not recommended as it can lead to variable name clashes -unless one is careful. For example, if we're re-using the same model twice in a model, not using prefixing -will lead to variable name clashes: However, one can manually prefix using the [`prefix(::Model, input)`](@ref): -```jldoctest submodel-to_submodel-prefix -julia> @model function demo2(x, y, z) - a ~ to_submodel(prefix(demo1(x), :sub1), false) - b ~ to_submodel(prefix(demo1(y), :sub2), false) - return z ~ Uniform(-a, b) - end; - -julia> vi = VarInfo(demo2(missing, missing, 0.4)); - -julia> @varname(sub1.x) in keys(vi) -true - -julia> @varname(sub2.x) in keys(vi) -true -``` - -Variables `a` and `b` are not tracked, but are assigned the return values of the respective -calls to `demo1`: -```jldoctest submodel-to_submodel-prefix -julia> @varname(a) in keys(vi) -false - -julia> @varname(b) in keys(vi) -false -``` - -We can check that the log joint probability of the model accumulated in `vi` is correct: - -```jldoctest submodel-to_submodel-prefix -julia> sub1_x = vi[@varname(sub1.x)]; - -julia> sub2_x = vi[@varname(sub2.x)]; - -julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); - -julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); - -julia> getlogjoint(vi) ≈ logprior + loglikelihood -true -``` - -## Usage as likelihood is illegal - -Note that it is illegal to use a `to_submodel` model as a likelihood in another model: - -```jldoctest submodel-to_submodel-illegal; setup=:(using Distributions) -julia> @model inner() = x ~ Normal() -inner (generic function with 2 methods) - -julia> @model illegal_likelihood() = a ~ to_submodel(inner()) -illegal_likelihood (generic function with 2 methods) - -julia> model = illegal_likelihood() | (a = 1.0,); - -julia> model() -ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported -[...] -``` -""" -to_submodel(model::Model, auto_prefix::Bool=true) = - to_sampleable(returned(model), auto_prefix) diff --git a/src/submodel.jl b/src/submodel.jl new file mode 100644 index 000000000..94658b6bf --- /dev/null +++ b/src/submodel.jl @@ -0,0 +1,239 @@ +""" + is_rhs_model(x) + +Return `true` if `x` is a model or model wrapper, and `false` otherwise. +""" +is_rhs_model(x) = false + +""" + Distributional + +Abstract type for type indicating that something is "distributional". +""" +abstract type Distributional end + +""" + should_auto_prefix(distributional) + +Return `true` if the `distributional` should use automatic prefixing, and `false` otherwise. +""" +function should_auto_prefix end + +""" + is_rhs_model(x) + +Return `true` if the `distributional` is a model, and `false` otherwise. +""" +function is_rhs_model end + +""" + Sampleable{M} <: Distributional + +A wrapper around a model indicating it is sampleable. +""" +struct Sampleable{M,AutoPrefix} <: Distributional + model::M +end + +should_auto_prefix(::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} = AutoPrefix +is_rhs_model(x::Sampleable) = is_rhs_model(x.model) + +# TODO: Export this if it end up having a purpose beyond `to_submodel`. +""" + to_sampleable(model[, auto_prefix]) + +Return a wrapper around `model` indicating it is sampleable. + +# Arguments +- `model::Model`: the model to wrap. +- `auto_prefix::Bool`: whether to prefix the variables in the model. Default: `true`. +""" +to_sampleable(model, auto_prefix::Bool=true) = Sampleable{typeof(model),auto_prefix}(model) + +""" + rand_like!!(model_wrap, context, varinfo) + +Returns a tuple with the first element being the realization and the second the updated varinfo. + +# Arguments +- `model_wrap::ReturnedModelWrapper`: the wrapper of the model to use. +- `context::AbstractContext`: the context to use for evaluation. +- `varinfo::AbstractVarInfo`: the varinfo to use for evaluation. + """ +function rand_like!!( + model_wrap::Sampleable, context::AbstractContext, varinfo::AbstractVarInfo +) + return rand_like!!(model_wrap.model, context, varinfo) +end + +""" + ReturnedModelWrapper + +A wrapper around a model indicating it is a model over its return values. + +This should rarely be constructed explicitly; see [`returned(model)`](@ref) instead. +""" +struct ReturnedModelWrapper{M<:Model} + model::M +end + +is_rhs_model(::ReturnedModelWrapper) = true + +function rand_like!!( + model_wrap::ReturnedModelWrapper, context::AbstractContext, varinfo::AbstractVarInfo +) + # Return's the value and the (possibly mutated) varinfo. + return _evaluate!!(model_wrap.model, varinfo, context) +end + +""" + returned(model) + +Return a `model` wrapper indicating that it is a model over its return-values. +""" +returned(model::Model) = ReturnedModelWrapper(model) + +""" + to_submodel(model::Model[, auto_prefix::Bool]) + +Return a model wrapper indicating that it is a sampleable model over the return-values. + +This is mainly meant to be used on the right-hand side of a `~` operator to indicate that +the model can be sampled from but not necessarily evaluated for its log density. + +!!! warning + Note that some other operations that one typically associate with expressions of the form + `left ~ right` such as [`condition`](@ref), will also not work with `to_submodel`. + +!!! warning + To avoid variable names clashing between models, it is recommend leave argument `auto_prefix` equal to `true`. + If one does not use automatic prefixing, then it's recommended to use [`prefix(::Model, input)`](@ref) explicitly. + +# Arguments +- `model::Model`: the model to wrap. +- `auto_prefix::Bool`: whether to automatically prefix the variables in the model using the left-hand + side of the `~` statement. Default: `true`. + +# Examples + +## Simple example +```jldoctest submodel-to_submodel; setup=:(using Distributions) +julia> @model function demo1(x) + x ~ Normal() + return 1 + abs(x) + end; + +julia> @model function demo2(x, y) + a ~ to_submodel(demo1(x)) + return y ~ Uniform(0, a) + end; +``` + +When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: +```jldoctest submodel-to_submodel +julia> vi = VarInfo(demo2(missing, 0.4)); + +julia> @varname(a.x) in keys(vi) +true +``` + +The variable `a` is not tracked. However, it will be assigned the return value of `demo1`, +and can be used in subsequent lines of the model, as shown above. +```jldoctest submodel-to_submodel +julia> @varname(a) in keys(vi) +false +``` + +We can check that the log joint probability of the model accumulated in `vi` is correct: + +```jldoctest submodel-to_submodel +julia> x = vi[@varname(a.x)]; + +julia> getlogjoint(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) +true +``` + +## Without automatic prefixing +As mentioned earlier, by default, the `auto_prefix` argument specifies whether to automatically +prefix the variables in the submodel. If `auto_prefix=false`, then the variables in the submodel +will not be prefixed. +```jldoctest submodel-to_submodel-prefix; setup=:(using Distributions) +julia> @model function demo1(x) + x ~ Normal() + return 1 + abs(x) + end; + +julia> @model function demo2_no_prefix(x, z) + a ~ to_submodel(demo1(x), false) + return z ~ Uniform(-a, 1) + end; + +julia> vi = VarInfo(demo2_no_prefix(missing, 0.4)); + +julia> @varname(x) in keys(vi) # here we just use `x` instead of `a.x` +true +``` +However, not using prefixing is generally not recommended as it can lead to variable name clashes +unless one is careful. For example, if we're re-using the same model twice in a model, not using prefixing +will lead to variable name clashes: However, one can manually prefix using the [`prefix(::Model, input)`](@ref): +```jldoctest submodel-to_submodel-prefix +julia> @model function demo2(x, y, z) + a ~ to_submodel(prefix(demo1(x), :sub1), false) + b ~ to_submodel(prefix(demo1(y), :sub2), false) + return z ~ Uniform(-a, b) + end; + +julia> vi = VarInfo(demo2(missing, missing, 0.4)); + +julia> @varname(sub1.x) in keys(vi) +true + +julia> @varname(sub2.x) in keys(vi) +true +``` + +Variables `a` and `b` are not tracked, but are assigned the return values of the respective +calls to `demo1`: +```jldoctest submodel-to_submodel-prefix +julia> @varname(a) in keys(vi) +false + +julia> @varname(b) in keys(vi) +false +``` + +We can check that the log joint probability of the model accumulated in `vi` is correct: + +```jldoctest submodel-to_submodel-prefix +julia> sub1_x = vi[@varname(sub1.x)]; + +julia> sub2_x = vi[@varname(sub2.x)]; + +julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); + +julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); + +julia> getlogjoint(vi) ≈ logprior + loglikelihood +true +``` + +## Usage as likelihood is illegal + +Note that it is illegal to use a `to_submodel` model as a likelihood in another model: + +```jldoctest submodel-to_submodel-illegal; setup=:(using Distributions) +julia> @model inner() = x ~ Normal() +inner (generic function with 2 methods) + +julia> @model illegal_likelihood() = a ~ to_submodel(inner()) +illegal_likelihood (generic function with 2 methods) + +julia> model = illegal_likelihood() | (a = 1.0,); + +julia> model() +ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported +[...] +``` +""" +to_submodel(model::Model, auto_prefix::Bool=true) = + to_sampleable(returned(model), auto_prefix) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl deleted file mode 100644 index 67c3a8c18..000000000 --- a/src/submodel_macro.jl +++ /dev/null @@ -1,290 +0,0 @@ -""" - @submodel model - @submodel ... = model - -Run a Turing `model` nested inside of a Turing model. - -!!! warning - This is deprecated and will be removed in a future release. - Use `left ~ to_submodel(model)` instead (see [`to_submodel`](@ref)). - -# Examples - -```jldoctest submodel; setup=:(using Distributions) -julia> @model function demo1(x) - x ~ Normal() - return 1 + abs(x) - end; - -julia> @model function demo2(x, y) - @submodel a = demo1(x) - return y ~ Uniform(0, a) - end; -``` - -When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: -```jldoctest submodel -julia> vi = VarInfo(demo2(missing, 0.4)); -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 - -julia> @varname(x) in keys(vi) -true -``` - -Variable `a` is not tracked since it can be computed from the random variable `x` that was -tracked when running `demo1`: -```jldoctest submodel -julia> @varname(a) in keys(vi) -false -``` - -We can check that the log joint probability of the model accumulated in `vi` is correct: - -```jldoctest submodel -julia> x = vi[@varname(x)]; - -julia> getlogjoint(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) -true -``` -""" -macro submodel(expr) - return submodel(:(prefix = false), expr) -end - -""" - @submodel prefix=... model - @submodel prefix=... ... = model - -Run a Turing `model` nested inside of a Turing model and add "`prefix`." as a prefix -to all random variables inside of the `model`. - -Valid expressions for `prefix=...` are: -- `prefix=false`: no prefix is used. -- `prefix=true`: _attempt_ to automatically determine the prefix from the left-hand side - `... = model` by first converting into a `VarName`, and then calling `Symbol` on this. -- `prefix=expression`: results in the prefix `Symbol(expression)`. - -The prefix makes it possible to run the same Turing model multiple times while -keeping track of all random variables correctly. - -!!! warning - This is deprecated and will be removed in a future release. - Use `left ~ to_submodel(model)` instead (see [`to_submodel(model)`](@ref)). - -# Examples -## Example models -```jldoctest submodelprefix; setup=:(using Distributions) -julia> @model function demo1(x) - x ~ Normal() - return 1 + abs(x) - end; - -julia> @model function demo2(x, y, z) - @submodel prefix="sub1" a = demo1(x) - @submodel prefix="sub2" b = demo1(y) - return z ~ Uniform(-a, b) - end; -``` - -When we sample from the model `demo2(missing, missing, 0.4)` random variables `sub1.x` and -`sub2.x` will be sampled: -```jldoctest submodelprefix -julia> vi = VarInfo(demo2(missing, missing, 0.4)); -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 - -julia> @varname(sub1.x) in keys(vi) -true - -julia> @varname(sub2.x) in keys(vi) -true -``` - -Variables `a` and `b` are not tracked since they can be computed from the random variables `sub1.x` and -`sub2.x` that were tracked when running `demo1`: -```jldoctest submodelprefix -julia> @varname(a) in keys(vi) -false - -julia> @varname(b) in keys(vi) -false -``` - -We can check that the log joint probability of the model accumulated in `vi` is correct: - -```jldoctest submodelprefix -julia> sub1_x = vi[@varname(sub1.x)]; - -julia> sub2_x = vi[@varname(sub2.x)]; - -julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); - -julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); - -julia> getlogjoint(vi) ≈ logprior + loglikelihood -true -``` - -## Different ways of setting the prefix -```jldoctest submodel-prefix-alternatives; setup=:(using DynamicPPL, Distributions) -julia> @model inner() = x ~ Normal() -inner (generic function with 2 methods) - -julia> # When `prefix` is unspecified, no prefix is used. - @model submodel_noprefix() = @submodel a = inner() -submodel_noprefix (generic function with 2 methods) - -julia> @varname(x) in keys(VarInfo(submodel_noprefix())) -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 -true - -julia> # Explicitely don't use any prefix. - @model submodel_prefix_false() = @submodel prefix=false a = inner() -submodel_prefix_false (generic function with 2 methods) - -julia> @varname(x) in keys(VarInfo(submodel_prefix_false())) -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 -true - -julia> # Automatically determined from `a`. - @model submodel_prefix_true() = @submodel prefix=true a = inner() -submodel_prefix_true (generic function with 2 methods) - -julia> @varname(a.x) in keys(VarInfo(submodel_prefix_true())) -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 -true - -julia> # Using a static string. - @model submodel_prefix_string() = @submodel prefix="my prefix" a = inner() -submodel_prefix_string (generic function with 2 methods) - -julia> @varname(var"my prefix".x) in keys(VarInfo(submodel_prefix_string())) -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 -true - -julia> # Using string interpolation. - @model submodel_prefix_interpolation() = @submodel prefix="\$(nameof(inner()))" a = inner() -submodel_prefix_interpolation (generic function with 2 methods) - -julia> @varname(inner.x) in keys(VarInfo(submodel_prefix_interpolation())) -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 -true - -julia> # Or using some arbitrary expression. - @model submodel_prefix_expr() = @submodel prefix=1 + 2 a = inner() -submodel_prefix_expr (generic function with 2 methods) - -julia> @varname(var"3".x) in keys(VarInfo(submodel_prefix_expr())) -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 -true - -julia> # (×) Automatic prefixing without a left-hand side expression does not work! - @model submodel_prefix_error() = @submodel prefix=true inner() -ERROR: LoadError: cannot automatically prefix with no left-hand side -[...] -``` - -# Notes -- The choice `prefix=expression` means that the prefixing will incur a runtime cost. - This is also the case for `prefix=true`, depending on whether the expression on the - the right-hand side of `... = model` requires runtime-information or not, e.g. - `x = model` will result in the _static_ prefix `x`, while `x[i] = model` will be - resolved at runtime. -""" -macro submodel(prefix_expr, expr) - return submodel(prefix_expr, expr, esc(:__model__)) -end - -# Automatic prefixing. -function prefix_submodel_context(prefix::Bool, left::Symbol, model) - return prefix ? prefix_submodel_context(left, model) : :($model.context) -end - -function prefix_submodel_context(prefix::Bool, left::Expr, model) - return prefix ? prefix_submodel_context(varname(left), model) : :($model.context) -end - -# Manual prefixing. -prefix_submodel_context(prefix, left, model) = prefix_submodel_context(prefix, model) -function prefix_submodel_context(prefix, model) - # E.g. `prefix="asd[$i]"` or `prefix=asd` with `asd` to be evaluated. - return :($(PrefixContext)($(Val)($(Symbol)($(esc(prefix)))), $model.context)) -end - -function prefix_submodel_context(prefix::Union{AbstractString,Symbol}, model) - # E.g. `prefix="asd"`. - return :($(PrefixContext)($(esc(Meta.quot(Val(Symbol(prefix))))), $model.context)) -end - -function prefix_submodel_context(prefix::Bool, model) - if prefix - error("cannot automatically prefix with no left-hand side") - end - - return :($model.context) -end - -const SUBMODEL_DEPWARN_MSG = "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax." - -function submodel(prefix_expr, expr, model=esc(:__model__)) - prefix_left, prefix = getargs_assignment(prefix_expr) - if prefix_left !== :prefix - error("$(prefix_left) is not a valid kwarg") - end - - # The user expects `@submodel ...` to return the - # return-value of the `...`, hence we need to capture - # the return-value and handle it correctly. - @gensym retval - - # `prefix=false` => don't prefix, i.e. do nothing to `ctx`. - # `prefix=true` => automatically determine prefix. - # `prefix=...` => use it. - args_assign = getargs_assignment(expr) - return if args_assign === nothing - ctx = prefix_submodel_context(prefix, model) - quote - # Raise deprecation warning to let user know that we recommend using `left ~ to_submodel(model)`. - $(Base.depwarn)(SUBMODEL_DEPWARN_MSG, Symbol("@submodel")) - - $retval, $(esc(:__varinfo__)) = $(_evaluate!!)( - $(esc(expr)), $(esc(:__varinfo__)), $(ctx) - ) - $retval - end - else - L, R = args_assign - # Now that we have `L` and `R`, we can prefix automagically. - try - ctx = prefix_submodel_context(prefix, L, model) - catch e - error( - "failed to determine prefix from $(L); please specify prefix using the `@submodel prefix=\"your prefix\" ...` syntax", - ) - end - quote - # Raise deprecation warning to let user know that we recommend using `left ~ to_submodel(model)`. - $(Base.depwarn)(SUBMODEL_DEPWARN_MSG, Symbol("@submodel")) - - $retval, $(esc(:__varinfo__)) = $(_evaluate!!)( - $(esc(R)), $(esc(:__varinfo__)), $(ctx) - ) - $(esc(L)) = $retval - end - end -end diff --git a/test/deprecated.jl b/test/deprecated.jl deleted file mode 100644 index 500d3eb7f..000000000 --- a/test/deprecated.jl +++ /dev/null @@ -1,57 +0,0 @@ -@testset "deprecated" begin - @testset "@submodel" begin - @testset "is deprecated" begin - @model inner() = x ~ Normal() - @model outer() = @submodel x = inner() - @test_logs( - ( - :warn, - "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.", - ), - outer()() - ) - - @model outer_with_prefix() = @submodel prefix = "sub" x = inner() - @test_logs( - ( - :warn, - "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.", - ), - outer_with_prefix()() - ) - end - - @testset "prefixing still works correctly" begin - @model inner() = x ~ Normal() - @model function outer() - a = @submodel inner() - b = @submodel prefix = "sub" inner() - return a, b - end - @test outer()() isa Tuple{Float64,Float64} - vi = VarInfo(outer()) - @test @varname(x) in keys(vi) - @test @varname(sub.x) in keys(vi) - end - - @testset "logp is still accumulated properly" begin - @model inner_assume() = x ~ Normal() - @model inner_observe(x, y) = y ~ Normal(x) - @model function outer(b) - a = @submodel inner_assume() - @submodel inner_observe(a, b) - end - y_val = 1.0 - model = outer(y_val) - @test model() == y_val - - x_val = 1.5 - vi = VarInfo(outer(y_val)) - DynamicPPL.setindex!!(vi, x_val, @varname(x)) - @test logprior(model, vi) ≈ logpdf(Normal(), x_val) - @test loglikelihood(model, vi) ≈ logpdf(Normal(x_val), y_val) - @test logjoint(model, vi) ≈ - logpdf(Normal(), x_val) + logpdf(Normal(x_val), y_val) - end - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 6fabcfe59..c60c06786 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -72,7 +72,6 @@ include("test_util.jl") include("context_implementations.jl") include("threadsafe.jl") include("debug_utils.jl") - include("deprecated.jl") include("submodels.jl") include("bijector.jl") end From 7f207097ecc21222ac13182d8bd8033ac1755df5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 26 Jun 2025 15:46:41 +0100 Subject: [PATCH 17/27] Fix missing field tests for 1.12 (#961) --- test/varinfo.jl | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index d788e6215..cf03c1497 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -211,10 +211,12 @@ end ), ) @test getlogprior(vi) == lp_a + lp_b - @test_throws "has no field LogLikelihood" getloglikelihood(vi) - @test_throws "has no field LogLikelihood" getlogp(vi) - @test_throws "has no field LogLikelihood" getlogjoint(vi) - @test_throws "has no field NumProduce" get_num_produce(vi) + # need regex because 1.11 and 1.12 throw different errors (in 1.12 the + # missing field is surrounded by backticks) + @test_throws r"has no field `?LogLikelihood" getloglikelihood(vi) + @test_throws r"has no field `?LogLikelihood" getlogp(vi) + @test_throws r"has no field `?LogLikelihood" getlogjoint(vi) + @test_throws r"has no field `?NumProduce" get_num_produce(vi) @test begin vi = acclogprior!!(vi, 1.0) getlogprior(vi) == lp_a + lp_b + 1.0 @@ -229,20 +231,24 @@ end m, DynamicPPL.setaccs!!(deepcopy(vi), (NumProduceAccumulator(),)) ), ) - @test_throws "has no field LogPrior" getlogprior(vi) - @test_throws "has no field LogLikelihood" getloglikelihood(vi) - @test_throws "has no field LogPrior" getlogp(vi) - @test_throws "has no field LogPrior" getlogjoint(vi) + # need regex because 1.11 and 1.12 throw different errors (in 1.12 the + # missing field is surrounded by backticks) + @test_throws r"has no field `?LogPrior" getlogprior(vi) + @test_throws r"has no field `?LogLikelihood" getloglikelihood(vi) + @test_throws r"has no field `?LogPrior" getlogp(vi) + @test_throws r"has no field `?LogPrior" getlogjoint(vi) @test get_num_produce(vi) == 2 # Test evaluating without any accumulators. vi = last(DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), ()))) - @test_throws "has no field LogPrior" getlogprior(vi) - @test_throws "has no field LogLikelihood" getloglikelihood(vi) - @test_throws "has no field LogPrior" getlogp(vi) - @test_throws "has no field LogPrior" getlogjoint(vi) - @test_throws "has no field NumProduce" get_num_produce(vi) - @test_throws "has no field NumProduce" reset_num_produce!!(vi) + # need regex because 1.11 and 1.12 throw different errors (in 1.12 the + # missing field is surrounded by backticks) + @test_throws r"has no field `?LogPrior" getlogprior(vi) + @test_throws r"has no field `?LogLikelihood" getloglikelihood(vi) + @test_throws r"has no field `?LogPrior" getlogp(vi) + @test_throws r"has no field `?LogPrior" getlogjoint(vi) + @test_throws r"has no field `?NumProduce" get_num_produce(vi) + @test_throws r"has no field `?NumProduce" reset_num_produce!!(vi) end @testset "flags" begin From f20e86ccb652a89309d52e10e9ffba672206c7f5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 3 Jul 2025 17:22:38 +0100 Subject: [PATCH 18/27] Remove 3-argument `{_,}evaluate!!`; clean up submodel code (#960) * Clean up submodel code, remove 3-arg `_evaluate!!` * Remove 3-argument `evaluate!!` as well * Update changelog * Improve submodel error message * Fix doctest * Add error hint for three-argument evaluate!! --- HISTORY.md | 7 +- docs/src/api.md | 6 -- src/DynamicPPL.jl | 17 +++- src/compiler.jl | 6 +- src/context_implementations.jl | 38 ++------- src/model.jl | 59 +------------- src/submodel.jl | 144 ++++++++++++--------------------- src/test_utils/contexts.jl | 10 ++- 8 files changed, 85 insertions(+), 202 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index d559e6373..617543a5f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -31,11 +31,11 @@ This version therefore excises the context argument, and instead uses `model.con The upshot of this is that many functions that previously took a context argument now no longer do. There were very few such functions where the context argument was actually used (most of them simply took `DefaultContext()` as the default value). -`evaluate!!(model, varinfo, ext_context)` is deprecated, and broadly speaking you should replace calls to that with `new_model = contextualize(model, ext_context); evaluate!!(new_model, varinfo)`. +`evaluate!!(model, varinfo, ext_context)` is removed, and broadly speaking you should replace calls to that with `new_model = contextualize(model, ext_context); evaluate!!(new_model, varinfo)`. If the 'external context' `ext_context` is a parent context, then you should wrap `model.context` appropriately to ensure that its information content is not lost. If, on the other hand, `ext_context` is a `DefaultContext`, then you can just drop the argument entirely. -To aid with this process, `contextualize` is now exported from DynamicPPL. +**To aid with this process, `contextualize` is now exported from DynamicPPL.** The main situation where one _did_ want to specify an additional evaluation context was when that context was a `SamplingContext`. Doing this would allow you to run the model and sample fresh values, instead of just using the values that existed in the VarInfo object. @@ -54,9 +54,10 @@ However, here are the more user-facing ones: And a couple of more internal changes: - - `evaluate!!`, `evaluate_threadsafe!!`, and `evaluate_threadunsafe!!` no longer accept context arguments + - Just like `evaluate!!`, the other functions `_evaluate!!`, `evaluate_threadsafe!!`, and `evaluate_threadunsafe!!` now no longer accept context arguments - `evaluate!!` no longer takes rng and sampler (if you used this, you should use `evaluate_and_sample!!` instead, or construct your own `SamplingContext`) - The model evaluation function, `model.f` for some `model::Model`, no longer takes a context as an argument + - The internal representation and API dealing with submodels (i.e., `ReturnedModelWrapper`, `Sampleable`, `should_auto_prefix`, `is_rhs_model`) has been simplified. If you need to check whether something is a submodel, just use `x isa DynamicPPL.Submodel`. Note that the public API i.e. `to_submodel` remains completely untouched. ## 0.36.12 diff --git a/docs/src/api.md b/docs/src/api.md index 886d34a2f..24efdae30 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -152,12 +152,6 @@ In the context of including models within models, it's also useful to prefix the DynamicPPL.prefix ``` -Under the hood, [`to_submodel`](@ref) makes use of the following method to indicate that the model it's wrapping is a model over its return-values rather than something else - -```@docs -returned(::Model) -``` - ## Utilities It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 2b4d0e4a6..69e489ce6 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -171,11 +171,11 @@ abstract type AbstractVarInfo <: AbstractModelTrace end include("utils.jl") include("chains.jl") include("model.jl") -include("submodel.jl") include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") +include("submodel.jl") include("varnamedvector.jl") include("accumulators.jl") include("default_accumulators.jl") @@ -226,6 +226,21 @@ if isdefined(Base.Experimental, :register_error_hint) ) end end + + Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ + is_evaluate_three_arg = + exc.f === AbstractPPL.evaluate!! && + length(argtypes) == 3 && + argtypes[1] <: Model && + argtypes[2] <: AbstractVarInfo && + argtypes[3] <: AbstractContext + if is_evaluate_three_arg + print( + io, + "\n\nThe method `evaluate!!(model, varinfo, new_ctx)` has been removed. Instead, you should store the `new_ctx` in the `model.context` field using `new_model = contextualize(model, new_ctx)`, and then call `evaluate!!(new_model, varinfo)` on the new model. (Note that, if the model already contained a non-default context, you will need to wrap the existing context.)", + ) + end + end end end diff --git a/src/compiler.jl b/src/compiler.jl index 22dff33a2..6384eaa7c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -176,11 +176,7 @@ function check_tilde_rhs(@nospecialize(x)) end check_tilde_rhs(x::Distribution) = x check_tilde_rhs(x::AbstractArray{<:Distribution}) = x -check_tilde_rhs(x::ReturnedModelWrapper) = x -function check_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} - model = check_tilde_rhs(x.model) - return Sampleable{typeof(model),AutoPrefix}(model) -end +check_tilde_rhs(x::Submodel{M,AutoPrefix}) where {M,AutoPrefix} = x """ check_dot_tilde_rhs(x) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index cc75cd7e6..b11a723a5 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -63,31 +63,10 @@ By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log probability of `vi` with the returned value. """ function tilde_assume!!(context, right, vn, vi) - return if is_rhs_model(right) - # Here, we apply the PrefixContext _not_ to the parent `context`, but - # to the context of the submodel being evaluated. This means that later= - # on in `make_evaluate_args_and_kwargs`, the context stack will be - # correctly arranged such that it goes like this: - # parent_context[1] -> parent_context[2] -> ... -> PrefixContext -> - # submodel_context[1] -> submodel_context[2] -> ... -> leafcontext - # See the docstring of `make_evaluate_args_and_kwargs`, and the internal - # DynamicPPL documentation on submodel conditioning, for more details. - # - # NOTE: This relies on the existence of `right.model.model`. Right now, - # the only thing that can return true for `is_rhs_model` is something - # (a `Sampleable`) that has a `model` field that itself (a - # `ReturnedModelWrapper`) has a `model` field. This may or may not - # change in the future. - if should_auto_prefix(right) - dppl_model = right.model.model # This isa DynamicPPL.Model - prefixed_submodel_context = PrefixContext(vn, dppl_model.context) - new_dppl_model = contextualize(dppl_model, prefixed_submodel_context) - right = to_submodel(new_dppl_model, true) - end - rand_like!!(right, context, vi) + return if right isa DynamicPPL.Submodel + _evaluate!!(right, vi, context, vn) else - value, vi = tilde_assume(context, right, vn, vi) - return value, vi + tilde_assume(context, right, vn, vi) end end @@ -129,17 +108,14 @@ accumulate the log probability, and return the observed value and updated `vi`. Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!!(context::DefaultContext, right, left, vn, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`~` with a model on the right-hand side of an observe statement is not supported", - ), - ) +function tilde_observe!!(::DefaultContext, right, left, vn, vi) + right isa DynamicPPL.Submodel && + throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) vi = accumulate_observe!!(vi, right, left, vn) return left, vi end -function assume(rng::Random.AbstractRNG, spl::Sampler, dist) +function assume(::Random.AbstractRNG, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end diff --git a/src/model.jl b/src/model.jl index 27551bfa2..93e77eaec 100644 --- a/src/model.jl +++ b/src/model.jl @@ -258,7 +258,7 @@ julia> # However, it's not possible to condition `inner` directly. conditioned_model_fail = model | (inner = 1.0, ); julia> conditioned_model_fail() -ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported +ERROR: ArgumentError: `x ~ to_submodel(...)` is not supported when `x` is observed [...] ``` """ @@ -864,12 +864,6 @@ If multiple threads are available, the varinfo provided will be wrapped in a Returns a tuple of the model's return value, plus the updated `varinfo` (unwrapped if necessary). - - evaluate!!(model::Model, varinfo, context) - -When an extra context stack is provided, the model's context is inserted into -that context stack. See `combine_model_and_external_contexts`. This method is -deprecated. """ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) return if use_threadsafe_eval(model.context, varinfo) @@ -878,17 +872,6 @@ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) evaluate_threadunsafe!!(model, varinfo) end end -function AbstractPPL.evaluate!!( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext -) - Base.depwarn( - "The `context` argument to evaluate!!(model, varinfo, context) is deprecated.", - :dynamicppl_evaluate_context, - ) - new_ctx = combine_model_and_external_contexts(model.context, context) - model = contextualize(model, new_ctx) - return evaluate!!(model, varinfo) -end """ evaluate_threadunsafe!!(model, varinfo) @@ -932,54 +915,14 @@ Evaluate the `model` with the given `varinfo`. This function does not wrap the varinfo in a `ThreadSafeVarInfo`. It also does not reset the log probability of the `varinfo` before running. - - _evaluate!!(model::Model, varinfo, context) - -If an additional `context` is provided, the model's context is combined with -that context before evaluation. """ function _evaluate!!(model::Model, varinfo::AbstractVarInfo) args, kwargs = make_evaluate_args_and_kwargs(model, varinfo) return model.f(args...; kwargs...) end -function _evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext) - # TODO(penelopeysm): We don't really need this, but it's a useful - # convenience method. We could remove it after we get rid of the - # evaluate_threadsafe!! stuff (in favour of making users call evaluate!! - # with a TSVI themselves). - new_ctx = combine_model_and_external_contexts(model.context, context) - model = contextualize(model, new_ctx) - return _evaluate!!(model, varinfo) -end is_splat_symbol(s::Symbol) = startswith(string(s), "#splat#") -""" - combine_model_and_external_contexts(model_context, external_context) - -Combine a context from a model and an external context into a single context. - -The resulting context stack has the following structure: - - `external_context` -> `childcontext(external_context)` -> ... -> - `model_context` -> `childcontext(model_context)` -> ... -> - `leafcontext(external_context)` - -The reason for this is that we want to give `external_context` precedence over -`model_context`, while also preserving the leaf context of `external_context`. -We can do this by - -1. Set the leaf context of `model_context` to `leafcontext(external_context)`. -2. Set leaf context of `external_context` to the context resulting from (1). -""" -function combine_model_and_external_contexts( - model_context::AbstractContext, external_context::AbstractContext -) - return setleafcontext( - external_context, setleafcontext(model_context, leafcontext(external_context)) - ) -end - """ make_evaluate_args_and_kwargs(model, varinfo) diff --git a/src/submodel.jl b/src/submodel.jl index 94658b6bf..dcb107bb4 100644 --- a/src/submodel.jl +++ b/src/submodel.jl @@ -1,98 +1,13 @@ """ - is_rhs_model(x) + Submodel{M,AutoPrefix} -Return `true` if `x` is a model or model wrapper, and `false` otherwise. +A wrapper around a model, plus a flag indicating whether it should be automatically +prefixed with the left-hand variable in a `~` statement. """ -is_rhs_model(x) = false - -""" - Distributional - -Abstract type for type indicating that something is "distributional". -""" -abstract type Distributional end - -""" - should_auto_prefix(distributional) - -Return `true` if the `distributional` should use automatic prefixing, and `false` otherwise. -""" -function should_auto_prefix end - -""" - is_rhs_model(x) - -Return `true` if the `distributional` is a model, and `false` otherwise. -""" -function is_rhs_model end - -""" - Sampleable{M} <: Distributional - -A wrapper around a model indicating it is sampleable. -""" -struct Sampleable{M,AutoPrefix} <: Distributional +struct Submodel{M,AutoPrefix} model::M end -should_auto_prefix(::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} = AutoPrefix -is_rhs_model(x::Sampleable) = is_rhs_model(x.model) - -# TODO: Export this if it end up having a purpose beyond `to_submodel`. -""" - to_sampleable(model[, auto_prefix]) - -Return a wrapper around `model` indicating it is sampleable. - -# Arguments -- `model::Model`: the model to wrap. -- `auto_prefix::Bool`: whether to prefix the variables in the model. Default: `true`. -""" -to_sampleable(model, auto_prefix::Bool=true) = Sampleable{typeof(model),auto_prefix}(model) - -""" - rand_like!!(model_wrap, context, varinfo) - -Returns a tuple with the first element being the realization and the second the updated varinfo. - -# Arguments -- `model_wrap::ReturnedModelWrapper`: the wrapper of the model to use. -- `context::AbstractContext`: the context to use for evaluation. -- `varinfo::AbstractVarInfo`: the varinfo to use for evaluation. - """ -function rand_like!!( - model_wrap::Sampleable, context::AbstractContext, varinfo::AbstractVarInfo -) - return rand_like!!(model_wrap.model, context, varinfo) -end - -""" - ReturnedModelWrapper - -A wrapper around a model indicating it is a model over its return values. - -This should rarely be constructed explicitly; see [`returned(model)`](@ref) instead. -""" -struct ReturnedModelWrapper{M<:Model} - model::M -end - -is_rhs_model(::ReturnedModelWrapper) = true - -function rand_like!!( - model_wrap::ReturnedModelWrapper, context::AbstractContext, varinfo::AbstractVarInfo -) - # Return's the value and the (possibly mutated) varinfo. - return _evaluate!!(model_wrap.model, varinfo, context) -end - -""" - returned(model) - -Return a `model` wrapper indicating that it is a model over its return-values. -""" -returned(model::Model) = ReturnedModelWrapper(model) - """ to_submodel(model::Model[, auto_prefix::Bool]) @@ -106,8 +21,8 @@ the model can be sampled from but not necessarily evaluated for its log density. `left ~ right` such as [`condition`](@ref), will also not work with `to_submodel`. !!! warning - To avoid variable names clashing between models, it is recommend leave argument `auto_prefix` equal to `true`. - If one does not use automatic prefixing, then it's recommended to use [`prefix(::Model, input)`](@ref) explicitly. + To avoid variable names clashing between models, it is recommended to leave the argument `auto_prefix` equal to `true`. + If one does not use automatic prefixing, then it's recommended to use [`prefix(::Model, input)`](@ref) explicitly, i.e. `to_submodel(prefix(model, @varname(my_prefix)))` # Arguments - `model::Model`: the model to wrap. @@ -231,9 +146,50 @@ illegal_likelihood (generic function with 2 methods) julia> model = illegal_likelihood() | (a = 1.0,); julia> model() -ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported +ERROR: ArgumentError: `x ~ to_submodel(...)` is not supported when `x` is observed [...] ``` """ -to_submodel(model::Model, auto_prefix::Bool=true) = - to_sampleable(returned(model), auto_prefix) +to_submodel(m::Model, auto_prefix::Bool=true) = Submodel{typeof(m),auto_prefix}(m) + +# When automatic prefixing is used, the submodel itself doesn't carry the +# prefix, as the prefix is obtained from the LHS of `~` (whereas the submodel +# is on the RHS). The prefix can only be obtained in `tilde_assume!!`, and then +# passed into this function. +# +# `parent_context` here refers to the context of the model that contains the +# submodel. +function _evaluate!!( + submodel::Submodel{M,AutoPrefix}, + vi::AbstractVarInfo, + parent_context::AbstractContext, + left_vn::VarName, +) where {M<:Model,AutoPrefix} + # First, we construct the context to be used when evaluating the submodel. There + # are several considerations here: + # (1) We need to apply an appropriate PrefixContext when evaluating the submodel, but + # _only_ if automatic prefixing is supposed to be applied. + submodel_context_prefixed = if AutoPrefix + PrefixContext(left_vn, submodel.model.context) + else + submodel.model.context + end + + # (2) We need to respect the leaf-context of the parent model. This, unfortunately, + # means disregarding the leaf-context of the submodel. + submodel_context = setleafcontext( + submodel_context_prefixed, leafcontext(parent_context) + ) + + # (3) We need to use the parent model's context to wrap the whole thing, so that + # e.g. if the user conditions the parent model, the conditioned variables will be + # correctly picked up when evaluating the submodel. + eval_context = setleafcontext(parent_context, submodel_context) + + # (4) Finally, we need to store that context inside the submodel. + model = contextualize(submodel.model, eval_context) + + # Once that's all set up nicely, we can just _evaluate!! the wrapped model. This + # returns a tuple of submodel.model's return value and the new varinfo. + return _evaluate!!(model, vi) +end diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 08acdfada..863db4262 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -63,10 +63,12 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod # TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos. # Untyped varinfo. varinfo_untyped = DynamicPPL.VarInfo() - @test (DynamicPPL.evaluate!!(model, varinfo_untyped, SamplingContext(context)); true) - @test (DynamicPPL.evaluate!!(model, varinfo_untyped, context); true) + model_with_spl = contextualize(model, SamplingContext(context)) + model_without_spl = contextualize(model, context) + @test DynamicPPL.evaluate!!(model_with_spl, varinfo_untyped) isa Any + @test DynamicPPL.evaluate!!(model_without_spl, varinfo_untyped) isa Any # Typed varinfo. varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped) - @test (DynamicPPL.evaluate!!(model, varinfo_typed, SamplingContext(context)); true) - @test (DynamicPPL.evaluate!!(model, varinfo_typed, context); true) + @test DynamicPPL.evaluate!!(model_with_spl, varinfo_typed) isa Any + @test DynamicPPL.evaluate!!(model_without_spl, varinfo_typed) isa Any end From a0289db2e0e57289cc6a5d94cf67b1d73e098052 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 8 Jul 2025 11:11:41 +0100 Subject: [PATCH 19/27] Improve API for AD testing (#964) * Rework API for AD testing * Fix test * Add `rng` keyword argument * Use atol and rtol * remove unbound type parameter (?) * Don't need to do elementwise check * Update changelog * Fix typo --- HISTORY.md | 11 +++ docs/src/api.md | 15 ++++ src/test_utils/ad.jl | 180 +++++++++++++++++++++++++++---------------- test/ad.jl | 16 ++-- 4 files changed, 149 insertions(+), 73 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 129e92b26..955a28963 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -8,6 +8,17 @@ The `@submodel` macro is fully removed; please use `to_submodel` instead. +### `DynamicPPL.TestUtils.AD.run_ad` + +The three keyword arguments, `test`, `reference_backend`, and `expected_value_and_grad` have been merged into a single `test` keyword argument. +Please see the API documentation for more details. +(The old `test=true` and `test=false` values are still valid, and you only need to adjust the invocation if you were explicitly passing the `reference_backend` or `expected_value_and_grad` arguments.) + +There is now also an `rng` keyword argument to help seed parameter generation. + +Finally, instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient. +Their semantics are the same as in Julia's `isapprox`; two values are equal if they satisfy either `atol` or `rtol`. + ### Accumulators This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes: diff --git a/docs/src/api.md b/docs/src/api.md index 367433740..e918a095c 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -206,6 +206,21 @@ To test and/or benchmark the performance of an AD backend on a model, DynamicPPL ```@docs DynamicPPL.TestUtils.AD.run_ad +``` + +The default test setting is to compare against ForwardDiff. +You can have more fine-grained control over how to test the AD backend using the following types: + +```@docs +DynamicPPL.TestUtils.AD.AbstractADCorrectnessTestSetting +DynamicPPL.TestUtils.AD.WithBackend +DynamicPPL.TestUtils.AD.WithExpectedResult +DynamicPPL.TestUtils.AD.NoTest +``` + +These are returned / thrown by the `run_ad` function: + +```@docs DynamicPPL.TestUtils.AD.ADResult DynamicPPL.TestUtils.AD.ADIncorrectException ``` diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 5285391b1..155f3b68d 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,28 +4,57 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: - Model, - LogDensityFunction, - VarInfo, - AbstractVarInfo, - link, - DefaultContext, - AbstractContext +using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link using LogDensityProblems: logdensity, logdensity_and_gradient -using Random: Random, Xoshiro +using Random: AbstractRNG, default_rng using Statistics: median using Test: @test -export ADResult, run_ad, ADIncorrectException +export ADResult, run_ad, ADIncorrectException, WithBackend, WithExpectedResult, NoTest """ - REFERENCE_ADTYPE + AbstractADCorrectnessTestSetting -Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since -it's the default AD backend used in Turing.jl. +Different ways of testing the correctness of an AD backend. """ -const REFERENCE_ADTYPE = AutoForwardDiff() +abstract type AbstractADCorrectnessTestSetting end + +""" + WithBackend(adtype::AbstractADType=AutoForwardDiff()) <: AbstractADCorrectnessTestSetting + +Test correctness by comparing it against the result obtained with `adtype`. + +`adtype` defaults to ForwardDiff.jl, since it's the default AD backend used in +Turing.jl. +""" +struct WithBackend{AD<:AbstractADType} <: AbstractADCorrectnessTestSetting + adtype::AD +end +WithBackend() = WithBackend(AutoForwardDiff()) + +""" + WithExpectedResult( + value::T, + grad::AbstractVector{T} + ) where {T <: AbstractFloat} + <: AbstractADCorrectnessTestSetting + +Test correctness by comparing it against a known result (e.g. one obtained +analytically, or one obtained with a different backend previously). Both the +value of the primal (i.e. the log-density) as well as its gradient must be +supplied. +""" +struct WithExpectedResult{T<:AbstractFloat} <: AbstractADCorrectnessTestSetting + value::T + grad::AbstractVector{T} +end + +""" + NoTest() <: AbstractADCorrectnessTestSetting + +Disable correctness testing. +""" +struct NoTest <: AbstractADCorrectnessTestSetting end """ ADIncorrectException{T<:AbstractFloat} @@ -45,17 +74,18 @@ struct ADIncorrectException{T<:AbstractFloat} <: Exception end """ - ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} + ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat} Data structure to store the results of the AD correctness test. The type parameter `Tparams` is the numeric type of the parameters passed in; -`Tresult` is the type of the value and the gradient. +`Tresult` is the type of the value and the gradient; and `Ttol` is the type of the +absolute and relative tolerances used for correctness testing. # Fields $(TYPEDFIELDS) """ -struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} +struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat} "The DynamicPPL model that was tested" model::Model "The VarInfo that was used" @@ -64,18 +94,18 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} params::Vector{Tparams} "The AD backend that was tested" adtype::AbstractADType - "The absolute tolerance for the value of logp" - value_atol::Tresult - "The absolute tolerance for the gradient of logp" - grad_atol::Tresult + "Absolute tolerance used for correctness test" + atol::Ttol + "Relative tolerance used for correctness test" + rtol::Ttol "The expected value of logp" value_expected::Union{Nothing,Tresult} "The expected gradient of logp" grad_expected::Union{Nothing,Vector{Tresult}} "The value of logp (calculated using `adtype`)" - value_actual::Union{Nothing,Tresult} + value_actual::Tresult "The gradient of logp (calculated using `adtype`)" - grad_actual::Union{Nothing,Vector{Tresult}} + grad_actual::Vector{Tresult} "If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself" time_vs_primal::Union{Nothing,Tresult} end @@ -84,14 +114,12 @@ end run_ad( model::Model, adtype::ADTypes.AbstractADType; - test=true, + test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(), benchmark=false, - value_atol=1e-6, - grad_atol=1e-6, + atol::AbstractFloat=1e-8, + rtol::AbstractFloat=sqrt(eps()), varinfo::AbstractVarInfo=link(VarInfo(model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, - reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, - expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, )::ADResult @@ -133,8 +161,8 @@ Everything else is optional, and can be categorised into several groups: Note that if the VarInfo is not specified (and thus automatically generated) the parameters in it will have been sampled from the prior of the model. If - you want to seed the parameter generation, the easiest way is to pass a - `rng` argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`). + you want to seed the parameter generation for the VarInfo, you can pass the + `rng` keyword argument, which will then be used to create the VarInfo. Finally, note that these only reflect the parameters used for _evaluating_ the gradient. If you also want to control the parameters used for @@ -143,25 +171,35 @@ Everything else is optional, and can be categorised into several groups: prep_params)`. You could then evaluate the gradient at a different set of parameters using the `params` keyword argument. -3. _How to specify the results to compare against._ (Only if `test=true`.) +3. _How to specify the results to compare against._ Once logp and its gradient has been calculated with the specified `adtype`, - it must be tested for correctness. + it can optionally be tested for correctness. The exact way this is tested + is specified in the `test` parameter. + + There are several options for this: - This can be done either by specifying `reference_adtype`, in which case logp - and its gradient will also be calculated with this reference in order to - obtain the ground truth; or by using `expected_value_and_grad`, which is a - tuple of `(logp, gradient)` that the calculated values must match. The - latter is useful if you are testing multiple AD backends and want to avoid - recalculating the ground truth multiple times. + - You can explicitly specify the correct value using + [`WithExpectedResult()`](@ref). + - You can compare against the result obtained with a different AD backend + using [`WithBackend(adtype)`](@ref). + - You can disable testing by passing [`NoTest()`](@ref). + - The default is to compare against the result obtained with ForwardDiff, + i.e. `WithBackend(AutoForwardDiff())`. + - `test=false` and `test=true` are synonyms for + `NoTest()` and `WithBackend(AutoForwardDiff())`, respectively. - The default reference backend is ForwardDiff. If none of these parameters are - specified, ForwardDiff will be used to calculate the ground truth. +4. _How to specify the tolerances._ (Only if testing is enabled.) -4. _How to specify the tolerances._ (Only if `test=true`.) + Both absolute and relative tolerances can be specified using the `atol` and + `rtol` keyword arguments respectively. The behaviour of these is similar to + `isapprox()`, i.e. the value and gradient are considered correct if either + atol or rtol is satisfied. The default values are `100*eps()` for `atol` and + `sqrt(eps())` for `rtol`. - The tolerances for the value and gradient can be set using `value_atol` and - `grad_atol`. These default to 1e-6. + For the most part, it is the `rtol` check that is more meaningful, because + we cannot know the magnitude of logp and its gradient a priori. The `atol` + value is supplied to handle the case where gradients are equal to zero. 5. _Whether to output extra logging information._ @@ -180,48 +218,58 @@ thrown as-is. function run_ad( model::Model, adtype::AbstractADType; - test::Bool=true, + test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(), benchmark::Bool=false, - value_atol::AbstractFloat=1e-6, - grad_atol::AbstractFloat=1e-6, - varinfo::AbstractVarInfo=link(VarInfo(model), model), + atol::AbstractFloat=100 * eps(), + rtol::AbstractFloat=sqrt(eps()), + rng::AbstractRNG=default_rng(), + varinfo::AbstractVarInfo=link(VarInfo(rng, model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, - reference_adtype::AbstractADType=REFERENCE_ADTYPE, - expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, )::ADResult + # Convert Boolean `test` to an AbstractADCorrectnessTestSetting + if test isa Bool + test = test ? WithBackend() : NoTest() + end + + # Extract parameters if isnothing(params) params = varinfo[:] end params = map(identity, params) # Concretise + # Calculate log-density and gradient with the backend of interest verbose && @info "Running AD on $(model.f) with $(adtype)\n" verbose && println(" params : $(params)") ldf = LogDensityFunction(model, varinfo; adtype=adtype) - value, grad = logdensity_and_gradient(ldf, params) + # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 grad = collect(grad) verbose && println(" actual : $((value, grad))") - if test - # Calculate ground truth to compare against - value_true, grad_true = if expected_value_and_grad === nothing - ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype) - logdensity_and_gradient(ldf_reference, params) - else - expected_value_and_grad + # Test correctness + if test isa NoTest + value_true = nothing + grad_true = nothing + else + # Get the correct result + if test isa WithExpectedResult + value_true = test.value + grad_true = test.grad + elseif test isa WithBackend + ldf_reference = LogDensityFunction(model, varinfo; adtype=test.adtype) + value_true, grad_true = logdensity_and_gradient(ldf_reference, params) + # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 + grad_true = collect(grad_true) end + # Perform testing verbose && println(" expected : $((value_true, grad_true))") - grad_true = collect(grad_true) - exc() = throw(ADIncorrectException(value, value_true, grad, grad_true)) - isapprox(value, value_true; atol=value_atol) || exc() - isapprox(grad, grad_true; atol=grad_atol) || exc() - else - value_true = nothing - grad_true = nothing + isapprox(value, value_true; atol=atol, rtol=rtol) || exc() + isapprox(grad, grad_true; atol=atol, rtol=rtol) || exc() end + # Benchmark time_vs_primal = if benchmark primal_benchmark = @be (ldf, params) logdensity(_[1], _[2]) grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2]) @@ -237,8 +285,8 @@ function run_ad( varinfo, params, adtype, - value_atol, - grad_atol, + atol, + rtol, value_true, grad_true, value, diff --git a/test/ad.jl b/test/ad.jl index 0947c017a..48dffeadb 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,4 +1,5 @@ using DynamicPPL: LogDensityFunction +using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest @testset "Automatic differentiation" begin # Used as the ground truth that others are compared against. @@ -31,9 +32,10 @@ using DynamicPPL: LogDensityFunction linked_varinfo = DynamicPPL.link(varinfo, m) f = LogDensityFunction(m, linked_varinfo) x = DynamicPPL.getparams(f) + # Calculate reference logp + gradient of logp using ForwardDiff - ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype) - ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual @testset "$adtype" for adtype in test_adtypes @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" @@ -50,24 +52,24 @@ using DynamicPPL: LogDensityFunction if is_mooncake && is_1_11 && is_svi_vnv # https://github.com/compintell/Mooncake.jl/issues/470 @test_throws ArgumentError DynamicPPL.LogDensityFunction( - ref_ldf, adtype + m, linked_varinfo; adtype=adtype ) elseif is_mooncake && is_1_10 && is_svi_vnv # TODO: report upstream @test_throws UndefRefError DynamicPPL.LogDensityFunction( - ref_ldf, adtype + m, linked_varinfo; adtype=adtype ) elseif is_mooncake && is_1_10 && is_svi_od # TODO: report upstream @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - ref_ldf, adtype + m, linked_varinfo; adtype=adtype ) else - @test DynamicPPL.TestUtils.AD.run_ad( + @test run_ad( m, adtype; varinfo=linked_varinfo, - expected_value_and_grad=(ref_logp, ref_grad), + test=WithExpectedResult(ref_logp, ref_grad), ) isa Any end end From cba604b3deee6301efa4c9f942a3df8fd0155316 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 16 Jul 2025 17:53:49 +0100 Subject: [PATCH 20/27] DebugAccumulator (plus tiny bits and pieces) (#976) * DebugContext -> DebugAccumulator * Changelog * Force `conditioned` to return a dict * fix conditioned implementation * revert `conditioned` bugfix (will merge this to main instead) * fix show * Fix doctests * fix doctests 2 * Make VarInfo actually mandatory in check_model * Re-implement `missing` check * Revert `combine` signature in docstring * Revert changes to `Base.show` on AccumulatorTuple * Add TODO comment about VariableOrderAccumulator Co-authored-by: Markus Hauru * Fix doctests --------- Co-authored-by: Markus Hauru --- HISTORY.md | 36 ++++- src/accumulators.jl | 14 +- src/debug_utils.jl | 322 ++++++++++++++++++++------------------------ test/debug_utils.jl | 85 +++++++----- 4 files changed, 241 insertions(+), 216 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index fcd005579..d367e9ad7 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,7 +2,15 @@ ## 0.37.0 -**Breaking changes** +DynamicPPL 0.37 comes with a substantial reworking of its internals. +Fundamentally, there is no change to the actual modelling syntax: if you are a Turing.jl user, for example, this release is unlikely to affect you much. +However, if you are a package developer or someone who uses DynamicPPL's functionality directly, you will notice a number of changes. + +To avoid overwhelming the reader, we begin by listing the most important, user-facing changes, before explaining the changes to the internals in more detail. + +Note that virtually all changes listed here are breaking. + +**Public-facing changes** ### Submodel macro @@ -19,6 +27,32 @@ There is now also an `rng` keyword argument to help seed parameter generation. Finally, instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient. Their semantics are the same as in Julia's `isapprox`; two values are equal if they satisfy either `atol` or `rtol`. +### `DynamicPPL.TestUtils.check_model` + +You now need to explicitly pass a `VarInfo` argument to `check_model` and `check_model_and_trace`. +Previously, these functions would generate a new VarInfo for you (using an optionally provided `rng`). + +### Removal of `PriorContext` and `LikelihoodContext` + +A number of DynamicPPL's contexts have been removed, most notably `PriorContext` and `LikelihoodContext`. +Although these are not the only _exported_ contexts, we consider unlikely that anyone was using _other_ contexts manually: if you have a question about contexts _other_ than these, please continue reading the 'Internals' section below. + +Previously, during evaluation of a model, DynamicPPL only had the capability to store a _single_ log probability (`logp`) field. +`DefaultContext`, `PriorContext`, and `LikelihoodContext` were used to control what this field represented: they would accumulate the log joint, log prior, or log likelihood, respectively. + +Now, we have reworked DynamicPPL's `VarInfo` object such that it can track multiple log probabilities at once (see the 'Accumulators' section below). +If you were evaluating a model with `PriorContext`, you can now just evaluate it with `DefaultContext`, and instead of calling `getlogp(varinfo)`, you can call `getlogprior(varinfo)` (and similarly for the likelihood). + +If you were constructing a `LogDensityFunction` with `PriorContext`, you can now stick to `DefaultContext`. +`LogDensityFunction` now has an extra field, called `getlogdensity`, which represents a function that takes a `VarInfo` and returns the log density you want. +Thus, if you pass `getlogprior` as the value of this parameter, you will get the same behaviour as with `PriorContext`. + +The other case where one might use `PriorContext` was to use `@addlogprob!` to add to the log prior. +Previously, this was accomplished by manually checking `__context__ isa DynamicPPL.PriorContext`. +Now, you can write `@addlogprob (; logprior=x, loglikelihood=y)` to add `x` to the log-prior and `y` to the log-likelihood. + +**Internals** + ### Accumulators This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes: diff --git a/src/accumulators.jl b/src/accumulators.jl index 10a988ae5..595c45d3f 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -53,10 +53,11 @@ function accumulate_observe!! end Update `acc` in a `tilde_assume!!` call. Returns the updated `acc`. -`vn` is the name of the variable being assumed, `val` is the value of the variable, and -`right` is the distribution on the RHS of the tilde statement. `logjac` is the log -determinant of the Jacobian of the transformation that was done to convert the value of `vn` -as it was given (e.g. by sampler operating in linked space) to `val`. +`vn` is the name of the variable being assumed, `val` is the value of the variable (in the +original, unlinked space), and `right` is the distribution on the RHS of the tilde +statement. `logjac` is the log determinant of the Jacobian of the transformation that was +done to convert the value of `vn` as it was given to `val`: for example, if the sampler is +operating in linked (Euclidean) space, then logjac will be nonzero. `accumulate_assume!!` may mutate `acc`, but not any of the other arguments. @@ -71,7 +72,7 @@ Return a new accumulator like `acc` but empty. The precise meaning of "empty" is that that the returned value should be such that `combine(acc, split(acc))` is equal to `acc`. This is used in the context of multi-threading -where different threads may accumulate independently and the results are the combined. +where different threads may accumulate independently and the results are then combined. See also: [`combine`](@ref) """ @@ -80,7 +81,8 @@ function split end """ combine(acc::AbstractAccumulator, acc2::AbstractAccumulator) -Combine two accumulators of the same type. Returns a new accumulator. +Combine two accumulators which have the same type (but may, in general, have different type +parameters). Returns a new accumulator of the same type. See also: [`split`](@ref) """ diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 4343ce8ac..d1add6e00 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -49,7 +49,7 @@ end function show_right(io::IO, d::Distribution) pnames = fieldnames(typeof(d)) - uml, namevals = Distributions._use_multline_show(d, pnames) + _, namevals = Distributions._use_multline_show(d, pnames) return Distributions.show_oneline(io, d, namevals) end @@ -76,7 +76,6 @@ Base.@kwdef struct AssumeStmt <: Stmt varname right value - varinfo = nothing end function Base.show(io::IO, stmt::AssumeStmt) @@ -88,21 +87,30 @@ function Base.show(io::IO, stmt::AssumeStmt) print(io, " ") print(io, RESULT_SYMBOL) print(io, " ") - return print(io, stmt.value) + print(io, stmt.value) + return nothing end Base.@kwdef struct ObserveStmt <: Stmt - left + varname right - varinfo = nothing + value end function Base.show(io::IO, stmt::ObserveStmt) io = add_io_context(io) - print(io, "observe: ") - show_right(io, stmt.left) + print(io, " observe: ") + if stmt.varname === nothing + print(io, stmt.value) + else + show_varname(io, stmt.varname) + print(io, " (= ") + print(io, stmt.value) + print(io, ")") + end print(io, " ~ ") - return show_right(io, stmt.right) + show_right(io, stmt.right) + return nothing end # Some utility methods for extracting information from a trace. @@ -124,98 +132,88 @@ distributions_in_stmt(stmt::AssumeStmt) = [stmt.right] distributions_in_stmt(stmt::ObserveStmt) = [stmt.right] """ - DebugContext <: AbstractContext + DebugAccumulator <: AbstractAccumulator -A context used for checking validity of a model. +An accumulator which captures tilde-statements inside a model and attempts to catch +errors in the model. # Fields -$(FIELDS) +$(TYPEDFIELDS) """ -struct DebugContext{C<:AbstractContext} <: AbstractContext - "context used for running the model" - context::C +struct DebugAccumulator <: AbstractAccumulator "mapping from varnames to the number of times they have been seen" varnames_seen::OrderedDict{VarName,Int} "tilde statements that have been executed" statements::Vector{Stmt} - "whether to throw an error if we encounter warnings" + "whether to throw an error if we encounter errors in the model" error_on_failure::Bool - "whether to record the tilde statements" - record_statements::Bool - "whether to record the varinfo in every tilde statement" - record_varinfo::Bool -end - -function DebugContext( - context::AbstractContext=DefaultContext(); - varnames_seen=OrderedDict{VarName,Int}(), - statements=Vector{Stmt}(), - error_on_failure=false, - record_statements=true, - record_varinfo=false, -) - return DebugContext( - context, - varnames_seen, - statements, - error_on_failure, - record_statements, - record_varinfo, - ) end -DynamicPPL.NodeTrait(::DebugContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(context::DebugContext) = context.context -function DynamicPPL.setchildcontext(context::DebugContext, child) - Accessors.@set context.context = child +function DebugAccumulator(error_on_failure=false) + return DebugAccumulator(OrderedDict{VarName,Int}(), Vector{Stmt}(), error_on_failure) end -function record_varname!(context::DebugContext, varname::VarName, dist) - prefixed_varname = DynamicPPL.prefix(context, varname) - if haskey(context.varnames_seen, prefixed_varname) - if context.error_on_failure - error("varname $prefixed_varname used multiple times in model") +const _DEBUG_ACC_NAME = :Debug +DynamicPPL.accumulator_name(::Type{<:DebugAccumulator}) = _DEBUG_ACC_NAME + +function split(acc::DebugAccumulator) + return DebugAccumulator( + OrderedDict{VarName,Int}(), Vector{Stmt}(), acc.error_on_failure + ) +end +function combine(acc1::DebugAccumulator, acc2::DebugAccumulator) + return DebugAccumulator( + merge(acc1.varnames_seen, acc2.varnames_seen), + vcat(acc1.statements, acc2.statements), + acc1.error_on_failure || acc2.error_on_failure, + ) +end + +function record_varname!(acc::DebugAccumulator, varname::VarName, dist) + if haskey(acc.varnames_seen, varname) + if acc.error_on_failure + error("varname $varname used multiple times in model") else - @warn "varname $prefixed_varname used multiple times in model" + @warn "varname $varname used multiple times in model" end - context.varnames_seen[prefixed_varname] += 1 + acc.varnames_seen[varname] += 1 else # We need to check: # 1. Does this `varname` subsume any of the other keys. # 2. Does any of the other keys subsume `varname`. - vns = collect(keys(context.varnames_seen)) + vns = collect(keys(acc.varnames_seen)) # Is `varname` subsumed by any of the other keys? - idx_parent = findfirst(Base.Fix2(subsumes, prefixed_varname), vns) + idx_parent = findfirst(Base.Fix2(subsumes, varname), vns) if idx_parent !== nothing varname_parent = vns[idx_parent] - if context.error_on_failure + if acc.error_on_failure error( - "varname $(varname_parent) used multiple times in model (subsumes $prefixed_varname)", + "varname $(varname_parent) used multiple times in model (subsumes $varname)", ) else - @warn "varname $(varname_parent) used multiple times in model (subsumes $prefixed_varname)" + @warn "varname $(varname_parent) used multiple times in model (subsumes $varname)" end # Update count of parent. - context.varnames_seen[varname_parent] += 1 + acc.varnames_seen[varname_parent] += 1 else # Does `varname` subsume any of the other keys? - idx_child = findfirst(Base.Fix1(subsumes, prefixed_varname), vns) + idx_child = findfirst(Base.Fix1(subsumes, varname), vns) if idx_child !== nothing varname_child = vns[idx_child] - if context.error_on_failure + if acc.error_on_failure error( - "varname $(varname_child) used multiple times in model (subsumed by $prefixed_varname)", + "varname $(varname_child) used multiple times in model (subsumed by $varname)", ) else - @warn "varname $(varname_child) used multiple times in model (subsumed by $prefixed_varname)" + @warn "varname $(varname_child) used multiple times in model (subsumed by $varname)" end # Update count of child. - context.varnames_seen[varname_child] += 1 + acc.varnames_seen[varname_child] += 1 end end - context.varnames_seen[prefixed_varname] = 1 + acc.varnames_seen[varname] = 1 end end @@ -233,83 +231,56 @@ end _has_nans(x::NamedTuple) = any(_has_nans, x) _has_nans(x::AbstractArray) = any(_has_nans, x) _has_nans(x) = isnan(x) +_has_nans(::Missing) = false -# assume -function record_pre_tilde_assume!(context::DebugContext, vn, dist, varinfo) - record_varname!(context, vn, dist) - return nothing -end - -function record_post_tilde_assume!(context::DebugContext, vn, dist, value, varinfo) - stmt = AssumeStmt(; - varname=vn, - right=dist, - value=value, - varinfo=context.record_varinfo ? varinfo : nothing, - ) - if context.record_statements - push!(context.statements, stmt) - end - return nothing +function DynamicPPL.accumulate_assume!!( + acc::DebugAccumulator, val, _logjac, vn::VarName, right::Distribution +) + record_varname!(acc, vn, right) + stmt = AssumeStmt(; varname=vn, right=right, value=val) + push!(acc.statements, stmt) + return acc end -function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi) - record_pre_tilde_assume!(context, vn, right, vi) - value, vi = DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) - record_post_tilde_assume!(context, vn, right, value, vi) - return value, vi -end -function DynamicPPL.tilde_assume( - rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi +function DynamicPPL.accumulate_observe!!( + acc::DebugAccumulator, right::Distribution, val, vn::Union{VarName,Nothing} ) - record_pre_tilde_assume!(context, vn, right, vi) - value, vi = DynamicPPL.tilde_assume(rng, childcontext(context), sampler, right, vn, vi) - record_post_tilde_assume!(context, vn, right, value, vi) - return value, vi -end - -# observe -function record_pre_tilde_observe!(context::DebugContext, left, dist, varinfo) - # Check for `missing`s; these should not end up here. - if _has_missings(left) - error( - "Encountered `missing` value(s) on the left-hand side" * - " of an observe statement. Using `missing` to de-condition" * - " a variable is only supported for univariate distributions," * - " not for $dist.", + if _has_missings(val) + # If `val` itself is a missing, that's a bug because that should cause + # us to go down the assume path. + val === missing && error( + "Encountered `missing` value on the left-hand side of an observe" * + " statement. This should not happen. Please open an issue at" * + " https://github.com/TuringLang/DynamicPPL.jl.", ) + # Otherwise it's an array with some missing values. + msg = + "Encountered a container with one or more `missing` value(s) on the" * + " left-hand side of an observe statement. To treat the variable on" * + " the left-hand side as a random variable, you should specify a single" * + " `missing` rather than a vector of `missing`s. It is not possible to" * + " set part but not all of a distribution to be `missing`." + if acc.error_on_failure + error(msg) + else + @warn msg + end end # Check for NaN's as well - if _has_nans(left) - error( + if _has_nans(val) + msg = "Encountered a NaN value on the left-hand side of an" * " observe statement; this may indicate that your data" * - " contain NaN values.", - ) + " contain NaN values." + if acc.error_on_failure + error(msg) + else + @warn msg + end end -end - -function record_post_tilde_observe!(context::DebugContext, left, right, varinfo) - stmt = ObserveStmt(; - left=left, right=right, varinfo=context.record_varinfo ? varinfo : nothing - ) - if context.record_statements - push!(context.statements, stmt) - end - return nothing -end - -function DynamicPPL.tilde_observe!!(context::DebugContext, right, left, vn, vi) - record_pre_tilde_observe!(context, left, right, vi) - vi = DynamicPPL.tilde_observe!!(childcontext(context), right, left, vn, vi) - record_post_tilde_observe!(context, left, right, vi) - return vi -end -function DynamicPPL.tilde_observe!!(context::DebugContext, sampler, right, left, vn, vi) - record_pre_tilde_observe!(context, left, right, vi) - vi = DynamicPPL.tilde_observe!!(childcontext(context), sampler, right, left, vn, vi) - record_post_tilde_observe!(context, left, right, vi) - return vi + stmt = ObserveStmt(; varname=vn, right=right, value=val) + push!(acc.statements, stmt) + return acc end _conditioned_varnames(d::AbstractDict) = keys(d) @@ -357,26 +328,26 @@ function check_model_pre_evaluation(model::Model) return issuccess end -function check_model_post_evaluation(model::Model) - return check_varnames_seen(model.context.varnames_seen) +function check_model_post_evaluation(acc::DebugAccumulator) + return check_varnames_seen(acc.varnames_seen) end """ - check_model_and_trace([rng, ]model::Model; kwargs...) + check_model_and_trace(model::Model, varinfo::AbstractVarInfo; error_on_failure=false) -Check that `model` is valid, warning about any potential issues. +Check that evaluating `model` with the given `varinfo` is valid, warning about any potential +issues. This will check the model for the following issues: + 1. Repeated usage of the same varname in a model. -2. Incorrectly treating a variable as random rather than fixed, and vice versa. +2. `NaN` on the left-hand side of observe statements. # Arguments -- `rng::Random.AbstractRNG`: The random number generator to use when evaluating the model. - `model::Model`: The model to check. +- `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. -# Keyword Arguments -- `varinfo::VarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. -- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). +# Keyword Argument - `error_on_failure::Bool`: Whether to throw an error if the model check fails. Default: `false`. # Returns @@ -394,7 +365,9 @@ julia> rng = StableRNG(42); julia> @model demo_correct() = x ~ Normal() demo_correct (generic function with 2 methods) -julia> issuccess, trace = check_model_and_trace(rng, demo_correct()); +julia> model = demo_correct(); varinfo = VarInfo(rng, model); + +julia> issuccess, trace = check_model_and_trace(model, varinfo); julia> issuccess true @@ -402,7 +375,9 @@ true julia> print(trace) assume: x ~ Normal{Float64}(μ=0.0, σ=1.0) ⟼ -0.670252 -julia> issuccess, trace = check_model_and_trace(rng, demo_correct() | (x = 1.0,)); +julia> cond_model = model | (x = 1.0,); + +julia> issuccess, trace = check_model_and_trace(cond_model, VarInfo(cond_model)); ┌ Warning: The model does not contain any parameters. └ @ DynamicPPL.DebugUtils DynamicPPL.jl/src/debug_utils.jl:342 @@ -410,7 +385,7 @@ julia> issuccess true julia> print(trace) -observe: 1.0 ~ Normal{Float64}(μ=0.0, σ=1.0) + observe: x (= 1.0) ~ Normal{Float64}(μ=0.0, σ=1.0) ``` ## Incorrect model @@ -423,58 +398,53 @@ julia> @model function demo_incorrect() end demo_incorrect (generic function with 2 methods) -julia> issuccess, trace = check_model_and_trace(rng, demo_incorrect(); error_on_failure=true); +julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually + # alert us to the issue of `x` being sampled twice. + model = demo_incorrect(); varinfo = VarInfo(model); + +julia> issuccess, trace = check_model_and_trace(model, varinfo; error_on_failure=true); ERROR: varname x used multiple times in model ``` """ -function check_model_and_trace(model::Model; kwargs...) - return check_model_and_trace(Random.default_rng(), model; kwargs...) -end function check_model_and_trace( - rng::Random.AbstractRNG, - model::Model; - varinfo=VarInfo(), - error_on_failure=false, - kwargs..., + model::Model, varinfo::AbstractVarInfo; error_on_failure=false ) - # Execute the model with the debug context. - debug_context = DebugContext( - SamplingContext(rng, model.context); error_on_failure=error_on_failure, kwargs... + # Add debug accumulator to the VarInfo. + # Need a NumProduceAccumulator as well or else get_num_produce may throw + # TODO(mhauru) Remove this once VariableOrderAccumulator stuff is done. + varinfo = DynamicPPL.setaccs!!( + deepcopy(varinfo), (DebugAccumulator(error_on_failure), NumProduceAccumulator()) ) - debug_model = DynamicPPL.contextualize(model, debug_context) # Perform checks before evaluating the model. - issuccess = check_model_pre_evaluation(debug_model) + issuccess = check_model_pre_evaluation(model) # Force single-threaded execution. - DynamicPPL.evaluate_threadunsafe!!(debug_model, varinfo) + DynamicPPL.evaluate_threadunsafe!!(model, varinfo) # Perform checks after evaluating the model. - issuccess &= check_model_post_evaluation(debug_model) + debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME)) + issuccess = issuccess && check_model_post_evaluation(debug_acc) if !issuccess && error_on_failure error("model check failed") end - trace = debug_context.statements + trace = debug_acc.statements return issuccess, trace end """ - check_model([rng, ]model::Model; kwargs...) - -Check that `model` is valid, warning about any potential issues. + check_model(model::Model, varinfo::AbstractVarInfo; error_on_failure=false) -See [`check_model_and_trace`](@ref) for more details on supported keyword arguments -and details of which types of checks are performed. +Check that `model` is valid, warning about any potential issues (or erroring if +`error_on_failure` is `true`). # Returns - `issuccess::Bool`: Whether the model check succeeded. """ -check_model(model::Model; kwargs...) = first(check_model_and_trace(model; kwargs...)) -function check_model(rng::Random.AbstractRNG, model::Model; kwargs...) - return first(check_model_and_trace(rng, model; kwargs...)) -end +check_model(model::Model, varinfo::AbstractVarInfo; error_on_failure=false) = + first(check_model_and_trace(model, varinfo; error_on_failure=error_on_failure)) # Convenience method used to check if all elements in a list are the same. function all_the_same(xs) @@ -490,7 +460,7 @@ function all_the_same(xs) end """ - has_static_constraints([rng, ]model::Model; num_evals=5, kwargs...) + has_static_constraints([rng, ]model::Model; num_evals=5, error_on_failure=false) Return `true` if the model has static constraints, `false` otherwise. @@ -503,19 +473,16 @@ and checking if the model is consistent across runs. # Keyword Arguments - `num_evals::Int`: The number of evaluations to perform. Default: `5`. -- `kwargs...`: Additional keyword arguments to pass to [`check_model_and_trace`](@ref). +- `error_on_failure::Bool`: Whether to throw an error if any of the `num_evals` model + checks fail. Default: `false`. """ -function has_static_constraints(model::Model; kwargs...) - return has_static_constraints(Random.default_rng(), model; kwargs...) -end function has_static_constraints( - rng::Random.AbstractRNG, model::Model; num_evals=5, kwargs... + rng::Random.AbstractRNG, model::Model; num_evals::Int=5, error_on_failure::Bool=false ) + new_model = DynamicPPL.contextualize(model, SamplingContext(rng, SampleFromPrior())) results = map(1:num_evals) do _ - check_model_and_trace(rng, model; kwargs...) + check_model_and_trace(new_model, VarInfo(); error_on_failure=error_on_failure) end - issuccess = all(first, results) - issuccess || throw(ArgumentError("model check failed")) # Extract the distributions and the corresponding bijectors for each run. traces = map(last, results) @@ -527,6 +494,13 @@ function has_static_constraints( # Check if the distributions are the same across all runs. return all_the_same(transforms) end +function has_static_constraints( + model::Model; num_evals::Int=5, error_on_failure::Bool=false +) + return has_static_constraints( + Random.default_rng(), model; num_evals=num_evals, error_on_failure=error_on_failure + ) +end """ gen_evaluator_call_with_types(model[, varinfo]) diff --git a/test/debug_utils.jl b/test/debug_utils.jl index 8279ac51a..5bf741ff3 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -1,13 +1,6 @@ @testset "check_model" begin - @testset "context interface" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - context = DynamicPPL.DebugUtils.DebugContext() - DynamicPPL.TestUtils.test_context(context, model) - end - end - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - issuccess, trace = check_model_and_trace(model) + issuccess, trace = check_model_and_trace(model, VarInfo(model)) # These models should all work. @test issuccess @@ -33,11 +26,14 @@ return y ~ Normal() end buggy_model = buggy_demo_model() + varinfo = VarInfo(buggy_model) - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model(buggy_model; record_varinfo=false) + @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) + issuccess = check_model(buggy_model, varinfo) @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) + @test_throws ErrorException check_model( + buggy_model, varinfo; error_on_failure=true + ) end @testset "submodel" begin @@ -48,7 +44,10 @@ return x ~ Normal() end model = ModelOuterBroken() - @test_throws ErrorException check_model(model; error_on_failure=true) + varinfo = VarInfo(model) + @test_throws ErrorException check_model( + model, VarInfo(model); error_on_failure=true + ) @model function ModelOuterWorking() # With automatic prefixing => `x` is not duplicated. @@ -57,7 +56,7 @@ return z end model = ModelOuterWorking() - @test check_model(model; error_on_failure=true) + @test check_model(model, VarInfo(model); error_on_failure=true) # With manual prefixing, https://github.com/TuringLang/DynamicPPL.jl/issues/785 @model function ModelOuterWorking2() @@ -66,7 +65,7 @@ return (x1, x2) end model = ModelOuterWorking2() - @test check_model(model; error_on_failure=true) + @test check_model(model, VarInfo(model); error_on_failure=true) end @testset "subsumes (x then x[1])" begin @@ -77,11 +76,14 @@ return nothing end buggy_model = buggy_subsumes_demo_model() + varinfo = VarInfo(buggy_model) - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model(buggy_model; record_varinfo=false) + @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) + issuccess = check_model(buggy_model, varinfo) @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) + @test_throws ErrorException check_model( + buggy_model, varinfo; error_on_failure=true + ) end @testset "subsumes (x[1] then x)" begin @@ -92,11 +94,14 @@ return nothing end buggy_model = buggy_subsumes_demo_model() + varinfo = VarInfo(buggy_model) - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model(buggy_model; record_varinfo=false) + @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) + issuccess = check_model(buggy_model, varinfo) @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) + @test_throws ErrorException check_model( + buggy_model, varinfo; error_on_failure=true + ) end @testset "subsumes (x.a then x)" begin @@ -107,11 +112,14 @@ return nothing end buggy_model = buggy_subsumes_demo_model() + varinfo = VarInfo(buggy_model) - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model(buggy_model; record_varinfo=false) + @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) + issuccess = check_model(buggy_model, varinfo) @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) + @test_throws ErrorException check_model( + buggy_model, varinfo; error_on_failure=true + ) end end @@ -123,14 +131,14 @@ end end m = demo_nan_in_data([1.0, NaN]) - @test_throws ErrorException check_model(m; error_on_failure=true) + @test_throws ErrorException check_model(m, VarInfo(m); error_on_failure=true) # Test NamedTuples with nested arrays, see #898 @model function demo_nan_complicated(nt) nt ~ product_distribution((x=Normal(), y=Dirichlet([2, 4]))) return x ~ Normal() end m = demo_nan_complicated((x=1.0, y=[NaN, 0.5])) - @test_throws ErrorException check_model(m; error_on_failure=true) + @test_throws ErrorException check_model(m, VarInfo(m); error_on_failure=true) end @testset "incorrect use of condition" begin @@ -139,7 +147,10 @@ return x ~ MvNormal(zeros(length(x)), I) end model = demo_missing_in_multivariate([1.0, missing]) - @test_throws ErrorException check_model(model) + # Have to run this check_model call with an empty varinfo, because actually + # instantiating the VarInfo would cause it to throw a MethodError. + model = contextualize(model, SamplingContext()) + @test_throws ErrorException check_model(model, VarInfo(); error_on_failure=true) end @testset "condition both in args and context" begin @@ -153,8 +164,9 @@ OrderedDict(@varname(x[1]) => 2.0), ] conditioned_model = DynamicPPL.condition(model, vals) + varinfo = VarInfo(conditioned_model) @test_throws ErrorException check_model( - conditioned_model; error_on_failure=true + conditioned_model, varinfo; error_on_failure=true ) end end @@ -163,23 +175,26 @@ @testset "printing statements" begin @testset "assume" begin @model demo_assume() = x ~ Normal() - isuccess, trace = check_model_and_trace(demo_assume()) - @test isuccess + model = demo_assume() + issuccess, trace = check_model_and_trace(model, VarInfo(model)) + @test issuccess @test startswith(string(trace), " assume: x ~ Normal") end @testset "observe" begin @model demo_observe(x) = x ~ Normal() - isuccess, trace = check_model_and_trace(demo_observe(1.0)) - @test isuccess - @test occursin(r"observe: \d+\.\d+ ~ Normal", string(trace)) + model = demo_observe(1.0) + issuccess, trace = check_model_and_trace(model, VarInfo(model)) + @test issuccess + @test occursin(r"observe: x \(= \d+\.\d+\) ~ Normal", string(trace)) end end @testset "comparing multiple traces" begin + # Run the same model but with different VarInfos. model = DynamicPPL.TestUtils.demo_dynamic_constraint() - issuccess_1, trace_1 = check_model_and_trace(model) - issuccess_2, trace_2 = check_model_and_trace(model) + issuccess_1, trace_1 = check_model_and_trace(model, VarInfo(model)) + issuccess_2, trace_2 = check_model_and_trace(model, VarInfo(model)) @test issuccess_1 && issuccess_2 # Should have the same varnames present. @@ -204,7 +219,7 @@ end for ns in [(2,), (2, 2), (2, 2, 2)] model = demo_undef(ns...) - @test check_model(model; error_on_failure=true) + @test check_model(model, VarInfo(model); error_on_failure=true) end end From f4dd46a34112d9cd75670b59787e64adb689e71b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 18 Jul 2025 16:33:16 +0100 Subject: [PATCH 21/27] VariableOrderAccumulator (#940) * Turn NumProduceAccumulator into VariableOrderAccumulator * Add comparison methods * Make VariableOrderAccumulator use regular Dict * Use copy rather than deepcopy for accumulators * Minor docstring touchup * Remove unnecessary use of NumProduceAccumulator * Fix split(VariableOrderAccumulator) * Remove NumProduceAcc from Debug * Fix set_retained_vns_del! --------- Co-authored-by: Penelope Yong --- docs/src/api.md | 4 +- src/DynamicPPL.jl | 4 +- src/abstract_varinfo.jl | 34 +++++++- src/accumulators.jl | 4 + src/context_implementations.jl | 1 - src/debug_utils.jl | 6 +- src/default_accumulators.jl | 106 +++++++++++++++++------ src/extract_priors.jl | 16 ++-- src/pointwise_logdensities.jl | 4 + src/simple_varinfo.jl | 14 ++- src/threadsafe.jl | 4 +- src/values_as_in_model.jl | 4 + src/varinfo.jl | 151 +++++++++------------------------ test/accumulators.jl | 43 +++++----- test/varinfo.jl | 62 +++++++++++--- 15 files changed, 260 insertions(+), 197 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index e918a095c..180e8dfd4 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -341,7 +341,7 @@ get_num_produce set_num_produce!! increment_num_produce!! reset_num_produce!! -setorder! +setorder!! set_retained_vns_del! ``` @@ -368,7 +368,7 @@ DynamicPPL provides the following default accumulators. ```@docs LogPriorAccumulator LogLikelihoodAccumulator -NumProduceAccumulator +VariableOrderAccumulator ``` ### Common API diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 69e489ce6..c282939a2 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -50,7 +50,7 @@ export AbstractVarInfo, AbstractAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, - NumProduceAccumulator, + VariableOrderAccumulator, push!!, empty!!, subset, @@ -73,7 +73,7 @@ export AbstractVarInfo, is_flagged, set_flag!, unset_flag!, - setorder!, + setorder!!, istrans, link, link!!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 68d3f9c03..581ca829b 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -374,6 +374,24 @@ function resetlogp!!(vi::AbstractVarInfo) return vi end +""" + setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer) + +Set the `order` of `vn` in `vi` to `index`, where `order` is the number of `observe +statements run before sampling `vn`. +""" +function setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer) + return map_accumulator!!(acc -> (acc.order[vn] = index; acc), vi, Val(:VariableOrder)) +end + +""" + getorder(vi::VarInfo, vn::VarName) + +Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements +run before sampling `vn`. +""" +getorder(vi::AbstractVarInfo, vn::VarName) = getacc(vi, Val(:VariableOrder)).order[vn] + # Variables and their realizations. @doc """ keys(vi::AbstractVarInfo) @@ -980,14 +998,22 @@ end Return the `num_produce` of `vi`. """ -get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:NumProduce)).num +get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:VariableOrder)).num_produce """ set_num_produce!!(vi::AbstractVarInfo, n::Int) Set the `num_produce` field of `vi` to `n`. """ -set_num_produce!!(vi::AbstractVarInfo, n::Int) = setacc!!(vi, NumProduceAccumulator(n)) +function set_num_produce!!(vi::AbstractVarInfo, n::Integer) + if hasacc(vi, Val(:VariableOrder)) + acc = getacc(vi, Val(:VariableOrder)) + acc = VariableOrderAccumulator(n, acc.order) + else + acc = VariableOrderAccumulator(n) + end + return setacc!!(vi, acc) +end """ increment_num_produce!!(vi::AbstractVarInfo) @@ -995,14 +1021,14 @@ set_num_produce!!(vi::AbstractVarInfo, n::Int) = setacc!!(vi, NumProduceAccumula Add 1 to `num_produce` in `vi`. """ increment_num_produce!!(vi::AbstractVarInfo) = - map_accumulator!!(increment, vi, Val(:NumProduce)) + map_accumulator!!(increment, vi, Val(:VariableOrder)) """ reset_num_produce!!(vi::AbstractVarInfo) Reset the value of `num_produce` in `vi` to 0. """ -reset_num_produce!!(vi::AbstractVarInfo) = map_accumulator!!(zero, vi, Val(:NumProduce)) +reset_num_produce!!(vi::AbstractVarInfo) = set_num_produce!!(vi, zero(get_num_produce(vi))) """ from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist]) diff --git a/src/accumulators.jl b/src/accumulators.jl index 595c45d3f..1e3e37e61 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -13,6 +13,7 @@ An accumulator type `T <: AbstractAccumulator` must implement the following meth - `accumulator_name(acc::T)` or `accumulator_name(::Type{T})` - `accumulate_observe!!(acc::T, right, left, vn)` - `accumulate_assume!!(acc::T, val, logjac, vn, right)` +- `Base.copy(acc::T)` To be able to work with multi-threading, it should also implement: - `split(acc::T)` @@ -138,6 +139,9 @@ function Base.haskey(at::AccumulatorTuple, ::Val{accname}) where {accname} @inline return haskey(at.nt, accname) end Base.keys(at::AccumulatorTuple) = keys(at.nt) +Base.:(==)(at1::AccumulatorTuple, at2::AccumulatorTuple) = at1.nt == at2.nt +Base.hash(at::AccumulatorTuple, h::UInt) = Base.hash((AccumulatorTuple, at.nt), h) +Base.copy(at::AccumulatorTuple) = AccumulatorTuple(map(copy, at.nt)) function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N}) where {N,T} return AccumulatorTuple(convert(T, accs.nt)) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index b11a723a5..9e9a2d63d 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -148,7 +148,6 @@ function assume( f = to_maybe_linked_internal_transform(vi, vn, dist) # TODO(mhauru) This should probably be call a function called setindex_internal! vi = BangBang.setindex!!(vi, f(r), vn) - setorder!(vi, vn, get_num_produce(vi)) else # Otherwise we just extract it. r = vi[vn, dist] diff --git a/src/debug_utils.jl b/src/debug_utils.jl index d1add6e00..d71fa57cc 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -410,11 +410,7 @@ function check_model_and_trace( model::Model, varinfo::AbstractVarInfo; error_on_failure=false ) # Add debug accumulator to the VarInfo. - # Need a NumProduceAccumulator as well or else get_num_produce may throw - # TODO(mhauru) Remove this once VariableOrderAccumulator stuff is done. - varinfo = DynamicPPL.setaccs!!( - deepcopy(varinfo), (DebugAccumulator(error_on_failure), NumProduceAccumulator()) - ) + varinfo = DynamicPPL.setaccs!!(deepcopy(varinfo), (DebugAccumulator(error_on_failure),)) # Perform checks before evaluating the model. issuccess = check_model_pre_evaluation(model) diff --git a/src/default_accumulators.jl b/src/default_accumulators.jl index ab538ba51..418362e8f 100644 --- a/src/default_accumulators.jl +++ b/src/default_accumulators.jl @@ -41,25 +41,40 @@ LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T) LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}() """ - NumProduceAccumulator{T} <: AbstractAccumulator + VariableOrderAccumulator{T} <: AbstractAccumulator -An accumulator that tracks the number of observations during model execution. +An accumulator that tracks the order of variables in a `VarInfo`. + +This doesn't track the full ordering, but rather how many observations have taken place +before the assume statement for each variable. This is needed for particle methods, where +the model is segmented into parts by each observation, and we need to know which part each +assume statement is in. # Fields $(TYPEDFIELDS) """ -struct NumProduceAccumulator{T<:Integer} <: AbstractAccumulator +struct VariableOrderAccumulator{Eltype<:Integer,VNType<:VarName} <: AbstractAccumulator "the number of observations" - num::T + num_produce::Eltype + "mapping of variable names to their order in the model" + order::Dict{VNType,Eltype} end """ - NumProduceAccumulator{T<:Integer}() + VariableOrderAccumulator{T<:Integer}(n=zero(T)) -Create a new `NumProduceAccumulator` accumulator with the number of observations initialized to zero. +Create a new `VariableOrderAccumulator` with the number of observations set to `n`. """ -NumProduceAccumulator{T}() where {T<:Integer} = NumProduceAccumulator(zero(T)) -NumProduceAccumulator() = NumProduceAccumulator{Int}() +VariableOrderAccumulator{T}(n=zero(T)) where {T<:Integer} = + VariableOrderAccumulator(convert(T, n), Dict{VarName,T}()) +VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n) +VariableOrderAccumulator() = VariableOrderAccumulator{Int}() + +Base.copy(acc::LogPriorAccumulator) = acc +Base.copy(acc::LogLikelihoodAccumulator) = acc +function Base.copy(acc::VariableOrderAccumulator) + return VariableOrderAccumulator(acc.num_produce, copy(acc.order)) +end function Base.show(io::IO, acc::LogPriorAccumulator) return print(io, "LogPriorAccumulator($(repr(acc.logp)))") @@ -67,17 +82,48 @@ end function Base.show(io::IO, acc::LogLikelihoodAccumulator) return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))") end -function Base.show(io::IO, acc::NumProduceAccumulator) - return print(io, "NumProduceAccumulator($(repr(acc.num)))") +function Base.show(io::IO, acc::VariableOrderAccumulator) + return print( + io, "VariableOrderAccumulator($(repr(acc.num_produce)), $(repr(acc.order)))" + ) +end + +# Note that == and isequal are different, and equality under the latter should imply +# equality of hashes. Both of the below implementations are also different from the default +# implementation for structs. +Base.:(==)(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) = acc1.logp == acc2.logp +function Base.:(==)(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) + return acc1.logp == acc2.logp +end +function Base.:(==)(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) + return acc1.num_produce == acc2.num_produce && acc1.order == acc2.order +end + +function Base.isequal(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) + return isequal(acc1.logp, acc2.logp) +end +function Base.isequal(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) + return isequal(acc1.logp, acc2.logp) +end +function Base.isequal(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) + return isequal(acc1.num_produce, acc2.num_produce) && isequal(acc1.order, acc2.order) +end + +Base.hash(acc::LogPriorAccumulator, h::UInt) = hash((LogPriorAccumulator, acc.logp), h) +function Base.hash(acc::LogLikelihoodAccumulator, h::UInt) + return hash((LogLikelihoodAccumulator, acc.logp), h) +end +function Base.hash(acc::VariableOrderAccumulator, h::UInt) + return hash((VariableOrderAccumulator, acc.num_produce, acc.order), h) end accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood -accumulator_name(::Type{<:NumProduceAccumulator}) = :NumProduce +accumulator_name(::Type{<:VariableOrderAccumulator}) = :VariableOrder split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T)) split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T)) -split(acc::NumProduceAccumulator) = acc +split(acc::VariableOrderAccumulator) = copy(acc) function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator) return LogPriorAccumulator(acc.logp + acc2.logp) @@ -85,8 +131,12 @@ end function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) return LogLikelihoodAccumulator(acc.logp + acc2.logp) end -function combine(acc::NumProduceAccumulator, acc2::NumProduceAccumulator) - return NumProduceAccumulator(max(acc.num, acc2.num)) +function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator) + # Note that assumptions are not allowed in parallelised blocks, and thus the + # dictionaries should be identical. + return VariableOrderAccumulator( + max(acc.num_produce, acc2.num_produce), merge(acc.order, acc2.order) + ) end function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) @@ -95,11 +145,12 @@ end function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) return LogLikelihoodAccumulator(acc1.logp + acc2.logp) end -increment(acc::NumProduceAccumulator) = NumProduceAccumulator(acc.num + oneunit(acc.num)) +function increment(acc::VariableOrderAccumulator) + return VariableOrderAccumulator(acc.num_produce + oneunit(acc.num_produce), acc.order) +end Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp)) Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp)) -Base.zero(acc::NumProduceAccumulator) = NumProduceAccumulator(zero(acc.num)) function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right) return acc + LogPriorAccumulator(logpdf(right, val) + logjac) @@ -114,8 +165,11 @@ function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left)) end -accumulate_assume!!(acc::NumProduceAccumulator, val, logjac, vn, right) = acc -accumulate_observe!!(acc::NumProduceAccumulator, right, left, vn) = increment(acc) +function accumulate_assume!!(acc::VariableOrderAccumulator, val, logjac, vn, right) + acc.order[vn] = acc.num_produce + return acc +end +accumulate_observe!!(acc::VariableOrderAccumulator, right, left, vn) = increment(acc) function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T} return LogPriorAccumulator(convert(T, acc.logp)) @@ -126,15 +180,19 @@ function Base.convert( return LogLikelihoodAccumulator(convert(T, acc.logp)) end function Base.convert( - ::Type{NumProduceAccumulator{T}}, acc::NumProduceAccumulator -) where {T} - return NumProduceAccumulator(convert(T, acc.num)) + ::Type{VariableOrderAccumulator{ElType,VnType}}, acc::VariableOrderAccumulator +) where {ElType,VnType} + order = Dict{VnType,ElType}() + for (k, v) in acc.order + order[convert(VnType, k)] = convert(ElType, v) + end + return VariableOrderAccumulator(convert(ElType, acc.num_produce), order) end # TODO(mhauru) -# We ignore the convert_eltype calls for NumProduceAccumulator, by letting them fallback on +# We ignore the convert_eltype calls for VariableOrderAccumulator, by letting them fallback on # convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to -# deal with dual number types of AD backends, which shouldn't concern NumProduceAccumulator. This is +# deal with dual number types of AD backends, which shouldn't concern VariableOrderAccumulator. This is # horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`. function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T} return LogPriorAccumulator(convert(T, acc.logp)) @@ -149,6 +207,6 @@ function default_accumulators( return AccumulatorTuple( LogPriorAccumulator{FloatT}(), LogLikelihoodAccumulator{FloatT}(), - NumProduceAccumulator{IntT}(), + VariableOrderAccumulator{IntT}(), ) end diff --git a/src/extract_priors.jl b/src/extract_priors.jl index bd6bdb2f2..64dcf2eea 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -4,6 +4,10 @@ end PriorDistributionAccumulator() = PriorDistributionAccumulator(OrderedDict{VarName,Any}()) +function Base.copy(acc::PriorDistributionAccumulator) + return PriorDistributionAccumulator(copy(acc.priors)) +end + accumulator_name(::PriorDistributionAccumulator) = :PriorDistributionAccumulator split(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors)) @@ -112,10 +116,7 @@ extract_priors(args::Union{Model,AbstractVarInfo}...) = extract_priors(Random.default_rng(), args...) function extract_priors(rng::Random.AbstractRNG, model::Model) varinfo = VarInfo() - # TODO(mhauru) This doesn't actually need the NumProduceAccumulator, it's only a - # workaround for the fact that `order` is still hardcoded in VarInfo, and hence you - # can't push new variables without knowing the num_produce. Remove this when possible. - varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(), NumProduceAccumulator())) + varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(),)) varinfo = last(evaluate_and_sample!!(rng, model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end @@ -129,12 +130,7 @@ This is done by evaluating the model at the values present in `varinfo` and recording the distributions that are present at each tilde statement. """ function extract_priors(model::Model, varinfo::AbstractVarInfo) - # TODO(mhauru) This doesn't actually need the NumProduceAccumulator, it's only a - # workaround for the fact that `order` is still hardcoded in VarInfo, and hence you - # can't push new variables without knowing the num_produce. Remove this when possible. - varinfo = setaccs!!( - deepcopy(varinfo), (PriorDistributionAccumulator(), NumProduceAccumulator()) - ) + varinfo = setaccs!!(deepcopy(varinfo), (PriorDistributionAccumulator(),)) varinfo = last(evaluate!!(model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 59cc5e1bb..44882f91e 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -32,6 +32,10 @@ function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps) end +function Base.copy(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob}(copy(acc.logps)) +end + function Base.push!(acc::PointwiseLogProbAccumulator, vn, logp) logps = acc.logps # The last(fieldtypes(eltype(...))) gets the type of the values, rather than the keys. diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index ddc3275ae..d8a00d5ea 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -122,7 +122,7 @@ Evaluation in transformed space of course also works: ```jldoctest simplevarinfo-general julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) -Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0))) +Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) julia> # (✓) Positive probability mass on negative numbers! getlogjoint(last(DynamicPPL.evaluate!!(m, vi))) @@ -130,7 +130,7 @@ julia> # (✓) Positive probability mass on negative numbers! julia> # While if we forget to indicate that it's transformed: vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) -SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0))) +SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) julia> # (✓) No probability mass on negative numbers! getlogjoint(last(DynamicPPL.evaluate!!(m, vi))) @@ -195,6 +195,12 @@ struct SimpleVarInfo{NT,Accs<:AccumulatorTuple where {N},C<:AbstractTransformati transformation::C end +function Base.:(==)(vi1::SimpleVarInfo, vi2::SimpleVarInfo) + return vi1.values == vi2.values && + vi1.accs == vi2.accs && + vi1.transformation == vi2.transformation +end + transformation(vi::SimpleVarInfo) = vi.transformation function SimpleVarInfo(values, accs) @@ -249,7 +255,7 @@ end # Constructor from `VarInfo`. function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D} values = values_as(vi, D) - return SimpleVarInfo(values, deepcopy(getaccs(vi))) + return SimpleVarInfo(values, copy(getaccs(vi))) end function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D} values = values_as(vi, D) @@ -448,7 +454,7 @@ _subset(x::VarNamedVector, vns) = subset(x, vns) # `merge` function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) values = merge(varinfo_left.values, varinfo_right.values) - accs = deepcopy(getaccs(varinfo_right)) + accs = copy(getaccs(varinfo_right)) transformation = merge_transformations( varinfo_left.transformation, varinfo_right.transformation ) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 51c57651d..9b82cd8b4 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -87,7 +87,9 @@ end syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) -setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index) +function setorder!!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) + return ThreadSafeVarInfo(setorder!!(vi.varinfo, vn, index), vi.accs_by_thread) +end setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 4922ddbb0..1fa0555f0 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -20,6 +20,10 @@ function ValuesAsInModelAccumulator(include_colon_eq) return ValuesAsInModelAccumulator(OrderedDict(), include_colon_eq) end +function Base.copy(acc::ValuesAsInModelAccumulator) + return ValuesAsInModelAccumulator(copy(acc.values), acc.include_colon_eq) +end + accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel function split(acc::ValuesAsInModelAccumulator) diff --git a/src/varinfo.jl b/src/varinfo.jl index b3380e7f9..f280eb98b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -15,10 +15,9 @@ not. Let `md` be an instance of `Metadata`: - `md.vns` is the vector of all `VarName` instances. - `md.idcs` is the dictionary that maps each `VarName` instance to its index in - `md.vns`, `md.ranges` `md.dists`, `md.orders` and `md.flags`. + `md.vns`, `md.ranges` `md.dists`, and `md.flags`. - `md.vns[md.idcs[vn]] == vn`. - `md.dists[md.idcs[vn]]` is the distribution of `vn`. -- `md.orders[md.idcs[vn]]` is the number of `observe` statements before `vn` is sampled. - `md.ranges[md.idcs[vn]]` is the index range of `vn` in `md.vals`. - `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`. - `md.flags` is a dictionary of true/false flags. `md.flags[flag][md.idcs[vn]]` is the @@ -57,13 +56,21 @@ struct Metadata{ # Vector of distributions correpsonding to `vns` dists::TDists # AbstractVector{<:Distribution} - # Number of `observe` statements before each random variable is sampled - orders::Vector{Int} - # Each `flag` has a `BitVector` `flags[flag]`, where `flags[flag][i]` is the true/false flag value corresonding to `vns[i]` flags::Dict{String,BitVector} end +function Base.:(==)(md1::Metadata, md2::Metadata) + return ( + md1.idcs == md2.idcs && + md1.vns == md2.vns && + md1.ranges == md2.ranges && + md1.vals == md2.vals && + md1.dists == md2.dists && + md1.flags == md2.flags + ) +end + ########### # VarInfo # ########### @@ -143,6 +150,10 @@ const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} } +function Base.:(==)(vi1::VarInfo, vi2::VarInfo) + return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs) +end + # NOTE: This is kind of weird, but it effectively preserves the "old" # behavior where we're allowed to call `link!` on the same `VarInfo` # multiple times. @@ -227,8 +238,6 @@ function typed_varinfo(vi::UntypedVarInfo) sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) # New dists sym_dists = getindex.((meta.dists,), inds) - # New orders - sym_orders = getindex.((meta.orders,), inds) # New flags sym_flags = Dict(a => meta.flags[a][inds] for a in keys(meta.flags)) @@ -246,13 +255,11 @@ function typed_varinfo(vi::UntypedVarInfo) push!( new_metas, - Metadata( - sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_orders, sym_flags - ), + Metadata(sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_flags), ) end nt = NamedTuple{syms_tuple}(Tuple(new_metas)) - return VarInfo(nt, deepcopy(vi.accs)) + return VarInfo(nt, copy(vi.accs)) end function typed_varinfo(vi::NTVarInfo) # This function preserves the behaviour of typed_varinfo(vi) where vi is @@ -295,7 +302,7 @@ Return a VarInfo object for the given `model`, which has just a single """ function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) - return VarInfo(md, deepcopy(vi.accs)) + return VarInfo(md, copy(vi.accs)) end function untyped_vector_varinfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() @@ -319,12 +326,12 @@ Return a VarInfo object for the given `model`, which has a NamedTuple of """ function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) - return VarInfo(md, deepcopy(vi.accs)) + return VarInfo(md, copy(vi.accs)) end function typed_vector_varinfo(vi::UntypedVectorVarInfo) new_metas = group_by_symbol(vi.metadata) nt = NamedTuple(new_metas) - return VarInfo(nt, deepcopy(vi.accs)) + return VarInfo(nt, copy(vi.accs)) end function typed_vector_varinfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() @@ -358,8 +365,7 @@ function unflatten(vi::VarInfo, x::AbstractVector) # convert to into an intermediate variable makes this unstable (constant propagation) # fails. Take care when editing. accs = map( - acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), - deepcopy(getaccs(vi)), + acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), copy(getaccs(vi)) ) return VarInfo(md, accs) end @@ -382,7 +388,7 @@ end end function unflatten_metadata(md::Metadata, x::AbstractVector) - return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.orders, md.flags) + return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.flags) end unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) @@ -408,7 +414,6 @@ function Metadata() Vector{UnitRange{Int}}(), vals, Vector{Distribution}(), - Vector{Int}(), flags, ) end @@ -426,7 +431,6 @@ function empty!(meta::Metadata) empty!(meta.ranges) empty!(meta.vals) empty!(meta.dists) - empty!(meta.orders) for k in keys(meta.flags) empty!(meta.flags[k]) end @@ -443,7 +447,7 @@ end function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) metadata = subset(varinfo.metadata, vns) - return VarInfo(metadata, deepcopy(varinfo.accs)) + return VarInfo(metadata, copy(varinfo.accs)) end function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName}) @@ -515,15 +519,7 @@ function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:Va end flags = Dict(k => v[indices_for_vns] for (k, v) in metadata.flags) - return Metadata( - indices, - vns, - ranges, - vals, - metadata.dists[indices_for_vns], - metadata.orders[indices_for_vns], - flags, - ) + return Metadata(indices, vns, ranges, vals, metadata.dists[indices_for_vns], flags) end function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) @@ -532,7 +528,7 @@ end function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) - return VarInfo(metadata, deepcopy(varinfo_right.accs)) + return VarInfo(metadata, copy(varinfo_right.accs)) end function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector) @@ -593,7 +589,6 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) ranges = Vector{UnitRange{Int}}() vals = T[] dists = D[] - orders = Int[] flags = Dict{String,BitVector}() # Initialize the `flags`. for k in union(keys(metadata_left.flags), keys(metadata_right.flags)) @@ -615,13 +610,12 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) offset = r[end] dist = getdist(metadata_for_vn, vn) push!(dists, dist) - push!(orders, getorder(metadata_for_vn, vn)) for k in keys(flags) push!(flags[k], is_flagged(metadata_for_vn, vn, k)) end end - return Metadata(idcs, vns, ranges, vals, dists, orders, flags) + return Metadata(idcs, vns, ranges, vals, dists, flags) end const VarView = Union{Int,UnitRange,Vector{Int}} @@ -1288,7 +1282,6 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ ranges_new, reduce(vcat, vals_new), metadata.dists, - metadata.orders, metadata.flags, ), cumulative_logjac @@ -1454,7 +1447,6 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ ranges_new, reduce(vcat, vals_new), metadata.dists, - metadata.orders, metadata.flags, ), cumulative_logjac @@ -1636,7 +1628,6 @@ function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) ("VarNames", vi.metadata.vns), ("Range", vi.metadata.ranges), ("Vals", vi.metadata.vals), - ("Orders", vi.metadata.orders), ] for accname in acckeys(vi) push!(lines, (string(accname), getacc(vi, Val(accname)))) @@ -1721,13 +1712,12 @@ function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) [1:length(val)], val, [dist], - [get_num_produce(vi)], Dict{String,BitVector}("trans" => [false], "del" => [false]), ) vi = Accessors.@set vi.metadata[sym] = md else meta = getmetadata(vi, vn) - push!(meta, vn, r, dist, get_num_produce(vi)) + push!(meta, vn, r, dist) end return vi @@ -1747,7 +1737,7 @@ end # exist in the NTVarInfo already. We could implement it in the cases where it it does # exist, but that feels a bit pointless. I think we should rather rely on `push!!`. -function Base.push!(meta::Metadata, vn, r, dist, num_produce) +function Base.push!(meta::Metadata, vn, r, dist) val = tovec(r) meta.idcs[vn] = length(meta.idcs) + 1 push!(meta.vns, vn) @@ -1756,7 +1746,6 @@ function Base.push!(meta::Metadata, vn, r, dist, num_produce) push!(meta.ranges, (l + 1):(l + n)) append!(meta.vals, val) push!(meta.dists, dist) - push!(meta.orders, num_produce) push!(meta.flags["del"], false) push!(meta.flags["trans"], false) return meta @@ -1767,31 +1756,6 @@ function Base.delete!(vi::VarInfo, vn::VarName) return vi end -""" - setorder!(vi::VarInfo, vn::VarName, index::Int) - -Set the `order` of `vn` in `vi` to `index`, where `order` is the number of `observe -statements run before sampling `vn`. -""" -function setorder!(vi::VarInfo, vn::VarName, index::Int) - setorder!(getmetadata(vi, vn), vn, index) - return vi -end -function setorder!(metadata::Metadata, vn::VarName, index::Int) - metadata.orders[metadata.idcs[vn]] = index - return metadata -end -setorder!(vnv::VarNamedVector, ::VarName, ::Int) = vnv - -""" - getorder(vi::VarInfo, vn::VarName) - -Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements -run before sampling `vn`. -""" -getorder(vi::VarInfo, vn::VarName) = getorder(getmetadata(vi, vn), vn) -getorder(metadata::Metadata, vn::VarName) = metadata.orders[getidx(metadata, vn)] - ####################################### # Rand & replaying method for VarInfo # ####################################### @@ -1849,55 +1813,24 @@ end """ set_retained_vns_del!(vi::VarInfo) -Set the `"del"` flag of variables in `vi` with `order > vi.num_produce[]` to `true`. +Set the `"del"` flag of variables in `vi` with `order > num_produce` to `true`. + +Will error if `vi` does not have an accumulator for `VariableOrder`. """ -function set_retained_vns_del!(vi::UntypedVarInfo) - idcs = _getidcs(vi) - if get_num_produce(vi) == 0 - for i in length(idcs):-1:1 - vi.metadata.flags["del"][idcs[i]] = true - end - else - for i in 1:length(vi.orders) - if i in idcs && vi.orders[i] > get_num_produce(vi) - vi.metadata.flags["del"][i] = true - end +function set_retained_vns_del!(vi::VarInfo) + if !hasacc(vi, Val(:VariableOrder)) + msg = "`vi` must have an accumulator for VariableOrder to set the `del` flag." + raise(ArgumentError(msg)) + end + num_produce = get_num_produce(vi) + for vn in keys(vi) + order = getorder(vi, vn) + if order > num_produce + set_flag!(vi, vn, "del") end end return nothing end -function set_retained_vns_del!(vi::NTVarInfo) - idcs = _getidcs(vi) - return _set_retained_vns_del!(vi.metadata, idcs, get_num_produce(vi)) -end -@generated function _set_retained_vns_del!( - metadata, idcs::NamedTuple{names}, num_produce -) where {names} - expr = Expr(:block) - for f in names - f_idcs = :(idcs.$f) - f_orders = :(metadata.$f.orders) - f_flags = :(metadata.$f.flags) - push!( - expr.args, - quote - # Set the flag for variables with symbol `f` - if num_produce == 0 - for i in length($f_idcs):-1:1 - $f_flags["del"][$f_idcs[i]] = true - end - else - for i in 1:length($f_orders) - if i in $f_idcs && $f_orders[i] > num_produce - $f_flags["del"][i] = true - end - end - end - end, - ) - end - return expr -end # TODO: Maybe rename or something? """ diff --git a/test/accumulators.jl b/test/accumulators.jl index 36bb95e46..5963ad8b5 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -7,7 +7,7 @@ using DynamicPPL: AccumulatorTuple, LogLikelihoodAccumulator, LogPriorAccumulator, - NumProduceAccumulator, + VariableOrderAccumulator, accumulate_assume!!, accumulate_observe!!, combine, @@ -31,11 +31,11 @@ using DynamicPPL: LogLikelihoodAccumulator{Float64}() == LogLikelihoodAccumulator{Float64}(0.0) == zero(LogLikelihoodAccumulator(1.0)) - @test NumProduceAccumulator(0) == - NumProduceAccumulator() == - NumProduceAccumulator{Int}() == - NumProduceAccumulator{Int}(0) == - zero(NumProduceAccumulator(1)) + @test VariableOrderAccumulator(0) == + VariableOrderAccumulator() == + VariableOrderAccumulator{Int}() == + VariableOrderAccumulator{Int}(0) == + VariableOrderAccumulator(0, Dict{VarName,Int}()) end @testset "addition and incrementation" begin @@ -47,19 +47,19 @@ using DynamicPPL: LogLikelihoodAccumulator(2.0f0) @test LogLikelihoodAccumulator(1.0) + LogLikelihoodAccumulator(1.0f0) == LogLikelihoodAccumulator(2.0) - @test increment(NumProduceAccumulator()) == NumProduceAccumulator(1) - @test increment(NumProduceAccumulator{UInt8}()) == - NumProduceAccumulator{UInt8}(1) + @test increment(VariableOrderAccumulator()) == VariableOrderAccumulator(1) + @test increment(VariableOrderAccumulator{UInt8}()) == + VariableOrderAccumulator{UInt8}(1) end @testset "split and combine" begin for acc in [ LogPriorAccumulator(1.0), LogLikelihoodAccumulator(1.0), - NumProduceAccumulator(1), + VariableOrderAccumulator(1), LogPriorAccumulator(1.0f0), LogLikelihoodAccumulator(1.0f0), - NumProduceAccumulator(UInt8(1)), + VariableOrderAccumulator(UInt8(1)), ] @test combine(acc, split(acc)) == acc end @@ -71,8 +71,9 @@ using DynamicPPL: @test convert( LogLikelihoodAccumulator{Float32}, LogLikelihoodAccumulator(1.0) ) == LogLikelihoodAccumulator{Float32}(1.0f0) - @test convert(NumProduceAccumulator{UInt8}, NumProduceAccumulator(1)) == - NumProduceAccumulator{UInt8}(1) + @test convert( + VariableOrderAccumulator{UInt8,VarName}, VariableOrderAccumulator(1) + ) == VariableOrderAccumulator{UInt8}(1) @test convert_eltype(Float32, LogPriorAccumulator(1.0)) == LogPriorAccumulator{Float32}(1.0f0) @@ -90,8 +91,8 @@ using DynamicPPL: @test accumulate_assume!!( LogLikelihoodAccumulator(1.0), val, logjac, vn, dist ) == LogLikelihoodAccumulator(1.0) - @test accumulate_assume!!(NumProduceAccumulator(1), val, logjac, vn, dist) == - NumProduceAccumulator(1) + @test accumulate_assume!!(VariableOrderAccumulator(1), val, logjac, vn, dist) == + VariableOrderAccumulator(1, Dict{VarName,Int}((vn => 1))) end @testset "accumulate_observe" begin @@ -102,8 +103,8 @@ using DynamicPPL: LogPriorAccumulator(1.0) @test accumulate_observe!!(LogLikelihoodAccumulator(1.0), right, left, vn) == LogLikelihoodAccumulator(1.0 + logpdf(right, left)) - @test accumulate_observe!!(NumProduceAccumulator(1), right, left, vn) == - NumProduceAccumulator(2) + @test accumulate_observe!!(VariableOrderAccumulator(1), right, left, vn) == + VariableOrderAccumulator(2) end end @@ -113,7 +114,7 @@ using DynamicPPL: lp_f32 = LogPriorAccumulator(1.0f0) ll_f64 = LogLikelihoodAccumulator(1.0) ll_f32 = LogLikelihoodAccumulator(1.0f0) - np_i64 = NumProduceAccumulator(1) + np_i64 = VariableOrderAccumulator(1) @testset "constructors" begin @test AccumulatorTuple(lp_f64, ll_f64) == AccumulatorTuple((lp_f64, ll_f64)) @@ -131,12 +132,12 @@ using DynamicPPL: @test at_all64[:LogPrior] == lp_f64 @test at_all64[:LogLikelihood] == ll_f64 - @test at_all64[:NumProduce] == np_i64 + @test at_all64[:VariableOrder] == np_i64 - @test haskey(AccumulatorTuple(np_i64), Val(:NumProduce)) + @test haskey(AccumulatorTuple(np_i64), Val(:VariableOrder)) @test ~haskey(AccumulatorTuple(np_i64), Val(:LogPrior)) @test length(AccumulatorTuple(lp_f64, ll_f64, np_i64)) == 3 - @test keys(at_all64) == (:LogPrior, :LogLikelihood, :NumProduce) + @test keys(at_all64) == (:LogPrior, :LogLikelihood, :VariableOrder) @test collect(at_all64) == [lp_f64, ll_f64, np_i64] # Replace the existing LogPriorAccumulator diff --git a/test/varinfo.jl b/test/varinfo.jl index 75868eb66..dad54f024 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -21,12 +21,13 @@ function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution) if !haskey(vi, vn) r = rand(dist) push!!(vi, vn, r, dist) + vi = DynamicPPL.setorder!!(vi, vn, DynamicPPL.get_num_produce(vi)) r elseif DynamicPPL.is_flagged(vi, vn, "del") DynamicPPL.unset_flag!(vi, vn, "del") r = rand(dist) vi[vn] = DynamicPPL.tovec(r) - DynamicPPL.setorder!(vi, vn, DynamicPPL.get_num_produce(vi)) + vi = DynamicPPL.setorder!!(vi, vn, DynamicPPL.get_num_produce(vi)) r else vi[vn] @@ -54,7 +55,6 @@ end ind = meta.idcs[vn] tind = fmeta.idcs[vn] @test meta.dists[ind] == fmeta.dists[tind] - @test meta.orders[ind] == fmeta.orders[tind] for flag in keys(meta.flags) @test meta.flags[flag][ind] == fmeta.flags[flag][tind] end @@ -208,7 +208,7 @@ end @test_throws r"has no field `?LogLikelihood" getloglikelihood(vi) @test_throws r"has no field `?LogLikelihood" getlogp(vi) @test_throws r"has no field `?LogLikelihood" getlogjoint(vi) - @test_throws r"has no field `?NumProduce" get_num_produce(vi) + @test_throws r"has no field `?VariableOrder" get_num_produce(vi) @test begin vi = acclogprior!!(vi, 1.0) getlogprior(vi) == lp_a + lp_b + 1.0 @@ -220,7 +220,7 @@ end vi = last( DynamicPPL.evaluate!!( - m, DynamicPPL.setaccs!!(deepcopy(vi), (NumProduceAccumulator(),)) + m, DynamicPPL.setaccs!!(deepcopy(vi), (VariableOrderAccumulator(),)) ), ) # need regex because 1.11 and 1.12 throw different errors (in 1.12 the @@ -239,8 +239,8 @@ end @test_throws r"has no field `?LogLikelihood" getloglikelihood(vi) @test_throws r"has no field `?LogPrior" getlogp(vi) @test_throws r"has no field `?LogPrior" getlogjoint(vi) - @test_throws r"has no field `?NumProduce" get_num_produce(vi) - @test_throws r"has no field `?NumProduce" reset_num_produce!!(vi) + @test_throws r"has no field `?VariableOrder" get_num_produce(vi) + @test_throws r"has no field `?VariableOrder" reset_num_produce!!(vi) end @testset "flags" begin @@ -1091,13 +1091,36 @@ end randr(vi, vn_a2, dists[2]) vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) - @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] + @test DynamicPPL.getorder(vi, vn_z1) == 1 + @test DynamicPPL.getorder(vi, vn_a1) == 1 + @test DynamicPPL.getorder(vi, vn_b) == 2 + @test DynamicPPL.getorder(vi, vn_z2) == 2 + @test DynamicPPL.getorder(vi, vn_a2) == 2 + @test DynamicPPL.getorder(vi, vn_z3) == 3 @test DynamicPPL.get_num_produce(vi) == 3 + @test !DynamicPPL.is_flagged(vi, vn_z1, "del") + @test !DynamicPPL.is_flagged(vi, vn_a1, "del") + @test !DynamicPPL.is_flagged(vi, vn_b, "del") + @test !DynamicPPL.is_flagged(vi, vn_z2, "del") + @test !DynamicPPL.is_flagged(vi, vn_a2, "del") + @test !DynamicPPL.is_flagged(vi, vn_z3, "del") + + vi = DynamicPPL.reset_num_produce!!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) + DynamicPPL.set_retained_vns_del!(vi) + @test !DynamicPPL.is_flagged(vi, vn_z1, "del") + @test !DynamicPPL.is_flagged(vi, vn_a1, "del") + @test DynamicPPL.is_flagged(vi, vn_b, "del") + @test DynamicPPL.is_flagged(vi, vn_z2, "del") + @test DynamicPPL.is_flagged(vi, vn_a2, "del") + @test DynamicPPL.is_flagged(vi, vn_z3, "del") + vi = DynamicPPL.reset_num_produce!!(vi) DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") + @test DynamicPPL.is_flagged(vi, vn_b, "del") @test DynamicPPL.is_flagged(vi, vn_z2, "del") @test DynamicPPL.is_flagged(vi, vn_a2, "del") @test DynamicPPL.is_flagged(vi, vn_z3, "del") @@ -1110,7 +1133,12 @@ end vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) randr(vi, vn_a2, dists[2]) - @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] + @test DynamicPPL.getorder(vi, vn_z1) == 1 + @test DynamicPPL.getorder(vi, vn_a1) == 1 + @test DynamicPPL.getorder(vi, vn_b) == 2 + @test DynamicPPL.getorder(vi, vn_z2) == 2 + @test DynamicPPL.getorder(vi, vn_z3) == 3 + @test DynamicPPL.getorder(vi, vn_a2) == 3 @test DynamicPPL.get_num_produce(vi) == 3 vi = empty!!(DynamicPPL.typed_varinfo(vi)) @@ -1125,9 +1153,12 @@ end randr(vi, vn_a2, dists[2]) vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) - @test vi.metadata.z.orders == [1, 2, 3] - @test vi.metadata.a.orders == [1, 2] - @test vi.metadata.b.orders == [2] + @test DynamicPPL.getorder(vi, vn_z1) == 1 + @test DynamicPPL.getorder(vi, vn_z2) == 2 + @test DynamicPPL.getorder(vi, vn_z3) == 3 + @test DynamicPPL.getorder(vi, vn_a1) == 1 + @test DynamicPPL.getorder(vi, vn_a2) == 2 + @test DynamicPPL.getorder(vi, vn_b) == 2 @test DynamicPPL.get_num_produce(vi) == 3 vi = DynamicPPL.reset_num_produce!!(vi) @@ -1146,9 +1177,12 @@ end vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) randr(vi, vn_a2, dists[2]) - @test vi.metadata.z.orders == [1, 2, 3] - @test vi.metadata.a.orders == [1, 3] - @test vi.metadata.b.orders == [2] + @test DynamicPPL.getorder(vi, vn_z1) == 1 + @test DynamicPPL.getorder(vi, vn_z2) == 2 + @test DynamicPPL.getorder(vi, vn_z3) == 3 + @test DynamicPPL.getorder(vi, vn_a1) == 1 + @test DynamicPPL.getorder(vi, vn_a2) == 3 + @test DynamicPPL.getorder(vi, vn_b) == 2 @test DynamicPPL.get_num_produce(vi) == 3 end From e60eab0d441f175115eb66c08ca151d15e4c9118 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 18 Jul 2025 20:53:12 +0100 Subject: [PATCH 22/27] Accumulators stage 2 (#925) * Give LogDensityFunction the getlogdensity field * Allow missing LogPriorAccumulator when linking * Trim whitespace * Run formatter * Fix a few typos * Fix comma -> semicolon * Fix `LogDensityAt` invocation * Fix one last test * Fix tests --------- Co-authored-by: Penelope Yong --- benchmarks/src/DynamicPPLBenchmarks.jl | 2 +- src/logdensityfunction.jl | 141 ++++++++++++++++--------- src/simple_varinfo.jl | 8 +- src/test_utils/ad.jl | 13 ++- src/varinfo.jl | 20 +++- src/varname.jl | 2 +- test/ad.jl | 16 +-- test/logdensityfunction.jl | 12 ++- 8 files changed, 144 insertions(+), 70 deletions(-) diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 26ec35b65..54a302a6f 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -86,7 +86,7 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: vi = DynamicPPL.link(vi, model) end - f = DynamicPPL.LogDensityFunction(model, vi; adtype=adbackend) + f = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi; adtype=adbackend) # The parameters at which we evaluate f. θ = vi[:] diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index e7565d137..3c092c06b 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -18,7 +18,8 @@ is_supported(::ADTypes.AutoReverseDiff) = true """ LogDensityFunction( model::Model, - varinfo::AbstractVarInfo=VarInfo(model); + getlogdensity::Function=getlogjoint, + varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing ) @@ -28,9 +29,10 @@ A struct which contains a model, along with all the information necessary to: - and if `adtype` is provided, calculate the gradient of the log density at that point. -At its most basic level, a LogDensityFunction wraps the model together with the -type of varinfo to be used. These must be known in order to calculate the log -density (using [`DynamicPPL.evaluate!!`](@ref)). +At its most basic level, a LogDensityFunction wraps the model together with a +function that specifies how to extract the log density, and the type of +VarInfo to be used. These must be known in order to calculate the log density +(using [`DynamicPPL.evaluate!!`](@ref)). If the `adtype` keyword argument is provided, then this struct will also store the adtype along with other information for efficient calculation of the @@ -72,13 +74,13 @@ julia> LogDensityProblems.dimension(f) 1 julia> # By default it uses `VarInfo` under the hood, but this is not necessary. - f = LogDensityFunction(model, SimpleVarInfo(model)); + f = LogDensityFunction(model, getlogjoint, SimpleVarInfo(model)); julia> LogDensityProblems.logdensity(f, [0.0]) -2.3378770664093453 -julia> # LogDensityFunction respects the accumulators in VarInfo: - f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),))); +julia> # One can also specify evaluating e.g. the log prior only: + f_prior = LogDensityFunction(model, getlogprior); julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) true @@ -93,11 +95,13 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) ``` """ struct LogDensityFunction{ - M<:Model,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType} + M<:Model,F<:Function,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType} } <: AbstractModel "model used for evaluation" model::M - "varinfo used for evaluation" + "function to be called on `varinfo` to extract the log density. By default `getlogjoint`." + getlogdensity::F + "varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`." varinfo::V "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" adtype::AD @@ -106,7 +110,8 @@ struct LogDensityFunction{ function LogDensityFunction( model::Model, - varinfo::AbstractVarInfo=VarInfo(model); + getlogdensity::Function=getlogjoint, + varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) if adtype === nothing @@ -120,15 +125,22 @@ struct LogDensityFunction{ # Get a set of dummy params to use for prep x = map(identity, varinfo[:]) if use_closure(adtype) - prep = DI.prepare_gradient(LogDensityAt(model, varinfo), adtype, x) + prep = DI.prepare_gradient( + LogDensityAt(model, getlogdensity, varinfo), adtype, x + ) else prep = DI.prepare_gradient( - logdensity_at, adtype, x, DI.Constant(model), DI.Constant(varinfo) + logdensity_at, + adtype, + x, + DI.Constant(model), + DI.Constant(getlogdensity), + DI.Constant(varinfo), ) end end - return new{typeof(model),typeof(varinfo),typeof(adtype)}( - model, varinfo, adtype, prep + return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(adtype)}( + model, getlogdensity, varinfo, adtype, prep ) end end @@ -149,83 +161,112 @@ function LogDensityFunction( return if adtype === f.adtype f # Avoid recomputing prep if not needed else - LogDensityFunction(f.model, f.varinfo; adtype=adtype) + LogDensityFunction(f.model, f.getlogdensity, f.varinfo; adtype=adtype) end end +""" + ldf_default_varinfo(model::Model, getlogdensity::Function) + +Create the default AbstractVarInfo that should be used for evaluating the log density. + +Only the accumulators necesessary for `getlogdensity` will be used. +""" +function ldf_default_varinfo(::Model, getlogdensity::Function) + msg = """ + LogDensityFunction does not know what sort of VarInfo should be used when \ + `getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly. + """ + return error(msg) +end + +ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) = VarInfo(model) + +function ldf_default_varinfo(model::Model, ::typeof(getlogprior)) + return setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) +end + +function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood)) + return setaccs!!(VarInfo(model), (LogLikelihoodAccumulator(),)) +end + """ logdensity_at( x::AbstractVector, model::Model, + getlogdensity::Function, varinfo::AbstractVarInfo, ) -Evaluate the log density of the given `model` at the given parameter values `x`, -using the given `varinfo`. Note that the `varinfo` argument is provided only -for its structure, in the sense that the parameters from the vector `x` are -inserted into it, and its own parameters are discarded. It does, however, -determine whether the log prior, likelihood, or joint is returned, based on -which accumulators are set in it. +Evaluate the log density of the given `model` at the given parameter values +`x`, using the given `varinfo`. Note that the `varinfo` argument is provided +only for its structure, in the sense that the parameters from the vector `x` +are inserted into it, and its own parameters are discarded. `getlogdensity` is +the function that extracts the log density from the evaluated varinfo. """ -function logdensity_at(x::AbstractVector, model::Model, varinfo::AbstractVarInfo) +function logdensity_at( + x::AbstractVector, model::Model, getlogdensity::Function, varinfo::AbstractVarInfo +) varinfo_new = unflatten(varinfo, x) varinfo_eval = last(evaluate!!(model, varinfo_new)) - has_prior = hasacc(varinfo_eval, Val(:LogPrior)) - has_likelihood = hasacc(varinfo_eval, Val(:LogLikelihood)) - if has_prior && has_likelihood - return getlogjoint(varinfo_eval) - elseif has_prior - return getlogprior(varinfo_eval) - elseif has_likelihood - return getloglikelihood(varinfo_eval) - else - error("LogDensityFunction: varinfo tracks neither log prior nor log likelihood") - end + return getlogdensity(varinfo_eval) end """ - LogDensityAt{M<:Model,V<:AbstractVarInfo}( + LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo}( model::M + getlogdensity::F, varinfo::V ) A callable struct that serves the same purpose as `x -> logdensity_at(x, model, -varinfo)`. +getlogdensity, varinfo)`. """ -struct LogDensityAt{M<:Model,V<:AbstractVarInfo} +struct LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo} model::M + getlogdensity::F varinfo::V end -(ld::LogDensityAt)(x::AbstractVector) = logdensity_at(x, ld.model, ld.varinfo) +function (ld::LogDensityAt)(x::AbstractVector) + return logdensity_at(x, ld.model, ld.getlogdensity, ld.varinfo) +end ### LogDensityProblems interface function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,V,Nothing}} -) where {M,V} + ::Type{<:LogDensityFunction{M,F,V,Nothing}} +) where {M,F,V} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,V,AD}} -) where {M,V,AD<:ADTypes.AbstractADType} + ::Type{<:LogDensityFunction{M,F,V,AD}} +) where {M,F,V,AD<:ADTypes.AbstractADType} return LogDensityProblems.LogDensityOrder{1}() end function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) - return logdensity_at(x, f.model, f.varinfo) + return logdensity_at(x, f.model, f.getlogdensity, f.varinfo) end function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunction{M,V,AD}, x::AbstractVector -) where {M,V,AD<:ADTypes.AbstractADType} + f::LogDensityFunction{M,F,V,AD}, x::AbstractVector +) where {M,F,V,AD<:ADTypes.AbstractADType} f.prep === nothing && error("Gradient preparation not available; this should not happen") x = map(identity, x) # Concretise type # Make branching statically inferrable, i.e. type-stable (even if the two # branches happen to return different types) return if use_closure(f.adtype) - DI.value_and_gradient(LogDensityAt(f.model, f.varinfo), f.prep, f.adtype, x) + DI.value_and_gradient( + LogDensityAt(f.model, f.getlogdensity, f.varinfo), f.prep, f.adtype, x + ) else DI.value_and_gradient( - logdensity_at, f.prep, f.adtype, x, DI.Constant(f.model), DI.Constant(f.varinfo) + logdensity_at, + f.prep, + f.adtype, + x, + DI.Constant(f.model), + DI.Constant(f.getlogdensity), + DI.Constant(f.varinfo), ) end end @@ -264,9 +305,9 @@ There are two ways of dealing with this: 1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f) -2. Use a constant context. This lets us pass a two-argument function to - DifferentiationInterface, as long as we also give it the 'inactive argument' - (i.e. the model) wrapped in `DI.Constant`. +2. Use a constant DI.Context. This lets us pass a two-argument function to DI, + as long as we also give it the 'inactive argument' (i.e. the model) wrapped + in `DI.Constant`. The relative performance of the two approaches, however, depends on the AD backend used. Some benchmarks are provided here: @@ -292,7 +333,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. """ function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return LogDensityFunction(model, f.varinfo; adtype=f.adtype) + return LogDensityFunction(model, f.getlogdensity, f.varinfo; adtype=f.adtype) end """ diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index d8a00d5ea..abb93a0ab 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -619,7 +619,9 @@ function link!!( x = vi.values y, logjac = with_logabsdet_jacobian(b, x) vi_new = Accessors.@set(vi.values = y) - vi_new = acclogprior!!(vi_new, -logjac) + if hasacc(vi_new, Val(:LogPrior)) + vi_new = acclogprior!!(vi_new, -logjac) + end return settrans!!(vi_new, t) end @@ -632,7 +634,9 @@ function invlink!!( y = vi.values x, logjac = with_logabsdet_jacobian(b, y) vi_new = Accessors.@set(vi.values = x) - vi_new = acclogprior!!(vi_new, logjac) + if hasacc(vi_new, Val(:LogPrior)) + vi_new = acclogprior!!(vi_new, logjac) + end return settrans!!(vi_new, NoTransformation()) end diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 155f3b68d..d4f6f9a1d 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,7 +4,7 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link +using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint, link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: AbstractRNG, default_rng using Statistics: median @@ -88,6 +88,8 @@ $(TYPEDFIELDS) struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat} "The DynamicPPL model that was tested" model::Model + "The function used to extract the log density from the model" + getlogdensity::Function "The VarInfo that was used" varinfo::AbstractVarInfo "The values at which the model was evaluated" @@ -222,6 +224,7 @@ function run_ad( benchmark::Bool=false, atol::AbstractFloat=100 * eps(), rtol::AbstractFloat=sqrt(eps()), + getlogdensity::Function=getlogjoint, rng::AbstractRNG=default_rng(), varinfo::AbstractVarInfo=link(VarInfo(rng, model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, @@ -241,7 +244,8 @@ function run_ad( # Calculate log-density and gradient with the backend of interest verbose && @info "Running AD on $(model.f) with $(adtype)\n" verbose && println(" params : $(params)") - ldf = LogDensityFunction(model, varinfo; adtype=adtype) + ldf = LogDensityFunction(model, getlogdensity, varinfo; adtype=adtype) + value, grad = logdensity_and_gradient(ldf, params) # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 grad = collect(grad) @@ -257,7 +261,9 @@ function run_ad( value_true = test.value grad_true = test.grad elseif test isa WithBackend - ldf_reference = LogDensityFunction(model, varinfo; adtype=test.adtype) + ldf_reference = LogDensityFunction( + model, getlogdensity, varinfo; adtype=test.adtype + ) value_true, grad_true = logdensity_and_gradient(ldf_reference, params) # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 grad_true = collect(grad_true) @@ -282,6 +288,7 @@ function run_ad( return ADResult( model, + getlogdensity, varinfo, params, adtype, diff --git a/src/varinfo.jl b/src/varinfo.jl index f280eb98b..d8233ae07 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1148,7 +1148,9 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) setrange!(md, vn, start:(start + length(yvec) - 1)) # Set the new value. setval!(md, yvec, vn) - vi = acclogprior!!(vi, -logjac) + if hasacc(vi, Val(:LogPrior)) + vi = acclogprior!!(vi, -logjac) + end return vi end @@ -1185,7 +1187,9 @@ function _link(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogPrior)) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + end return new_varinfo end @@ -1199,7 +1203,9 @@ function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) md, logjac = _link_metadata!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogPrior)) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + end return new_varinfo end @@ -1347,7 +1353,9 @@ function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) md, logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogPrior)) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + end return new_varinfo end @@ -1361,7 +1369,9 @@ function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) md, logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogPrior)) + new_varinfo = acclogprior!!(new_varinfo, -logjac) + end return new_varinfo end diff --git a/src/varname.jl b/src/varname.jl index c16587065..3eb1f2460 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -7,7 +7,7 @@ This is a very restricted version `subumes(u::VarName, v::VarName)` only really - Scalar: `x` subsumes `x[1, 2]`, `x[1, 2]` subsumes `x[1, 2][3]`, etc. ## Note -- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)` +- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)` for strings, one can always do `eval(varname(Meta.parse(u))` to get `VarName` of `u`, and similarly to `v`. But this is slow. """ diff --git a/test/ad.jl b/test/ad.jl index 48dffeadb..308894ada 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -30,7 +30,7 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos linked_varinfo = DynamicPPL.link(varinfo, m) - f = LogDensityFunction(m, linked_varinfo) + f = LogDensityFunction(m, getlogjoint, linked_varinfo) x = DynamicPPL.getparams(f) # Calculate reference logp + gradient of logp using ForwardDiff @@ -52,17 +52,17 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest if is_mooncake && is_1_11 && is_svi_vnv # https://github.com/compintell/Mooncake.jl/issues/470 @test_throws ArgumentError DynamicPPL.LogDensityFunction( - m, linked_varinfo; adtype=adtype + m, getlogjoint, linked_varinfo; adtype=adtype ) elseif is_mooncake && is_1_10 && is_svi_vnv # TODO: report upstream @test_throws UndefRefError DynamicPPL.LogDensityFunction( - m, linked_varinfo; adtype=adtype + m, getlogjoint, linked_varinfo; adtype=adtype ) elseif is_mooncake && is_1_10 && is_svi_od # TODO: report upstream @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - m, linked_varinfo; adtype=adtype + m, getlogjoint, linked_varinfo; adtype=adtype ) else @test run_ad( @@ -111,10 +111,12 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest # Compiling the ReverseDiff tape used to fail here spl = Sampler(MyEmptyAlg()) - vi = VarInfo(model) sampling_model = contextualize(model, SamplingContext(model.context)) - ldf = LogDensityFunction(sampling_model, vi; adtype=AutoReverseDiff(; compile=true)) - @test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any + ldf = LogDensityFunction( + sampling_model, getlogjoint; adtype=AutoReverseDiff(; compile=true) + ) + x = ldf.varinfo[:] + @test LogDensityProblems.logdensity_and_gradient(ldf, x) isa Any end # Test that various different ways of specifying array types as arguments work with all diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index d6e66ec59..c4d0d6beb 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -15,8 +15,18 @@ end vns = DynamicPPL.TestUtils.varnames(model) varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) + vi = first(varinfos) + theta = vi[:] + ldf_joint = DynamicPPL.LogDensityFunction(model) + @test LogDensityProblems.logdensity(ldf_joint, theta) ≈ logjoint(model, vi) + ldf_prior = DynamicPPL.LogDensityFunction(model, getlogprior) + @test LogDensityProblems.logdensity(ldf_prior, theta) ≈ logprior(model, vi) + ldf_likelihood = DynamicPPL.LogDensityFunction(model, getloglikelihood) + @test LogDensityProblems.logdensity(ldf_likelihood, theta) ≈ + loglikelihood(model, vi) + @testset "$(varinfo)" for varinfo in varinfos - logdensity = DynamicPPL.LogDensityFunction(model, varinfo) + logdensity = DynamicPPL.LogDensityFunction(model, getlogjoint, varinfo) θ = varinfo[:] @test LogDensityProblems.logdensity(logdensity, θ) ≈ logjoint(model, varinfo) @test LogDensityProblems.dimension(logdensity) == length(θ) From bd99d4f42e4e98d320ae98f54c99e52eb5dc1afd Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 24 Jul 2025 15:05:33 +0100 Subject: [PATCH 23/27] Implement more consistent tracking of logp components via `LogJacobianAccumulator` (#998) * logjac accumulator * Fix tests * Fix a whole bunch of stuff * Fix final tests * Fix docs * Fix docs/doctests * Fix maths in LogJacobianAccumulator docstring * Twiddle with a comment * Add changelog * Fix accumulator docstring * logJ -> logjac * Fix logjac accumulation for StaticTransformation --- HISTORY.md | 32 ++++++-- benchmarks/src/DynamicPPLBenchmarks.jl | 4 +- docs/make.jl | 4 +- docs/src/api.md | 6 ++ src/DynamicPPL.jl | 6 ++ src/abstract_varinfo.jl | 109 +++++++++++++++++++++---- src/accumulators.jl | 15 +++- src/context_implementations.jl | 6 +- src/default_accumulators.jl | 85 ++++++++++++++++++- src/logdensityfunction.jl | 57 ++++++++++--- src/model.jl | 8 ++ src/pointwise_logdensities.jl | 3 + src/simple_varinfo.jl | 24 +++--- src/test_utils/ad.jl | 5 +- src/threadsafe.jl | 5 ++ src/transforming.jl | 30 ++----- src/varinfo.jl | 54 ++++++------ test/accumulators.jl | 6 +- test/ad.jl | 10 +-- test/linking.jl | 20 +++-- test/logdensityfunction.jl | 3 + test/model.jl | 7 ++ test/simple_varinfo.jl | 73 +++++++++-------- test/varinfo.jl | 77 ++++++++--------- 24 files changed, 458 insertions(+), 191 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index d367e9ad7..b59d8dd7f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -32,20 +32,40 @@ Their semantics are the same as in Julia's `isapprox`; two values are equal if t You now need to explicitly pass a `VarInfo` argument to `check_model` and `check_model_and_trace`. Previously, these functions would generate a new VarInfo for you (using an optionally provided `rng`). -### Removal of `PriorContext` and `LikelihoodContext` - -A number of DynamicPPL's contexts have been removed, most notably `PriorContext` and `LikelihoodContext`. -Although these are not the only _exported_ contexts, we consider unlikely that anyone was using _other_ contexts manually: if you have a question about contexts _other_ than these, please continue reading the 'Internals' section below. +### Evaluating model log-probabilities in more detail Previously, during evaluation of a model, DynamicPPL only had the capability to store a _single_ log probability (`logp`) field. `DefaultContext`, `PriorContext`, and `LikelihoodContext` were used to control what this field represented: they would accumulate the log joint, log prior, or log likelihood, respectively. -Now, we have reworked DynamicPPL's `VarInfo` object such that it can track multiple log probabilities at once (see the 'Accumulators' section below). +In this version, we have overhauled this quite substantially. +The technical details of exactly _how_ this is done is covered in the 'Accumulators' section below, but the upshot is that the log prior, log likelihood, and log Jacobian terms (for any linked variables) are separately tracked. + +Specifically, you will want to use the following functions to access these log probabilities: + + - `getlogprior(varinfo)` to get the log prior. **Note:** This version introduces new, more consistent behaviour for this function, in that it always returns the log-prior of the values in the original, untransformed space, even if the `varinfo` has been linked. + - `getloglikelihood(varinfo)` to get the log likelihood. + - `getlogjoint(varinfo)` to get the log joint probability. **Note:** Similar to `getlogprior`, this function now always returns the log joint of the values in the original, untransformed space, even if the `varinfo` has been linked. + +If you are using linked VarInfos (e.g. if you are writing a sampler), you may find that you need to obtain the log probability of the variables in the transformed space. +To this end, you can use: + + - `getlogjac(varinfo)` to get the log Jacobian of the link transforms for any linked variables. + - `getlogprior_internal(varinfo)` to get the log prior of the variables in the transformed space. + - `getlogjoint_internal(varinfo)` to get the log joint probability of the variables in the transformed space. + +Since transformations only apply to random variables, the likelihood is unaffected by linking. + +### Removal of `PriorContext` and `LikelihoodContext` + +Following on from the above, a number of DynamicPPL's contexts have been removed, most notably `PriorContext` and `LikelihoodContext`. +Although these are not the only _exported_ contexts, we consider unlikely that anyone was using _other_ contexts manually: if you have a question about contexts _other_ than these, please continue reading the 'Internals' section below. + If you were evaluating a model with `PriorContext`, you can now just evaluate it with `DefaultContext`, and instead of calling `getlogp(varinfo)`, you can call `getlogprior(varinfo)` (and similarly for the likelihood). If you were constructing a `LogDensityFunction` with `PriorContext`, you can now stick to `DefaultContext`. `LogDensityFunction` now has an extra field, called `getlogdensity`, which represents a function that takes a `VarInfo` and returns the log density you want. -Thus, if you pass `getlogprior` as the value of this parameter, you will get the same behaviour as with `PriorContext`. +Thus, if you pass `getlogprior_internal` as the value of this parameter, you will get the same behaviour as with `PriorContext`. +(You should consider whether your use case needs the log prior in the transformed space, or the original space, and use (respectively) `getlogprior_internal` or `getlogprior` as needed.) The other case where one might use `PriorContext` was to use `@addlogprob!` to add to the log prior. Previously, this was accomplished by manually checking `__context__ isa DynamicPPL.PriorContext`. diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 54a302a6f..8c5032ace 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -86,7 +86,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: vi = DynamicPPL.link(vi, model) end - f = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi; adtype=adbackend) + f = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend + ) # The parameters at which we evaluate f. θ = vi[:] diff --git a/docs/make.jl b/docs/make.jl index c69b72fb8..9c59cb06b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -21,7 +21,9 @@ makedocs(; sitename="DynamicPPL", # The API index.html page is fairly large, and violates the default HTML page size # threshold of 200KiB, so we double that. - format=Documenter.HTML(; size_threshold=2^10 * 400), + format=Documenter.HTML(; + size_threshold=2^10 * 400, mathengine=Documenter.HTMLWriter.MathJax3() + ), modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)], pages=[ "Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"] diff --git a/docs/src/api.md b/docs/src/api.md index 180e8dfd4..9237943c7 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -367,6 +367,7 @@ DynamicPPL provides the following default accumulators. ```@docs LogPriorAccumulator +LogJacobianAccumulator LogLikelihoodAccumulator VariableOrderAccumulator ``` @@ -380,7 +381,12 @@ getlogp setlogp!! acclogp!! getlogjoint +getlogjoint_internal +getlogjac +setlogjac!! +acclogjac!! getlogprior +getlogprior_internal setlogprior!! acclogprior!! getloglikelihood diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c282939a2..15d39014e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -50,6 +50,7 @@ export AbstractVarInfo, AbstractAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, + LogJacobianAccumulator, VariableOrderAccumulator, push!!, empty!!, @@ -58,10 +59,15 @@ export AbstractVarInfo, getlogjoint, getlogprior, getloglikelihood, + getlogjac, + getlogjoint_internal, + getlogprior_internal, setlogp!!, setlogprior!!, + setlogjac!!, setloglikelihood!!, acclogp!!, + acclogjac!!, acclogprior!!, accloglikelihood!!, resetlogp!!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 581ca829b..cf5ce5706 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -99,16 +99,34 @@ See also: [`getlogprior`](@ref), [`getloglikelihood`](@ref). """ getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi) +""" + getlogjoint_internal(vi::AbstractVarInfo) + +Return the log of the joint probability of the observed data and parameters as +they are stored internally in `vi`, including the log-Jacobian for any linked +parameters. + +In general, we have that: + +```julia +getlogjoint_internal(vi) == getlogjoint(vi) - getlogjac(vi) +``` +""" +getlogjoint_internal(vi::AbstractVarInfo) = + getlogprior(vi) + getloglikelihood(vi) - getlogjac(vi) + """ getlogp(vi::AbstractVarInfo) -Return a NamedTuple of the log prior and log likelihood probabilities. +Return a NamedTuple of the log prior, log Jacobian, and log likelihood probabilities. -The keys are called `logprior` and `loglikelihood`. If either one is not present in `vi` an -error will be thrown. +The keys are called `logprior`, `logjac`, and `loglikelihood`. If any of them +are not present in `vi` an error will be thrown. """ function getlogp(vi::AbstractVarInfo) - return (; logprior=getlogprior(vi), loglikelihood=getloglikelihood(vi)) + return (; + logprior=getlogprior(vi), logjac=getlogjac(vi), loglikelihood=getloglikelihood(vi) + ) end """ @@ -164,6 +182,30 @@ See also: [`getlogjoint`](@ref), [`getloglikelihood`](@ref), [`setlogprior!!`](@ """ getlogprior(vi::AbstractVarInfo) = getacc(vi, Val(:LogPrior)).logp +""" + getlogprior_internal(vi::AbstractVarInfo) + +Return the log of the prior probability of the parameters as stored internally +in `vi`. This includes the log-Jacobian for any linked parameters. + +In general, we have that: + +```julia +getlogprior_internal(vi) == getlogprior(vi) - getlogjac(vi) +``` +""" +getlogprior_internal(vi::AbstractVarInfo) = getlogprior(vi) - getlogjac(vi) + +""" + getlogjac(vi::AbstractVarInfo) + +Return the accumulated log-Jacobian term for any linked parameters in `vi`. The +Jacobian here is taken with respect to the forward (link) transform. + +See also: [`setlogjac!!`](@ref). +""" +getlogjac(vi::AbstractVarInfo) = getacc(vi, Val(:LogJacobian)).logjac + """ getloglikelihood(vi::AbstractVarInfo) @@ -196,6 +238,16 @@ See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@re """ setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPriorAccumulator(logp)) +""" + setlogjac!!(vi::AbstractVarInfo, logjac) + +Set the accumulated log-Jacobian term for any linked parameters in `vi`. The +Jacobian here is taken with respect to the forward (link) transform. + +See also: [`getlogjac`](@ref), [`acclogjac!!`](@ref). +""" +setlogjac!!(vi::AbstractVarInfo, logjac) = setacc!!(vi, LogJacobianAccumulator(logjac)) + """ setloglikelihood!!(vi::AbstractVarInfo, logp) @@ -215,10 +267,13 @@ Set both the log prior and the log likelihood probabilities in `vi`. See also: [`setlogprior!!`](@ref), [`setloglikelihood!!`](@ref), [`getlogp`](@ref). """ function setlogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names} - if !(names == (:logprior, :loglikelihood) || names == (:loglikelihood, :logprior)) - error("logp must have the fields logprior and loglikelihood and no other fields.") + if Set(names) != Set([:logprior, :logjac, :loglikelihood]) + error( + "The second argument to `setlogp!!` must be a NamedTuple with the fields logprior, logjac, and loglikelihood.", + ) end vi = setlogprior!!(vi, logp.logprior) + vi = setlogjac!!(vi, logp.logjac) vi = setloglikelihood!!(vi, logp.loglikelihood) return vi end @@ -226,7 +281,7 @@ end function setlogp!!(vi::AbstractVarInfo, logp::Number) return error(""" `setlogp!!(vi::AbstractVarInfo, logp::Number)` is no longer supported. Use - `setloglikelihood!!` and/or `setlogprior!!` instead. + `setloglikelihood!!`, `setlogjac!!`, and/or `setlogprior!!` instead. """) end @@ -306,6 +361,19 @@ function acclogprior!!(vi::AbstractVarInfo, logp) return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior)) end +""" + acclogjac!!(vi::AbstractVarInfo, logjac) + +Add `logjac` to the value of the log Jacobian in `vi`. + +See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref). +""" +function acclogjac!!(vi::AbstractVarInfo, logjac) + return map_accumulator!!( + acc -> acc + LogJacobianAccumulator(logjac), vi, Val(:LogJacobian) + ) +end + """ accloglikelihood!!(vi::AbstractVarInfo, logp) @@ -368,6 +436,9 @@ function resetlogp!!(vi::AbstractVarInfo) if hasacc(vi, Val(:LogPrior)) vi = map_accumulator!!(zero, vi, Val(:LogPrior)) end + if hasacc(vi, Val(:LogJacobian)) + vi = map_accumulator!!(zero, vi, Val(:LogJacobian)) + end if hasacc(vi, Val(:LogLikelihood)) vi = map_accumulator!!(zero, vi, Val(:LogLikelihood)) end @@ -836,9 +907,12 @@ function link!!( x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) - lp_new = getlogprior(vi) - logjac - vi_new = setlogprior!!(unflatten(vi, y), lp_new) - return settrans!!(vi_new, t) + # Set parameters and add the logjac term. + vi = unflatten(vi, y) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, logjac) + end + return settrans!!(vi, t) end function invlink!!( @@ -846,11 +920,16 @@ function invlink!!( ) b = t.bijector y = vi[:] - x, logjac = with_logabsdet_jacobian(b, y) - - lp_new = getlogprior(vi) + logjac - vi_new = setlogprior!!(unflatten(vi, x), lp_new) - return settrans!!(vi_new, NoTransformation()) + x, inv_logjac = with_logabsdet_jacobian(b, y) + + # Mildly confusing: we need to _add_ the logjac of the inverse transform, + # because we are trying to remove the logjac of the forward transform + # that was previously accumulated when linking. + vi = unflatten(vi, x) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, inv_logjac) + end + return settrans!!(vi, NoTransformation()) end """ diff --git a/src/accumulators.jl b/src/accumulators.jl index 1e3e37e61..0dcf9c7cf 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -11,10 +11,21 @@ seen so far. An accumulator type `T <: AbstractAccumulator` must implement the following methods: - `accumulator_name(acc::T)` or `accumulator_name(::Type{T})` -- `accumulate_observe!!(acc::T, right, left, vn)` -- `accumulate_assume!!(acc::T, val, logjac, vn, right)` +- `accumulate_observe!!(acc::T, dist, val, vn)` +- `accumulate_assume!!(acc::T, val, logjac, vn, dist)` - `Base.copy(acc::T)` +In these functions: +- `val` is the new value of the random variable sampled from a distribution (always in + the original unlinked space), or the value on the left-hand side of an observe + statement. +- `dist` is the distribution on the RHS of the tilde statement. +- `vn` is the `VarName` that is on the left-hand side of the tilde-statement. If the + tilde-statement is a literal observation like `0.0 ~ Normal()`, then `vn` is `nothing`. +- `logjac` is the log determinant of the Jacobian of the link transformation, _if_ the + variable is stored as a linked value in the VarInfo. If the variable is stored in its + original, unlinked form, then `logjac` is zero. + To be able to work with multi-threading, it should also implement: - `split(acc::T)` - `combine(acc::T, acc2::T)` diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 9e9a2d63d..786d7c913 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -123,8 +123,8 @@ end function assume(dist::Distribution, vn::VarName, vi) y = getindex_internal(vi, vn) f = from_maybe_linked_internal_transform(vi, vn, dist) - x, logjac = with_logabsdet_jacobian(f, y) - vi = accumulate_assume!!(vi, x, logjac, vn, dist) + x, inv_logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist) return x, vi end @@ -166,6 +166,6 @@ function assume( # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) - vi = accumulate_assume!!(vi, r, -logjac, vn, dist) + vi = accumulate_assume!!(vi, r, logjac, vn, dist) return r, vi end diff --git a/src/default_accumulators.jl b/src/default_accumulators.jl index 418362e8f..d503b3e64 100644 --- a/src/default_accumulators.jl +++ b/src/default_accumulators.jl @@ -3,6 +3,10 @@ An accumulator that tracks the cumulative log prior during model execution. +Note that the log prior stored in here is always calculated based on unlinked +parameters, i.e., the value of `logp` is independent of whether tha VarInfo is +linked or not. + # Fields $(TYPEDFIELDS) """ @@ -19,6 +23,49 @@ Create a new `LogPriorAccumulator` accumulator with the log prior initialized to LogPriorAccumulator{T}() where {T<:Real} = LogPriorAccumulator(zero(T)) LogPriorAccumulator() = LogPriorAccumulator{LogProbType}() +""" + LogJacobianAccumulator{T<:Real} <: AbstractAccumulator + +An accumulator that tracks the cumulative log Jacobian (technically, +log(abs(det(J)))) during model execution. Specifically, J refers to the +Jacobian of the _link transform_, i.e., from the space of the original +distribution to unconstrained space. + +!!! note + This accumulator is only incremented if the variable is transformed by a + link function, i.e., if the VarInfo is linked (for the particular + variable that is currently being accumulated). If the variable is not + linked, the log Jacobian term will be 0. + + In general, for the forward Jacobian ``\\mathbf{J}`` corresponding to the + function ``\\mathbf{y} = f(\\mathbf{x})``, + + ```math + \\log(q(\\mathbf{y})) = \\log(p(\\mathbf{x})) - \\log (|\\mathbf{J}|) + ``` + + and correspondingly: + + ```julia + getlogjoint_internal(vi) = getlogjoint(vi) - getlogjac(vi) + ``` + +# Fields +$(TYPEDFIELDS) +""" +struct LogJacobianAccumulator{T<:Real} <: AbstractAccumulator + "the logabsdet of the link transform Jacobian" + logjac::T +end + +""" + LogJacobianAccumulator{T}() + +Create a new `LogJacobianAccumulator` accumulator with the log Jacobian initialized to zero. +""" +LogJacobianAccumulator{T}() where {T<:Real} = LogJacobianAccumulator(zero(T)) +LogJacobianAccumulator() = LogJacobianAccumulator{LogProbType}() + """ LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator @@ -71,6 +118,7 @@ VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n) VariableOrderAccumulator() = VariableOrderAccumulator{Int}() Base.copy(acc::LogPriorAccumulator) = acc +Base.copy(acc::LogJacobianAccumulator) = acc Base.copy(acc::LogLikelihoodAccumulator) = acc function Base.copy(acc::VariableOrderAccumulator) return VariableOrderAccumulator(acc.num_produce, copy(acc.order)) @@ -79,6 +127,9 @@ end function Base.show(io::IO, acc::LogPriorAccumulator) return print(io, "LogPriorAccumulator($(repr(acc.logp)))") end +function Base.show(io::IO, acc::LogJacobianAccumulator) + return print(io, "LogJacobianAccumulator($(repr(acc.logjac)))") +end function Base.show(io::IO, acc::LogLikelihoodAccumulator) return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))") end @@ -92,6 +143,9 @@ end # equality of hashes. Both of the below implementations are also different from the default # implementation for structs. Base.:(==)(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) = acc1.logp == acc2.logp +function Base.:(==)(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator) + return acc1.logjac == acc2.logjac +end function Base.:(==)(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) return acc1.logp == acc2.logp end @@ -102,6 +156,9 @@ end function Base.isequal(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) return isequal(acc1.logp, acc2.logp) end +function Base.isequal(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator) + return isequal(acc1.logjac, acc2.logjac) +end function Base.isequal(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) return isequal(acc1.logp, acc2.logp) end @@ -110,6 +167,9 @@ function Base.isequal(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumul end Base.hash(acc::LogPriorAccumulator, h::UInt) = hash((LogPriorAccumulator, acc.logp), h) +function Base.hash(acc::LogJacobianAccumulator, h::UInt) + return hash((LogJacobianAccumulator, acc.logjac), h) +end function Base.hash(acc::LogLikelihoodAccumulator, h::UInt) return hash((LogLikelihoodAccumulator, acc.logp), h) end @@ -118,16 +178,21 @@ function Base.hash(acc::VariableOrderAccumulator, h::UInt) end accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior +accumulator_name(::Type{<:LogJacobianAccumulator}) = :LogJacobian accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood accumulator_name(::Type{<:VariableOrderAccumulator}) = :VariableOrder split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T)) +split(::LogJacobianAccumulator{T}) where {T} = LogJacobianAccumulator(zero(T)) split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T)) split(acc::VariableOrderAccumulator) = copy(acc) function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator) return LogPriorAccumulator(acc.logp + acc2.logp) end +function combine(acc::LogJacobianAccumulator, acc2::LogJacobianAccumulator) + return LogJacobianAccumulator(acc.logjac + acc2.logjac) +end function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) return LogLikelihoodAccumulator(acc.logp + acc2.logp) end @@ -142,6 +207,9 @@ end function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) return LogPriorAccumulator(acc1.logp + acc2.logp) end +function Base.:+(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator) + return LogJacobianAccumulator(acc1.logjac + acc2.logjac) +end function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) return LogLikelihoodAccumulator(acc1.logp + acc2.logp) end @@ -150,13 +218,19 @@ function increment(acc::VariableOrderAccumulator) end Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp)) +Base.zero(acc::LogJacobianAccumulator) = LogJacobianAccumulator(zero(acc.logjac)) Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp)) function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right) - return acc + LogPriorAccumulator(logpdf(right, val) + logjac) + return acc + LogPriorAccumulator(logpdf(right, val)) end accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc +function accumulate_assume!!(acc::LogJacobianAccumulator, val, logjac, vn, right) + return acc + LogJacobianAccumulator(logjac) +end +accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn) = acc + accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) # Note that it's important to use the loglikelihood function here, not logpdf, because @@ -174,6 +248,11 @@ accumulate_observe!!(acc::VariableOrderAccumulator, right, left, vn) = increment function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T} return LogPriorAccumulator(convert(T, acc.logp)) end +function Base.convert( + ::Type{LogJacobianAccumulator{T}}, acc::LogJacobianAccumulator +) where {T} + return LogJacobianAccumulator(convert(T, acc.logjac)) +end function Base.convert( ::Type{LogLikelihoodAccumulator{T}}, acc::LogLikelihoodAccumulator ) where {T} @@ -197,6 +276,9 @@ end function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T} return LogPriorAccumulator(convert(T, acc.logp)) end +function convert_eltype(::Type{T}, acc::LogJacobianAccumulator) where {T} + return LogJacobianAccumulator(convert(T, acc.logjac)) +end function convert_eltype(::Type{T}, acc::LogLikelihoodAccumulator) where {T} return LogLikelihoodAccumulator(convert(T, acc.logp)) end @@ -206,6 +288,7 @@ function default_accumulators( ) where {FloatT,IntT} return AccumulatorTuple( LogPriorAccumulator{FloatT}(), + LogJacobianAccumulator{FloatT}(), LogLikelihoodAccumulator{FloatT}(), VariableOrderAccumulator{IntT}(), ) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 3c092c06b..3b790576a 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -18,7 +18,7 @@ is_supported(::ADTypes.AutoReverseDiff) = true """ LogDensityFunction( model::Model, - getlogdensity::Function=getlogjoint, + getlogdensity::Function=getlogjoint_internal, varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing ) @@ -29,10 +29,37 @@ A struct which contains a model, along with all the information necessary to: - and if `adtype` is provided, calculate the gradient of the log density at that point. -At its most basic level, a LogDensityFunction wraps the model together with a -function that specifies how to extract the log density, and the type of -VarInfo to be used. These must be known in order to calculate the log density -(using [`DynamicPPL.evaluate!!`](@ref)). +This information can be extracted using the LogDensityProblems.jl interface, +specifically, using `LogDensityProblems.logdensity` and +`LogDensityProblems.logdensity_and_gradient`. If `adtype` is nothing, then only +`logdensity` is implemented. If `adtype` is a concrete AD backend type, then +`logdensity_and_gradient` is also implemented. + +There are several options for `getlogdensity` that are 'supported' out of the +box: + +- [`getlogjoint_internal`](@ref): calculate the log joint, including the + log-Jacobian term for any variables that have been linked in the provided + VarInfo. +- [`getlogprior_internal`](@ref): calculate the log prior, including the + log-Jacobian term for any variables that have been linked in the provided + VarInfo. +- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring + any effects of linking +- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring + any effects of linking +- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected + by linking, since transforms are only applied to random variables) + +!!! note + By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the + result of `LogDensityProblems.logdensity(f, x)` will depend on whether the + `LogDensityFunction` was created with a linked or unlinked VarInfo. This + is done primarily to ease interoperability with MCMC samplers. + +If you provide one of these functions, a `VarInfo` will be automatically created +for you. If you provide a different function, you have to manually create a +VarInfo and pass it as the third argument. If the `adtype` keyword argument is provided, then this struct will also store the adtype along with other information for efficient calculation of the @@ -40,10 +67,6 @@ gradient of the log density. Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD backend itself to have been loaded (e.g. with `import Backend`). -`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface. -If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a -concrete AD backend type, then `logdensity_and_gradient` is also implemented. - # Fields $(FIELDS) @@ -74,7 +97,7 @@ julia> LogDensityProblems.dimension(f) 1 julia> # By default it uses `VarInfo` under the hood, but this is not necessary. - f = LogDensityFunction(model, getlogjoint, SimpleVarInfo(model)); + f = LogDensityFunction(model, getlogjoint_internal, SimpleVarInfo(model)); julia> LogDensityProblems.logdensity(f, [0.0]) -2.3378770664093453 @@ -99,7 +122,7 @@ struct LogDensityFunction{ } <: AbstractModel "model used for evaluation" model::M - "function to be called on `varinfo` to extract the log density. By default `getlogjoint`." + "function to be called on `varinfo` to extract the log density. By default `getlogjoint_internal`." getlogdensity::F "varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`." varinfo::V @@ -110,7 +133,7 @@ struct LogDensityFunction{ function LogDensityFunction( model::Model, - getlogdensity::Function=getlogjoint, + getlogdensity::Function=getlogjoint_internal, varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) @@ -180,7 +203,15 @@ function ldf_default_varinfo(::Model, getlogdensity::Function) return error(msg) end -ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) = VarInfo(model) +ldf_default_varinfo(model::Model, ::typeof(getlogjoint_internal)) = VarInfo(model) + +function ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) + return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogLikelihoodAccumulator())) +end + +function ldf_default_varinfo(model::Model, ::typeof(getlogprior_internal)) + return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogJacobianAccumulator())) +end function ldf_default_varinfo(model::Model, ::typeof(getlogprior)) return setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) diff --git a/src/model.jl b/src/model.jl index 93e77eaec..dbbe0b85b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -995,6 +995,10 @@ Base.rand(model::Model) = rand(Random.default_rng(), NamedTuple, model) Return the log joint probability of variables `varinfo` for the probabilistic `model`. +Note that this probability always refers to the parameters in unlinked space, i.e., +the return value of `logjoint` does not depend on whether `VarInfo` has been linked +or not. + See [`logprior`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) @@ -1042,6 +1046,10 @@ end Return the log prior probability of variables `varinfo` for the probabilistic `model`. +Note that this probability always refers to the parameters in unlinked space, i.e., +the return value of `logprior` does not depend on whether `VarInfo` has been linked +or not. + See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 44882f91e..dea432022 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -74,6 +74,9 @@ function accumulate_assume!!( # T is the element type of the vectors that are the values of `acc.logps`. Usually # it's LogProbType. T = eltype(last(fieldtypes(eltype(acc.logps)))) + # Note that in only accumulating LogPrior, we effectively ignore logjac + # (since we want to return log densities that don't depend on the + # linking status of the VarInfo). subacc = accumulate_assume!!(LogPriorAccumulator{T}(), val, logjac, vn, right) push!(acc, vn, subacc.logp) end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index abb93a0ab..0a2818e2a 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -122,18 +122,18 @@ Evaluation in transformed space of course also works: ```jldoctest simplevarinfo-general julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) -Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) +Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) julia> # (✓) Positive probability mass on negative numbers! - getlogjoint(last(DynamicPPL.evaluate!!(m, vi))) + getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) -1.3678794411714423 julia> # While if we forget to indicate that it's transformed: vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) -SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) +SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) julia> # (✓) No probability mass on negative numbers! - getlogjoint(last(DynamicPPL.evaluate!!(m, vi))) + getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) -Inf ``` @@ -476,7 +476,7 @@ function assume( f = to_maybe_linked_internal_transform(vi, vn, dist) value_raw, logjac = with_logabsdet_jacobian(f, value) vi = BangBang.push!!(vi, vn, value_raw, dist) - vi = accumulate_assume!!(vi, value, -logjac, vn, dist) + vi = accumulate_assume!!(vi, value, logjac, vn, dist) return value, vi end @@ -494,6 +494,7 @@ end istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) +istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = istrans(vi.varinfo) islinked(vi::SimpleVarInfo) = istrans(vi) @@ -619,8 +620,8 @@ function link!!( x = vi.values y, logjac = with_logabsdet_jacobian(b, x) vi_new = Accessors.@set(vi.values = y) - if hasacc(vi_new, Val(:LogPrior)) - vi_new = acclogprior!!(vi_new, -logjac) + if hasacc(vi_new, Val(:LogJacobian)) + vi_new = acclogjac!!(vi_new, logjac) end return settrans!!(vi_new, t) end @@ -632,10 +633,13 @@ function invlink!!( ) b = t.bijector y = vi.values - x, logjac = with_logabsdet_jacobian(b, y) + x, inv_logjac = with_logabsdet_jacobian(b, y) vi_new = Accessors.@set(vi.values = x) - if hasacc(vi_new, Val(:LogPrior)) - vi_new = acclogprior!!(vi_new, logjac) + # Mildly confusing: we need to _add_ the logjac of the inverse transform, + # because we are trying to remove the logjac of the forward transform + # that was previously accumulated when linking. + if hasacc(vi_new, Val(:LogJacobian)) + vi_new = acclogjac!!(vi_new, inv_logjac) end return settrans!!(vi_new, NoTransformation()) end diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index d4f6f9a1d..1ac33a481 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,7 +4,8 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint, link +using DynamicPPL: + Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: AbstractRNG, default_rng using Statistics: median @@ -224,7 +225,7 @@ function run_ad( benchmark::Bool=false, atol::AbstractFloat=100 * eps(), rtol::AbstractFloat=sqrt(eps()), - getlogdensity::Function=getlogjoint, + getlogdensity::Function=getlogjoint_internal, rng::AbstractRNG=default_rng(), varinfo::AbstractVarInfo=link(VarInfo(rng, model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 9b82cd8b4..5f0a6d3e5 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -201,6 +201,11 @@ function resetlogp!!(vi::ThreadSafeVarInfo) zero, vi.accs_by_thread[i], Val(:LogPrior) ) end + if hasacc(vi, Val(:LogJacobian)) + vi.accs_by_thread[i] = map_accumulator( + zero, vi.accs_by_thread[i], Val(:LogJacobian) + ) + end if hasacc(vi, Val(:LogLikelihood)) vi.accs_by_thread[i] = map_accumulator( zero, vi.accs_by_thread[i], Val(:LogLikelihood) diff --git a/src/transforming.jl b/src/transforming.jl index e3da0ff29..56f861cff 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -15,8 +15,8 @@ NodeTrait(::DynamicTransformationContext) = IsLeaf() function tilde_assume( ::DynamicTransformationContext{isinverse}, right, vn, vi ) where {isinverse} - r = vi[vn, right] - lp = Bijectors.logpdf_with_trans(right, r, !isinverse) + # vi[vn, right] always provides the value in unlinked space. + x = vi[vn, right] if istrans(vi, vn) isinverse || @warn "Trying to link an already transformed variable ($vn)" @@ -24,13 +24,11 @@ function tilde_assume( isinverse && @warn "Trying to invlink a non-transformed variable ($vn)" end - # Only transform if `!isinverse` since `vi[vn, right]` - # already performs the inverse transformation if it's transformed. - r_transformed = isinverse ? r : link_transform(right)(r) - if hasacc(vi, Val(:LogPrior)) - vi = acclogprior!!(vi, lp) - end - return r, setindex!!(vi, r_transformed, vn) + transform = isinverse ? identity : link_transform(right) + y, logjac = with_logabsdet_jacobian(transform, x) + vi = accumulate_assume!!(vi, x, logjac, vn, right) + vi = setindex!!(vi, y, vn) + return x, vi end function tilde_observe!!(::DynamicTransformationContext, right, left, vn, vi) @@ -53,21 +51,7 @@ function _transform!!( ) # To transform using DynamicTransformationContext, we evaluate the model using that as the leaf context: model = contextualize(model, setleafcontext(model.context, ctx)) - # but we do not need to use any accumulators other than LogPriorAccumulator - # (which is affected by the Jacobian of the transformation). - accs = getaccs(vi) - has_logprior = haskey(accs, Val(:LogPrior)) - if has_logprior - old_logprior = getacc(accs, Val(:LogPrior)) - vi = setaccs!!(vi, (old_logprior,)) - end vi = settrans!!(last(evaluate!!(model, vi)), t) - # Restore the accumulators. - if has_logprior - new_logprior = getacc(vi, Val(:LogPrior)) - accs = setacc!!(accs, new_logprior) - end - vi = setaccs!!(vi, accs) return vi end diff --git a/src/varinfo.jl b/src/varinfo.jl index d8233ae07..7b819c58f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1148,8 +1148,8 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) setrange!(md, vn, start:(start + length(yvec) - 1)) # Set the new value. setval!(md, yvec, vn) - if hasacc(vi, Val(:LogPrior)) - vi = acclogprior!!(vi, -logjac) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, logjac) end return vi end @@ -1187,8 +1187,8 @@ function _link(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogPrior)) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogJacobian)) + new_varinfo = acclogjac!!(new_varinfo, logjac) end return new_varinfo end @@ -1203,8 +1203,8 @@ function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) md, logjac = _link_metadata!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogPrior)) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogJacobian)) + new_varinfo = acclogjac!!(new_varinfo, logjac) end return new_varinfo end @@ -1351,10 +1351,13 @@ end function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) - md, logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) + md, inv_logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogPrior)) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogJacobian)) + # Mildly confusing: we need to _add_ the logjac of the inverse transform, + # because we are trying to remove the logjac of the forward transform + # that was previously accumulated when linking. + new_varinfo = acclogjac!!(new_varinfo, inv_logjac) end return new_varinfo end @@ -1367,10 +1370,13 @@ end function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - md, logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) + md, inv_logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogPrior)) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogJacobian)) + # Mildly confusing: we need to _add_ the logjac of the inverse transform, + # because we are trying to remove the logjac of the forward transform + # that was previously accumulated when linking. + new_varinfo = acclogjac!!(new_varinfo, inv_logjac) end return new_varinfo end @@ -1382,7 +1388,7 @@ end vns::NamedTuple{vns_names}, ) where {metadata_names,vns_names} expr = quote - cumulative_logjac = zero(LogProbType) + cumulative_inv_logjac = zero(LogProbType) end mds = Expr(:tuple) for f in metadata_names @@ -1391,10 +1397,10 @@ end mds.args, quote begin - md, logjac = _invlink_metadata!!( + md, inv_logjac = _invlink_metadata!!( model, varinfo, metadata.$f, vns.$f ) - cumulative_logjac += logjac + cumulative_inv_logjac += inv_logjac md end end, @@ -1407,7 +1413,7 @@ end push!( expr.args, quote - (NamedTuple{$metadata_names}($mds), cumulative_logjac) + (NamedTuple{$metadata_names}($mds), cumulative_inv_logjac) end, ) return expr @@ -1415,7 +1421,7 @@ end function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns - cumulative_logjac = zero(LogProbType) + cumulative_inv_logjac = zero(LogProbType) # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn @@ -1430,11 +1436,11 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ y = getindex_internal(varinfo, vn) dist = getdist(varinfo, vn) f = from_linked_internal_transform(varinfo, vn, dist) - x, logjac = with_logabsdet_jacobian(f, y) + x, inv_logjac = with_logabsdet_jacobian(f, y) # Vectorize value. xvec = tovec(x) # Accumulate the log-abs-det jacobian correction. - cumulative_logjac += logjac + cumulative_inv_logjac += inv_logjac # Mark as no longer transformed. settrans!!(varinfo, false, vn) # Return the vectorized transformed value. @@ -1459,25 +1465,25 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ metadata.dists, metadata.flags, ), - cumulative_logjac + cumulative_inv_logjac end function _invlink_metadata!!( ::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns ) vns = target_vns === nothing ? keys(metadata) : target_vns - cumulative_logjac = zero(LogProbType) + cumulative_inv_logjac = zero(LogProbType) for vn in vns transform = gettransform(metadata, vn) old_val = getindex_internal(metadata, vn) - new_val, logjac = with_logabsdet_jacobian(transform, old_val) + new_val, inv_logjac = with_logabsdet_jacobian(transform, old_val) # TODO(mhauru) We are calling a !! function but ignoring the return value. - cumulative_logjac += logjac + cumulative_inv_logjac += inv_logjac new_transform = from_vec_transform(new_val) metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform) settrans!(metadata, false, vn) end - return metadata, cumulative_logjac + return metadata, cumulative_inv_logjac end # TODO(mhauru) The treatment of the case when some variables are linked and others are not diff --git a/test/accumulators.jl b/test/accumulators.jl index 5963ad8b5..506821c38 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -87,7 +87,9 @@ using DynamicPPL: vn = @varname(x) dist = Normal() @test accumulate_assume!!(LogPriorAccumulator(1.0), val, logjac, vn, dist) == - LogPriorAccumulator(1.0 + logjac + logpdf(dist, val)) + LogPriorAccumulator(1.0 + logpdf(dist, val)) + @test accumulate_assume!!(LogJacobianAccumulator(2.0), val, logjac, vn, dist) == + LogJacobianAccumulator(2.0 + logjac) @test accumulate_assume!!( LogLikelihoodAccumulator(1.0), val, logjac, vn, dist ) == LogLikelihoodAccumulator(1.0) @@ -101,6 +103,8 @@ using DynamicPPL: vn = @varname(x) @test accumulate_observe!!(LogPriorAccumulator(1.0), right, left, vn) == LogPriorAccumulator(1.0) + @test accumulate_observe!!(LogJacobianAccumulator(1.0), right, left, vn) == + LogJacobianAccumulator(1.0) @test accumulate_observe!!(LogLikelihoodAccumulator(1.0), right, left, vn) == LogLikelihoodAccumulator(1.0 + logpdf(right, left)) @test accumulate_observe!!(VariableOrderAccumulator(1), right, left, vn) == diff --git a/test/ad.jl b/test/ad.jl index 308894ada..371e79b06 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -30,7 +30,7 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos linked_varinfo = DynamicPPL.link(varinfo, m) - f = LogDensityFunction(m, getlogjoint, linked_varinfo) + f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) x = DynamicPPL.getparams(f) # Calculate reference logp + gradient of logp using ForwardDiff @@ -52,17 +52,17 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest if is_mooncake && is_1_11 && is_svi_vnv # https://github.com/compintell/Mooncake.jl/issues/470 @test_throws ArgumentError DynamicPPL.LogDensityFunction( - m, getlogjoint, linked_varinfo; adtype=adtype + m, getlogjoint_internal, linked_varinfo; adtype=adtype ) elseif is_mooncake && is_1_10 && is_svi_vnv # TODO: report upstream @test_throws UndefRefError DynamicPPL.LogDensityFunction( - m, getlogjoint, linked_varinfo; adtype=adtype + m, getlogjoint_internal, linked_varinfo; adtype=adtype ) elseif is_mooncake && is_1_10 && is_svi_od # TODO: report upstream @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - m, getlogjoint, linked_varinfo; adtype=adtype + m, getlogjoint_internal, linked_varinfo; adtype=adtype ) else @test run_ad( @@ -113,7 +113,7 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest spl = Sampler(MyEmptyAlg()) sampling_model = contextualize(model, SamplingContext(model.context)) ldf = LogDensityFunction( - sampling_model, getlogjoint; adtype=AutoReverseDiff(; compile=true) + sampling_model, getlogjoint_internal; adtype=AutoReverseDiff(; compile=true) ) x = ldf.varinfo[:] @test LogDensityProblems.logdensity_and_gradient(ldf, x) isa Any diff --git a/test/linking.jl b/test/linking.jl index b0c2dcb5c..cae101c72 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -84,8 +84,11 @@ end else DynamicPPL.link(vi, model) end - # Difference should just be the log-absdet-jacobian "correction". - @test DynamicPPL.getlogjoint(vi) - DynamicPPL.getlogjoint(vi_linked) ≈ log(2) + # Difference between the internal logjoints should just be the log-absdet-jacobian "correction". + @test DynamicPPL.getlogjoint_internal(vi) - + DynamicPPL.getlogjoint_internal(vi_linked) ≈ log(2) + # The non-internal logjoint should be the same since it doesn't depend on linking. + @test DynamicPPL.getlogjoint(vi) ≈ DynamicPPL.getlogjoint(vi_linked) @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist]) # Linked one should be working with a lower-dimensional representation. @test length(vi_linked[:]) < length(vi[:]) @@ -98,7 +101,12 @@ end end @test length(vi_invlinked[:]) == length(vi[:]) @test vi_invlinked[@varname(m), dist] ≈ LowerTriangular(vi[@varname(m), dist]) + # The non-internal logjoint should still be the same, again since + # it doesn't depend on linking. @test DynamicPPL.getlogjoint(vi_invlinked) ≈ DynamicPPL.getlogjoint(vi) + # The internal logjoint should also be the same as before the round-trip linking. + @test DynamicPPL.getlogjoint_internal(vi_invlinked) ≈ + DynamicPPL.getlogjoint_internal(vi) end end @@ -130,7 +138,7 @@ end end @test length(vi_linked[:]) == d * (d - 1) ÷ 2 # Should now include the log-absdet-jacobian correction. - @test !(getlogjoint(vi_linked) ≈ lp) + @test !(getlogjoint_internal(vi_linked) ≈ lp) # Invlinked. vi_invlinked = if mutable DynamicPPL.invlink!!(deepcopy(vi_linked), model) @@ -138,7 +146,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == d^2 - @test getlogjoint(vi_invlinked) ≈ lp + @test getlogjoint_internal(vi_invlinked) ≈ lp end end end @@ -164,7 +172,7 @@ end end @test length(vi_linked[:]) == d - 1 # Should now include the log-absdet-jacobian correction. - @test !(getlogjoint(vi_linked) ≈ lp) + @test !(getlogjoint_internal(vi_linked) ≈ lp) # Invlinked. vi_invlinked = if mutable DynamicPPL.invlink!!(deepcopy(vi_linked), model) @@ -172,7 +180,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == d - @test getlogjoint(vi_invlinked) ≈ lp + @test getlogjoint_internal(vi_invlinked) ≈ lp end end end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index c4d0d6beb..fbd868f71 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -26,8 +26,11 @@ end loglikelihood(model, vi) @testset "$(varinfo)" for varinfo in varinfos + # Note use of `getlogjoint` rather than `getlogjoint_internal` here ... logdensity = DynamicPPL.LogDensityFunction(model, getlogjoint, varinfo) θ = varinfo[:] + # ... because it has to match with `logjoint(model, vi)`, which always returns + # the unlinked value @test LogDensityProblems.logdensity(logdensity, θ) ≈ logjoint(model, varinfo) @test LogDensityProblems.dimension(logdensity) == length(θ) end diff --git a/test/model.jl b/test/model.jl index daa3cc743..81f84e548 100644 --- a/test/model.jl +++ b/test/model.jl @@ -485,11 +485,18 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() DynamicPPL.untyped_simple_varinfo(model), ] @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + logjoint = getlogjoint(varinfo) # unlinked space varinfo_linked = DynamicPPL.link(varinfo, model) varinfo_linked_result = last( DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked)) ) + # getlogjoint should return the same result as before it was linked @test getlogjoint(varinfo_linked) ≈ getlogjoint(varinfo_linked_result) + @test getlogjoint(varinfo_linked) ≈ logjoint + # getlogjoint_internal shouldn't + @test getlogjoint_internal(varinfo_linked) ≈ + getlogjoint_internal(varinfo_linked_result) + @test !isapprox(getlogjoint_internal(varinfo_linked), logjoint) end end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index e300c651e..3cca1b5dc 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -89,38 +89,40 @@ @testset "link!! & invlink!! on $(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) - @testset "$(typeof(vi))" for vi in ( - SimpleVarInfo(Dict()), - SimpleVarInfo(values_constrained), - SimpleVarInfo(DynamicPPL.VarNamedVector()), - DynamicPPL.typed_varinfo(model), + @testset "$name" for (name, vi) in ( + ("SVI{Dict}", SimpleVarInfo(Dict())), + ("SVI{NamedTuple}", SimpleVarInfo(values_constrained)), + ("SVI{VNV}", SimpleVarInfo(DynamicPPL.VarNamedVector())), + ("TypedVarInfo", DynamicPPL.typed_varinfo(model)), ) for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) end vi = last(DynamicPPL.evaluate!!(model, vi)) - # `link!!` - vi_linked = link!!(deepcopy(vi), model) - lp_linked = getlogjoint(vi_linked) - values_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + # Calculate ground truth + lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true( model, values_constrained... ) - # Should result in the correct logjoint. + _, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + model, values_constrained... + ) + + # `link!!` + vi_linked = link!!(deepcopy(vi), model) + lp_unlinked = getlogjoint(vi_linked) + lp_linked = getlogjoint_internal(vi_linked) @test lp_linked ≈ lp_linked_true - # Should be approx. the same as the "lazy" transformation. - @test logjoint(model, vi_linked) ≈ lp_linked + @test lp_unlinked ≈ lp_unlinked_true + @test logjoint(model, vi_linked) ≈ lp_unlinked # `invlink!!` vi_invlinked = invlink!!(deepcopy(vi_linked), model) - lp_invlinked = getlogjoint(vi_invlinked) - lp_invlinked_true = DynamicPPL.TestUtils.logjoint_true( - model, values_constrained... - ) - # Should result in the correct logjoint. - @test lp_invlinked ≈ lp_invlinked_true - # Should be approx. the same as the "lazy" transformation. - @test logjoint(model, vi_invlinked) ≈ lp_invlinked + lp_unlinked = getlogjoint(vi_invlinked) + also_lp_unlinked = getlogjoint_internal(vi_invlinked) + @test lp_unlinked ≈ lp_unlinked_true + @test also_lp_unlinked ≈ lp_unlinked_true + @test logjoint(model, vi_invlinked) ≈ lp_unlinked # Should result in same values. @test all( @@ -143,10 +145,10 @@ end svi_vnv = SimpleVarInfo(vnv) - @testset "$(nameof(typeof(DynamicPPL.values_as(svi))))" for svi in ( - svi_nt, - svi_dict, - svi_vnv, + @testset "$name" for (name, svi) in ( + ("NamedTuple", svi_nt), + ("Dict", svi_dict), + ("VarNamedVector", svi_vnv), # TODO(mhauru) Fix linked SimpleVarInfos to work with our test models. # DynamicPPL.settrans!!(deepcopy(svi_nt), true), # DynamicPPL.settrans!!(deepcopy(svi_dict), true), @@ -250,7 +252,7 @@ end # `getlogp` should be equal to the logjoint with log-absdet-jac correction. - lp = getlogjoint(svi) + lp = getlogjoint_internal(svi) # needs higher atol because of https://github.com/TuringLang/Bijectors.jl/issues/375 @test lp ≈ lp_true atol = 1.2e-5 end @@ -281,31 +283,36 @@ vi_linked = DynamicPPL.setindex!!(vi_linked, -rand(), vn) end - retval, vi_linked_result = DynamicPPL.evaluate!!(model, deepcopy(vi_linked)) + # NOTE: Evaluating a linked VarInfo, **specifically when the transformation + # is static**, will result in an invlinked VarInfo. This is because of + # `maybe_invlink_before_eval!`, which only invlinks if the transformation + # is static. (src/abstract_varinfo.jl) + retval, vi_unlinked_again = DynamicPPL.evaluate!!(model, deepcopy(vi_linked)) @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≠ DynamicPPL.tovec(retval.s) # `s` is unconstrained in original @test DynamicPPL.tovec( - DynamicPPL.getindex_internal(vi_linked_result, @varname(s)) + DynamicPPL.getindex_internal(vi_unlinked_again, @varname(s)) ) == DynamicPPL.tovec(retval.s) # `s` is constrained in result # `m` should not be transformed. @test vi_linked[@varname(m)] == retval.m - @test vi_linked_result[@varname(m)] == retval.m + @test vi_unlinked_again[@varname(m)] == retval.m - # Compare to truth. - retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + # Get ground truths + retval_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( model, retval.s, retval.m ) + lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true(model, retval.s, retval.m) @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≈ DynamicPPL.tovec(retval_unconstrained.s) @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(m))) ≈ DynamicPPL.tovec(retval_unconstrained.m) - # The resulting varinfo should hold the correct logp. - lp = getlogjoint(vi_linked_result) - @test lp ≈ lp_true + # The unlinked varinfo should hold the unlinked logp. + lp_unlinked = getlogjoint(vi_unlinked_again) + @test getlogjoint(vi_unlinked_again) ≈ lp_unlinked_true end end end diff --git a/test/varinfo.jl b/test/varinfo.jl index dad54f024..16a9a857d 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -167,8 +167,9 @@ end vi = last(DynamicPPL.evaluate!!(m, deepcopy(vi))) @test getlogprior(vi) == lp_a + lp_b + @test getlogjac(vi) == 0.0 @test getloglikelihood(vi) == lp_c + lp_d - @test getlogp(vi) == (; logprior=lp_a + lp_b, loglikelihood=lp_c + lp_d) + @test getlogp(vi) == (; logprior=lp_a + lp_b, logjac=0.0, loglikelihood=lp_c + lp_d) @test getlogjoint(vi) == lp_a + lp_b + lp_c + lp_d @test get_num_produce(vi) == 2 @test begin @@ -183,17 +184,21 @@ end vi = setlogprior!!(vi, -1.0) getlogprior(vi) == -1.0 end + @test begin + vi = setlogjac!!(vi, -1.0) + getlogjac(vi) == -1.0 + end @test begin vi = setloglikelihood!!(vi, -1.0) getloglikelihood(vi) == -1.0 end @test begin - vi = setlogp!!(vi, (logprior=-3.0, loglikelihood=-3.0)) - getlogp(vi) == (; logprior=-3.0, loglikelihood=-3.0) + vi = setlogp!!(vi, (logprior=-3.0, logjac=-3.0, loglikelihood=-3.0)) + getlogp(vi) == (; logprior=-3.0, logjac=-3.0, loglikelihood=-3.0) end @test begin vi = acclogp!!(vi, (logprior=1.0, loglikelihood=1.0)) - getlogp(vi) == (; logprior=-2.0, loglikelihood=-2.0) + getlogp(vi) == (; logprior=-2.0, logjac=-3.0, loglikelihood=-2.0) end @test getlogp(setlogp!!(vi, getlogp(vi))) == getlogp(vi) @@ -206,7 +211,7 @@ end # need regex because 1.11 and 1.12 throw different errors (in 1.12 the # missing field is surrounded by backticks) @test_throws r"has no field `?LogLikelihood" getloglikelihood(vi) - @test_throws r"has no field `?LogLikelihood" getlogp(vi) + @test_throws r"has no field `?LogJacobian" getlogp(vi) @test_throws r"has no field `?LogLikelihood" getlogjoint(vi) @test_throws r"has no field `?VariableOrder" get_num_produce(vi) @test begin @@ -552,71 +557,52 @@ end end end - @testset "istrans" begin + @testset "logp evaluation on linked varinfo" begin @model demo_constrained() = x ~ truncated(Normal(); lower=0) model = demo_constrained() vn = @varname(x) dist = truncated(Normal(); lower=0) - ### `VarInfo` - # Need to run once since we can't specify that we want to _sample_ - # in the unconstrained space for `VarInfo` without having `vn` - # present in the `varinfo`. - - ## `untyped_varinfo` - vi = DynamicPPL.untyped_varinfo(model) + function test_linked_varinfo(model, vi) + # vn and dist are taken from the containing scope + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) + x = f(DynamicPPL.getindex_internal(vi, vn)) + @test istrans(vi, vn) + @test getlogjoint_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogprior_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getloglikelihood(vi) == 0.0 + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) + @test getlogprior(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) + end ## `untyped_varinfo` vi = DynamicPPL.untyped_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi) ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi) ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi) ### `SimpleVarInfo` ## `SimpleVarInfo{<:NamedTuple}` vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi) ## `SimpleVarInfo{<:Dict}` vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi) ## `SimpleVarInfo{<:VarNamedVector}` vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi) end @testset "values_as" begin @@ -719,8 +705,8 @@ end lp = logjoint(model, varinfo) @test lp ≈ lp_true @test getlogjoint(varinfo) ≈ lp_true - lp_linked = getlogjoint(varinfo_linked) - @test lp_linked ≈ lp_linked_true + lp_linked_internal = getlogjoint_internal(varinfo_linked) + @test lp_linked_internal ≈ lp_linked_true # TODO: Compare values once we are no longer working with `NamedTuple` for # the true values, e.g. `value_true`. @@ -732,6 +718,7 @@ end ) @test length(varinfo_invlinked[:]) == length(varinfo[:]) @test getlogjoint(varinfo_invlinked) ≈ lp_true + @test getlogjoint_internal(varinfo_invlinked) ≈ lp_true end end end From 05cd886d56ed89ffe8fb811b51c799e88392d399 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 25 Jul 2025 11:12:39 +0100 Subject: [PATCH 24/27] Fix behaviour of `set_retained_vns_del!` for `num_produce == 0` (#1000) --- src/varinfo.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 7b819c58f..101eb6d50 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1829,19 +1829,19 @@ end """ set_retained_vns_del!(vi::VarInfo) -Set the `"del"` flag of variables in `vi` with `order > num_produce` to `true`. +Set the `"del"` flag of variables in `vi` with `order > num_produce` to `true`. If +`num_produce` is `0`, _all_ variables will have their `"del"` flag set to `true`. Will error if `vi` does not have an accumulator for `VariableOrder`. """ function set_retained_vns_del!(vi::VarInfo) if !hasacc(vi, Val(:VariableOrder)) msg = "`vi` must have an accumulator for VariableOrder to set the `del` flag." - raise(ArgumentError(msg)) + throw(ArgumentError(msg)) end num_produce = get_num_produce(vi) for vn in keys(vi) - order = getorder(vi, vn) - if order > num_produce + if num_produce == 0 || getorder(vi, vn) > num_produce set_flag!(vi, vn, "del") end end From c6c0cbc3fdc39e3eff2b8629079e002d17a4f61c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 25 Jul 2025 15:22:08 +0100 Subject: [PATCH 25/27] `InitContext`, part 2 - Move `hasvalue` and `getvalue` to AbstractPPL; enforce key type of `AbstractDict` (#980) * point to unmerged AbstractPPL branch * Remove code that was moved to AbstractPPL * Remove Dictionaries with Any key type * Fix bad merge conflict resolution * Fix doctests * Point to AbstractPPL@0.13 This reverts commit 709dc9ecff1ed7cde441447ca6a6108f182a219c. * Fix doctests * Fix docs AbstractPPL bound * Remove stray `Pkg.update()` --- Project.toml | 2 +- docs/Project.toml | 2 +- src/DynamicPPL.jl | 2 +- src/model.jl | 12 ++- src/simple_varinfo.jl | 16 ++-- src/test_utils/varinfo.jl | 2 +- src/utils.jl | 193 -------------------------------------- src/values_as_in_model.jl | 4 +- src/varnamedvector.jl | 2 +- test/Project.toml | 2 +- test/simple_varinfo.jl | 2 +- test/varinfo.jl | 11 +-- 12 files changed, 30 insertions(+), 220 deletions(-) diff --git a/Project.toml b/Project.toml index c23845b8c..1f37515ab 100644 --- a/Project.toml +++ b/Project.toml @@ -47,7 +47,7 @@ DynamicPPLMooncakeExt = ["Mooncake"] [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.11, 0.12" +AbstractPPL = "0.13" Accessors = "0.1" BangBang = "0.4.1" Bijectors = "0.13.18, 0.14, 0.15" diff --git a/docs/Project.toml b/docs/Project.toml index 5797a8fd1..1cd1d90d2 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -14,7 +14,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] -AbstractPPL = "0.11, 0.12" +AbstractPPL = "0.13" Accessors = "0.1" DataStructures = "0.18" Distributions = "0.25" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 15d39014e..c4e7e6fba 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -23,7 +23,7 @@ using DocStringExtensions using Random: Random # For extending -import AbstractPPL: predict +import AbstractPPL: predict, hasvalue, getvalue # TODO: Remove these when it's possible. import Bijectors: link, invlink diff --git a/src/model.jl b/src/model.jl index dbbe0b85b..ac9968cf2 100644 --- a/src/model.jl +++ b/src/model.jl @@ -981,7 +981,11 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} - x = last(evaluate_and_sample!!(rng, model, SimpleVarInfo{Float64}(OrderedDict()))) + x = last( + evaluate_and_sample!!( + rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()) + ), + ) return values_as(x, T) end @@ -1032,7 +1036,7 @@ julia> logjoint(demo_model([1., 2.]), chain); function logjoint(model::Model, chain::AbstractMCMC.AbstractChains) var_info = VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict( + argvals_dict = OrderedDict{VarName,Any}( vn_parent => values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for vn_parent in keys(var_info) @@ -1090,7 +1094,7 @@ julia> logprior(demo_model([1., 2.]), chain); function logprior(model::Model, chain::AbstractMCMC.AbstractChains) var_info = VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict( + argvals_dict = OrderedDict{VarName,Any}( vn_parent => values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for vn_parent in keys(var_info) @@ -1144,7 +1148,7 @@ julia> loglikelihood(demo_model([1., 2.]), chain); function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains) var_info = VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict( + argvals_dict = OrderedDict{VarName,Any}( vn_parent => values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for vn_parent in keys(var_info) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 0a2818e2a..1538428fd 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -62,7 +62,7 @@ ERROR: type NamedTuple has no field x [...] julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict())); + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); julia> # (✓) Sort of fast, but only possible at runtime. vi[@varname(x[1])] @@ -70,11 +70,11 @@ julia> # (✓) Sort of fast, but only possible at runtime. julia> # In addtion, we can only access varnames as they appear in the model! vi[@varname(x)] -ERROR: KeyError: key x not found +ERROR: x was not found in the dictionary provided [...] julia> vi[@varname(x[1:2])] -ERROR: KeyError: key x[1:2] not found +ERROR: x[1:2] was not found in the dictionary provided [...] ``` @@ -107,7 +107,7 @@ julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true)); + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ 0.6225185067787314 @@ -177,11 +177,11 @@ julia> svi_dict[@varname(m.a[1])] 1.0 julia> svi_dict[@varname(m.a[2])] -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +ERROR: m.a[2] was not found in the dictionary provided [...] julia> svi_dict[@varname(m.b)] -ERROR: type NamedTuple has no field b +ERROR: m.b was not found in the dictionary provided [...] ``` """ @@ -212,7 +212,7 @@ end function SimpleVarInfo(values) return SimpleVarInfo{LogProbType}(values) end -function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict}) +function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict{<:VarName}}) return if isempty(values) # Can't infer from values, so we just use default. SimpleVarInfo{LogProbType}(values) @@ -264,7 +264,7 @@ function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D} end function untyped_simple_varinfo(model::Model) - varinfo = SimpleVarInfo(OrderedDict()) + varinfo = SimpleVarInfo(OrderedDict{VarName,Any}()) return last(evaluate_and_sample!!(model, varinfo)) end diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 542fc17fc..26e2aa7ca 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -34,7 +34,7 @@ function setup_varinfos( # SimpleVarInfo svi_typed = SimpleVarInfo(example_values) - svi_untyped = SimpleVarInfo(OrderedDict()) + svi_untyped = SimpleVarInfo(OrderedDict{VarName,Any}()) svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector()) varinfos = map(( diff --git a/src/utils.jl b/src/utils.jl index 0f4d98b11..af2891a2b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -751,199 +751,6 @@ function unflatten(original::AbstractDict, x::AbstractVector) return D(zip(keys(original), unflatten(collect(values(original)), x))) end -# TODO: Move `getvalue` and `hasvalue` to AbstractPPL.jl. -""" - getvalue(vals, vn::VarName) - -Return the value(s) in `vals` represented by `vn`. - -Note that this method is different from `getindex`. See examples below. - -# Examples - -For `NamedTuple`: - -```jldoctest -julia> vals = (x = [1.0],); - -julia> DynamicPPL.getvalue(vals, @varname(x)) # same as `getindex` -1-element Vector{Float64}: - 1.0 - -julia> DynamicPPL.getvalue(vals, @varname(x[1])) # different from `getindex` -1.0 - -julia> DynamicPPL.getvalue(vals, @varname(x[2])) -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] -[...] -``` - -For `AbstractDict`: - -```jldoctest -julia> vals = Dict(@varname(x) => [1.0]); - -julia> DynamicPPL.getvalue(vals, @varname(x)) # same as `getindex` -1-element Vector{Float64}: - 1.0 - -julia> DynamicPPL.getvalue(vals, @varname(x[1])) # different from `getindex` -1.0 - -julia> DynamicPPL.getvalue(vals, @varname(x[2])) -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] -[...] -``` - -In the `AbstractDict` case we can also have keys such as `v[1]`: - -```jldoctest -julia> vals = Dict(@varname(x[1]) => [1.0,]); - -julia> DynamicPPL.getvalue(vals, @varname(x[1])) # same as `getindex` -1-element Vector{Float64}: - 1.0 - -julia> DynamicPPL.getvalue(vals, @varname(x[1][1])) # different from `getindex` -1.0 - -julia> DynamicPPL.getvalue(vals, @varname(x[1][2])) -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] -[...] - -julia> DynamicPPL.getvalue(vals, @varname(x[2][1])) -ERROR: KeyError: key x[2][1] not found -[...] -``` -""" -getvalue(vals::NamedTuple, vn::VarName) = get(vals, vn) -getvalue(vals::AbstractDict, vn::VarName) = nested_getindex(vals, vn) - -""" - hasvalue(vals, vn::VarName) - -Determine whether `vals` has a mapping for a given `vn`, as compatible with [`getvalue`](@ref). - -# Examples -With `x` as a `NamedTuple`: - -```jldoctest -julia> DynamicPPL.hasvalue((x = 1.0, ), @varname(x)) -true - -julia> DynamicPPL.hasvalue((x = 1.0, ), @varname(x[1])) -false - -julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x)) -true - -julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x[1])) -true - -julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x[2])) -false -``` - -With `x` as a `AbstractDict`: - -```jldoctest -julia> DynamicPPL.hasvalue(Dict(@varname(x) => 1.0, ), @varname(x)) -true - -julia> DynamicPPL.hasvalue(Dict(@varname(x) => 1.0, ), @varname(x[1])) -false - -julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x)) -true - -julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x[1])) -true - -julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x[2])) -false -``` - -In the `AbstractDict` case we can also have keys such as `v[1]`: - -```jldoctest -julia> vals = Dict(@varname(x[1]) => [1.0,]); - -julia> DynamicPPL.hasvalue(vals, @varname(x[1])) # same as `haskey` -true - -julia> DynamicPPL.hasvalue(vals, @varname(x[1][1])) # different from `haskey` -true - -julia> DynamicPPL.hasvalue(vals, @varname(x[1][2])) -false - -julia> DynamicPPL.hasvalue(vals, @varname(x[2][1])) -false -``` -""" -function hasvalue(vals::NamedTuple, vn::VarName{sym}) where {sym} - # LHS: Ensure that `nt` indeed has the property we want. - # RHS: Ensure that the optic can view into `nt`. - return haskey(vals, sym) && canview(getoptic(vn), getproperty(vals, sym)) -end - -# For `dictlike` we need to check wether `vn` is "immediately" present, or -# if some ancestor of `vn` is present in `dictlike`. -function hasvalue(vals::AbstractDict, vn::VarName) - # First we check if `vn` is present as is. - haskey(vals, vn) && return true - - # If `vn` is not present, we check any parent-varnames by attempting - # to split the optic into the key / `parent` and the extraction optic / `child`. - # If `issuccess` is `true`, we found such a split, and hence `vn` is present. - parent, child, issuccess = splitoptic(getoptic(vn)) do optic - o = optic === nothing ? identity : optic - haskey(vals, VarName{getsym(vn)}(o)) - end - # When combined with `VarInfo`, `nothing` is equivalent to `identity`. - keyoptic = parent === nothing ? identity : parent - - # Return early if no such split could be found. - issuccess || return false - - # At this point we just need to check that we `canview` the value. - value = vals[VarName{getsym(vn)}(keyoptic)] - - return canview(child, value) -end - -""" - nested_getindex(values::AbstractDict, vn::VarName) - -Return value corresponding to `vn` in `values` by also looking -in the the actual values of the dict. -""" -function nested_getindex(values::AbstractDict, vn::VarName) - maybeval = get(values, vn, nothing) - if maybeval !== nothing - return maybeval - end - - # Split the optic into the key / `parent` and the extraction optic / `child`. - parent, child, issuccess = splitoptic(getoptic(vn)) do optic - o = optic === nothing ? identity : optic - haskey(values, VarName{getsym(vn)}(o)) - end - # When combined with `VarInfo`, `nothing` is equivalent to `identity`. - keyoptic = parent === nothing ? identity : parent - - # If we found a valid split, then we can extract the value. - if !issuccess - # At this point we just throw an error since the key could not be found. - throw(KeyError(vn)) - end - - # TODO: Should we also check that we `canview` the extracted `value` - # rather than just let it fail upon `get` call? - value = values[VarName{getsym(vn)}(keyoptic)] - return child(value) -end - """ update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 1fa0555f0..df663bf54 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -12,12 +12,12 @@ $(TYPEDFIELDS) """ struct ValuesAsInModelAccumulator <: AbstractAccumulator "values that are extracted from the model" - values::OrderedDict + values::OrderedDict{<:VarName} "whether to extract variables on the LHS of :=" include_colon_eq::Bool end function ValuesAsInModelAccumulator(include_colon_eq) - return ValuesAsInModelAccumulator(OrderedDict(), include_colon_eq) + return ValuesAsInModelAccumulator(OrderedDict{VarName,Any}(), include_colon_eq) end function Base.copy(acc::ValuesAsInModelAccumulator) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 965db96d5..5de0874c9 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1482,7 +1482,7 @@ function values_as(vnv::VarNamedVector, ::Type{D}) where {D<:AbstractDict} end # See the docstring of `getvalue` for the semantics of `hasvalue` and `getvalue`, and how -# they differ from `haskey` and `getindex`. They can be found in src/utils.jl. +# they differ from `haskey` and `getindex`. They can be found in AbstractPPL.jl. # TODO(mhauru) This is tricky to implement in the general case, and the below implementation # only covers some simple cases. It's probably sufficient in most situations though. diff --git a/test/Project.toml b/test/Project.toml index afecba1c4..6da3786f5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -30,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.11, 0.12" +AbstractPPL = "0.13" Accessors = "0.1" Aqua = "0.8" Bijectors = "0.15.1" diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 3cca1b5dc..be6deb96e 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -90,7 +90,7 @@ DynamicPPL.TestUtils.DEMO_MODELS values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) @testset "$name" for (name, vi) in ( - ("SVI{Dict}", SimpleVarInfo(Dict())), + ("SVI{Dict}", SimpleVarInfo(Dict{VarName,Any}())), ("SVI{NamedTuple}", SimpleVarInfo(values_constrained)), ("SVI{VNV}", SimpleVarInfo(DynamicPPL.VarNamedVector())), ("TypedVarInfo", DynamicPPL.typed_varinfo(model)), diff --git a/test/varinfo.jl b/test/varinfo.jl index 16a9a857d..bd0c0a987 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -110,7 +110,7 @@ end test_base(VarInfo()) test_base(DynamicPPL.typed_varinfo(VarInfo())) test_base(SimpleVarInfo()) - test_base(SimpleVarInfo(Dict())) + test_base(SimpleVarInfo(Dict{VarName,Any}())) test_base(SimpleVarInfo(DynamicPPL.VarNamedVector())) end @@ -597,7 +597,7 @@ end test_linked_varinfo(model, vi) ## `SimpleVarInfo{<:Dict}` - vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) + vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict{VarName,Any}()), true) test_linked_varinfo(model, vi) ## `SimpleVarInfo{<:VarNamedVector}` @@ -737,11 +737,10 @@ end model, (; x=1.0), (@varname(x),); include_threadsafe=true ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - # Skip the severely inconcrete `SimpleVarInfo` types, since checking for type + # Skip the inconcrete `SimpleVarInfo` types, since checking for type # stability for them doesn't make much sense anyway. - if varinfo isa SimpleVarInfo{OrderedDict{Any,Any}} || - varinfo isa - DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{OrderedDict{Any,Any}}} + if varinfo isa SimpleVarInfo{<:AbstractDict} || + varinfo isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{<:AbstractDict}} continue end @inferred DynamicPPL.unflatten(varinfo, varinfo[:]) From 5d9e934c0bf5d842a8cada97a2ddcfe9ce079582 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 28 Jul 2025 15:34:37 +0100 Subject: [PATCH 26/27] Accumulator miscellanea: Subset, merge, acclogp, and LogProbAccumulator (#999) * logjac accumulator * Fix tests * Fix a whole bunch of stuff * Fix final tests * Fix docs * Fix docs/doctests * Fix maths in LogJacobianAccumulator docstring * Twiddle with a comment * Add changelog * Simplify accs with LogProbAccumulator * Replace + with accumulate for LogProbAccs * Introduce merge and subset for accs * Improve acc tests * Fix docstring typo. Co-authored-by: Penelope Yong * Fix merge --------- Co-authored-by: Penelope Yong --- src/DynamicPPL.jl | 1 + src/abstract_varinfo.jl | 10 +- src/accumulators.jl | 69 ++++++++++ src/default_accumulators.jl | 251 +++++++++++++++++------------------- src/simple_varinfo.jl | 6 +- src/utils.jl | 7 + src/varinfo.jl | 5 +- test/accumulators.jl | 139 ++++++++++++++++++-- 8 files changed, 332 insertions(+), 156 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c4e7e6fba..4a13c9878 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -66,6 +66,7 @@ export AbstractVarInfo, setlogprior!!, setlogjac!!, setloglikelihood!!, + acclogp, acclogp!!, acclogjac!!, acclogprior!!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index cf5ce5706..caf6dc16c 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -358,7 +358,7 @@ Add `logp` to the value of the log of the prior probability in `vi`. See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getlogprior`](@ref), [`setlogprior!!`](@ref). """ function acclogprior!!(vi::AbstractVarInfo, logp) - return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior)) + return map_accumulator!!(acc -> acclogp(acc, logp), vi, Val(:LogPrior)) end """ @@ -369,9 +369,7 @@ Add `logjac` to the value of the log Jacobian in `vi`. See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref). """ function acclogjac!!(vi::AbstractVarInfo, logjac) - return map_accumulator!!( - acc -> acc + LogJacobianAccumulator(logjac), vi, Val(:LogJacobian) - ) + return map_accumulator!!(acc -> acclogp(acc, logjac), vi, Val(:LogJacobian)) end """ @@ -382,9 +380,7 @@ Add `logp` to the value of the log of the likelihood in `vi`. See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getloglikelihood`](@ref), [`setloglikelihood!!`](@ref). """ function accloglikelihood!!(vi::AbstractVarInfo, logp) - return map_accumulator!!( - acc -> acc + LogLikelihoodAccumulator(logp), vi, Val(:LogLikelihood) - ) + return map_accumulator!!(acc -> acclogp(acc, logp), vi, Val(:LogLikelihood)) end """ diff --git a/src/accumulators.jl b/src/accumulators.jl index 0dcf9c7cf..b560307b7 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -30,6 +30,13 @@ To be able to work with multi-threading, it should also implement: - `split(acc::T)` - `combine(acc::T, acc2::T)` +If two accumulators of the same type should be merged in some non-trivial way, other than +always keeping the second one over the first, `merge(acc1::T, acc2::T)` should be defined. + +If limiting the accumulator to a subset of `VarName`s is a meaningful operation and should +do something other than copy the original accumulator, then +`subset(acc::T, vns::AbstractVector{<:VarnName})` should be defined.` + See the documentation for each of these functions for more details. """ abstract type AbstractAccumulator end @@ -113,6 +120,24 @@ used by various AD backends, should implement a method for this function. """ convert_eltype(::Type, acc::AbstractAccumulator) = acc +""" + subset(acc::AbstractAccumulator, vns::AbstractVector{<:VarName}) + +Return a new accumulator that only contains the information for the `VarName`s in `vns`. + +By default returns a copy of `acc`. Subtypes should override this behaviour as needed. +""" +subset(acc::AbstractAccumulator, ::AbstractVector{<:VarName}) = copy(acc) + +""" + merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator) + +Merge two accumulators of the same type. Returns a new accumulator of the same type. + +By default returns a copy of `acc2`. Subtypes should override this behaviour as needed. +""" +Base.merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator) = copy(acc2) + """ AccumulatorTuple{N,T<:NamedTuple} @@ -158,6 +183,50 @@ function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N}) return AccumulatorTuple(convert(T, accs.nt)) end +""" + subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName}) + +Replace each accumulator `acc` in `at` with `subset(acc, vns)`. +""" +function subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName}) + return AccumulatorTuple(map(Base.Fix2(subset, vns), at.nt)) +end + +""" + _joint_keys(nt1::NamedTuple, nt2::NamedTuple) + +A helper function that returns three tuples of keys given two `NamedTuple`s: +The keys only in `nt1`, only in `nt2`, and in both, and in that order. + +Implemented as a generated function to enable constant propagation of the result in `merge`. +""" +@generated function _joint_keys( + nt1::NamedTuple{names1}, nt2::NamedTuple{names2} +) where {names1,names2} + only_in_nt1 = tuple(setdiff(names1, names2)...) + only_in_nt2 = tuple(setdiff(names2, names1)...) + in_both = tuple(intersect(names1, names2)...) + return :($only_in_nt1, $only_in_nt2, $in_both) +end + +""" + merge(at1::AccumulatorTuple, at2::AccumulatorTuple) + +Merge two `AccumulatorTuple`s. + +For any `accumulator_name` that exists in both `at1` and `at2`, we call `merge` on the two +accumulators themselves. Other accumulators are copied. +""" +function Base.merge(at1::AccumulatorTuple, at2::AccumulatorTuple) + keys_in_at1, keys_in_at2, keys_in_both = _joint_keys(at1.nt, at2.nt) + accs_in_at1 = (getfield(at1.nt, key) for key in keys_in_at1) + accs_in_at2 = (getfield(at2.nt, key) for key in keys_in_at2) + accs_in_both = ( + merge(getfield(at1.nt, key), getfield(at2.nt, key)) for key in keys_in_both + ) + return AccumulatorTuple(accs_in_at1..., accs_in_both..., accs_in_at2...) +end + """ setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) diff --git a/src/default_accumulators.jl b/src/default_accumulators.jl index d503b3e64..8d51a8431 100644 --- a/src/default_accumulators.jl +++ b/src/default_accumulators.jl @@ -1,5 +1,78 @@ """ - LogPriorAccumulator{T<:Real} <: AbstractAccumulator + LogProbAccumulator{T} <: AbstractAccumulator + +An abstract type for accumulators that hold a single scalar log probability value. + +Every subtype of `LogProbAccumulator` must implement +* A method for `logp` that returns the scalar log probability value that defines it. +* A single-argument constructor that takes a `logp` value. +* `accumulator_name`, `accumulate_assume!!`, and `accumulate_observe!!` methods like any + other accumulator. + +`LogProbAccumulator` provides implementations for other common functions, like convenience +constructors, `copy`, `show`, `==`, `isequal`, `hash`, `split`, and `combine`. + +This type has no great conceptual significance, it just reduces code duplication between +types like LogPriorAccumulator, LogJacobianAccumulator, and LogLikelihoodAccumulator. +""" +abstract type LogProbAccumulator{T<:Real} <: AbstractAccumulator end + +# The first of the below methods sets AccType{T}() = AccType(zero(T)) for any +# AccType <: LogProbAccumulator{T}. The second one sets LogProbType as the default eltype T +# when calling AccType(). +""" + LogProbAccumulator{T}() + +Create a new `LogProbAccumulator` accumulator with the log prior initialized to zero. +""" +(::Type{AccType})() where {T<:Real,AccType<:LogProbAccumulator{T}} = AccType(zero(T)) +(::Type{AccType})() where {AccType<:LogProbAccumulator} = AccType{LogProbType}() + +Base.copy(acc::LogProbAccumulator) = acc + +function Base.show(io::IO, acc::LogProbAccumulator) + return print(io, "$(string(basetypeof(acc)))($(repr(logp(acc))))") +end + +# Note that == and isequal are different, and equality under the latter should imply +# equality of hashes. Both of the below implementations are also different from the default +# implementation for structs. +function Base.:(==)(acc1::LogProbAccumulator, acc2::LogProbAccumulator) + return accumulator_name(acc1) === accumulator_name(acc2) && logp(acc1) == logp(acc2) +end + +function Base.isequal(acc1::LogProbAccumulator, acc2::LogProbAccumulator) + return basetypeof(acc1) === basetypeof(acc2) && isequal(logp(acc1), logp(acc2)) +end + +Base.hash(acc::T, h::UInt) where {T<:LogProbAccumulator} = hash((T, logp(acc)), h) + +split(::AccType) where {T,AccType<:LogProbAccumulator{T}} = AccType(zero(T)) + +function combine(acc::LogProbAccumulator, acc2::LogProbAccumulator) + if basetypeof(acc) !== basetypeof(acc2) + msg = "Cannot combine accumulators of different types: $(basetypeof(acc)) and $(basetypeof(acc2))" + throw(ArgumentError(msg)) + end + return basetypeof(acc)(logp(acc) + logp(acc2)) +end + +acclogp(acc::LogProbAccumulator, val) = basetypeof(acc)(logp(acc) + val) + +Base.zero(acc::T) where {T<:LogProbAccumulator} = T(zero(logp(acc))) + +function Base.convert( + ::Type{AccType}, acc::LogProbAccumulator +) where {T,AccType<:LogProbAccumulator{T}} + return AccType(convert(T, logp(acc))) +end + +function convert_eltype(::Type{T}, acc::LogProbAccumulator) where {T} + return basetypeof(acc)(convert(T, logp(acc))) +end + +""" + LogPriorAccumulator{T<:Real} <: LogProbAccumulator{T} An accumulator that tracks the cumulative log prior during model execution. @@ -10,21 +83,22 @@ linked or not. # Fields $(TYPEDFIELDS) """ -struct LogPriorAccumulator{T<:Real} <: AbstractAccumulator +struct LogPriorAccumulator{T<:Real} <: LogProbAccumulator{T} "the scalar log prior value" logp::T end -""" - LogPriorAccumulator{T}() +logp(acc::LogPriorAccumulator) = acc.logp -Create a new `LogPriorAccumulator` accumulator with the log prior initialized to zero. -""" -LogPriorAccumulator{T}() where {T<:Real} = LogPriorAccumulator(zero(T)) -LogPriorAccumulator() = LogPriorAccumulator{LogProbType}() +accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior + +function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right) + return acclogp(acc, logpdf(right, val)) +end +accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc """ - LogJacobianAccumulator{T<:Real} <: AbstractAccumulator + LogJacobianAccumulator{T<:Real} <: LogProbAccumulator{T} An accumulator that tracks the cumulative log Jacobian (technically, log(abs(det(J)))) during model execution. Specifically, J refers to the @@ -53,39 +127,44 @@ distribution to unconstrained space. # Fields $(TYPEDFIELDS) """ -struct LogJacobianAccumulator{T<:Real} <: AbstractAccumulator +struct LogJacobianAccumulator{T<:Real} <: LogProbAccumulator{T} "the logabsdet of the link transform Jacobian" logjac::T end -""" - LogJacobianAccumulator{T}() +logp(acc::LogJacobianAccumulator) = acc.logjac -Create a new `LogJacobianAccumulator` accumulator with the log Jacobian initialized to zero. -""" -LogJacobianAccumulator{T}() where {T<:Real} = LogJacobianAccumulator(zero(T)) -LogJacobianAccumulator() = LogJacobianAccumulator{LogProbType}() +accumulator_name(::Type{<:LogJacobianAccumulator}) = :LogJacobian + +function accumulate_assume!!(acc::LogJacobianAccumulator, val, logjac, vn, right) + return acclogp(acc, logjac) +end +accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn) = acc """ - LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator + LogLikelihoodAccumulator{T<:Real} <: LogProbAccumulator{T} An accumulator that tracks the cumulative log likelihood during model execution. # Fields $(TYPEDFIELDS) """ -struct LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator +struct LogLikelihoodAccumulator{T<:Real} <: LogProbAccumulator{T} "the scalar log likelihood value" logp::T end -""" - LogLikelihoodAccumulator{T}() +logp(acc::LogLikelihoodAccumulator) = acc.logp -Create a new `LogLikelihoodAccumulator` accumulator with the log likelihood initialized to zero. -""" -LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T)) -LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}() +accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood + +accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc +function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) + # Note that it's important to use the loglikelihood function here, not logpdf, because + # they handle vectors differently: + # https://github.com/JuliaStats/Distributions.jl/issues/1972 + return acclogp(acc, Distributions.loglikelihood(right, left)) +end """ VariableOrderAccumulator{T} <: AbstractAccumulator @@ -117,85 +196,32 @@ VariableOrderAccumulator{T}(n=zero(T)) where {T<:Integer} = VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n) VariableOrderAccumulator() = VariableOrderAccumulator{Int}() -Base.copy(acc::LogPriorAccumulator) = acc -Base.copy(acc::LogJacobianAccumulator) = acc -Base.copy(acc::LogLikelihoodAccumulator) = acc function Base.copy(acc::VariableOrderAccumulator) return VariableOrderAccumulator(acc.num_produce, copy(acc.order)) end -function Base.show(io::IO, acc::LogPriorAccumulator) - return print(io, "LogPriorAccumulator($(repr(acc.logp)))") -end -function Base.show(io::IO, acc::LogJacobianAccumulator) - return print(io, "LogJacobianAccumulator($(repr(acc.logjac)))") -end -function Base.show(io::IO, acc::LogLikelihoodAccumulator) - return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))") -end function Base.show(io::IO, acc::VariableOrderAccumulator) return print( - io, "VariableOrderAccumulator($(repr(acc.num_produce)), $(repr(acc.order)))" + io, "VariableOrderAccumulator($(string(acc.num_produce)), $(repr(acc.order)))" ) end -# Note that == and isequal are different, and equality under the latter should imply -# equality of hashes. Both of the below implementations are also different from the default -# implementation for structs. -Base.:(==)(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) = acc1.logp == acc2.logp -function Base.:(==)(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator) - return acc1.logjac == acc2.logjac -end -function Base.:(==)(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) - return acc1.logp == acc2.logp -end function Base.:(==)(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) return acc1.num_produce == acc2.num_produce && acc1.order == acc2.order end -function Base.isequal(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) - return isequal(acc1.logp, acc2.logp) -end -function Base.isequal(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator) - return isequal(acc1.logjac, acc2.logjac) -end -function Base.isequal(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) - return isequal(acc1.logp, acc2.logp) -end function Base.isequal(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) return isequal(acc1.num_produce, acc2.num_produce) && isequal(acc1.order, acc2.order) end -Base.hash(acc::LogPriorAccumulator, h::UInt) = hash((LogPriorAccumulator, acc.logp), h) -function Base.hash(acc::LogJacobianAccumulator, h::UInt) - return hash((LogJacobianAccumulator, acc.logjac), h) -end -function Base.hash(acc::LogLikelihoodAccumulator, h::UInt) - return hash((LogLikelihoodAccumulator, acc.logp), h) -end function Base.hash(acc::VariableOrderAccumulator, h::UInt) return hash((VariableOrderAccumulator, acc.num_produce, acc.order), h) end -accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior -accumulator_name(::Type{<:LogJacobianAccumulator}) = :LogJacobian -accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood accumulator_name(::Type{<:VariableOrderAccumulator}) = :VariableOrder -split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T)) -split(::LogJacobianAccumulator{T}) where {T} = LogJacobianAccumulator(zero(T)) -split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T)) split(acc::VariableOrderAccumulator) = copy(acc) -function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator) - return LogPriorAccumulator(acc.logp + acc2.logp) -end -function combine(acc::LogJacobianAccumulator, acc2::LogJacobianAccumulator) - return LogJacobianAccumulator(acc.logjac + acc2.logjac) -end -function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) - return LogLikelihoodAccumulator(acc.logp + acc2.logp) -end function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator) # Note that assumptions are not allowed in parallelised blocks, and thus the # dictionaries should be identical. @@ -204,60 +230,16 @@ function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator) ) end -function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) - return LogPriorAccumulator(acc1.logp + acc2.logp) -end -function Base.:+(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator) - return LogJacobianAccumulator(acc1.logjac + acc2.logjac) -end -function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) - return LogLikelihoodAccumulator(acc1.logp + acc2.logp) -end function increment(acc::VariableOrderAccumulator) return VariableOrderAccumulator(acc.num_produce + oneunit(acc.num_produce), acc.order) end -Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp)) -Base.zero(acc::LogJacobianAccumulator) = LogJacobianAccumulator(zero(acc.logjac)) -Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp)) - -function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right) - return acc + LogPriorAccumulator(logpdf(right, val)) -end -accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc - -function accumulate_assume!!(acc::LogJacobianAccumulator, val, logjac, vn, right) - return acc + LogJacobianAccumulator(logjac) -end -accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn) = acc - -accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc -function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) - # Note that it's important to use the loglikelihood function here, not logpdf, because - # they handle vectors differently: - # https://github.com/JuliaStats/Distributions.jl/issues/1972 - return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left)) -end - function accumulate_assume!!(acc::VariableOrderAccumulator, val, logjac, vn, right) acc.order[vn] = acc.num_produce return acc end accumulate_observe!!(acc::VariableOrderAccumulator, right, left, vn) = increment(acc) -function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T} - return LogPriorAccumulator(convert(T, acc.logp)) -end -function Base.convert( - ::Type{LogJacobianAccumulator{T}}, acc::LogJacobianAccumulator -) where {T} - return LogJacobianAccumulator(convert(T, acc.logjac)) -end -function Base.convert( - ::Type{LogLikelihoodAccumulator{T}}, acc::LogLikelihoodAccumulator -) where {T} - return LogLikelihoodAccumulator(convert(T, acc.logp)) -end function Base.convert( ::Type{VariableOrderAccumulator{ElType,VnType}}, acc::VariableOrderAccumulator ) where {ElType,VnType} @@ -273,15 +255,6 @@ end # convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to # deal with dual number types of AD backends, which shouldn't concern VariableOrderAccumulator. This is # horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`. -function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T} - return LogPriorAccumulator(convert(T, acc.logp)) -end -function convert_eltype(::Type{T}, acc::LogJacobianAccumulator) where {T} - return LogJacobianAccumulator(convert(T, acc.logjac)) -end -function convert_eltype(::Type{T}, acc::LogLikelihoodAccumulator) where {T} - return LogLikelihoodAccumulator(convert(T, acc.logp)) -end function default_accumulators( ::Type{FloatT}=LogProbType, ::Type{IntT}=Int @@ -293,3 +266,19 @@ function default_accumulators( VariableOrderAccumulator{IntT}(), ) end + +function subset(acc::VariableOrderAccumulator, vns::AbstractVector{<:VarName}) + order = filter(pair -> any(subsumes(vn, first(pair)) for vn in vns), acc.order) + return VariableOrderAccumulator(acc.num_produce, order) +end + +""" + merge(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) + +Merge two `VariableOrderAccumulator` instances. + +The `num_produce` field of the return value is the `num_produce` of `acc2`. +""" +function Base.merge(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) + return VariableOrderAccumulator(acc2.num_produce, merge(acc1.order, acc2.order)) +end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 1538428fd..4997b4b8d 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -417,7 +417,9 @@ Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V # `subset` function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) - return Accessors.@set varinfo.values = _subset(varinfo.values, vns) + return SimpleVarInfo( + _subset(varinfo.values, vns), subset(getaccs(varinfo), vns), varinfo.transformation + ) end function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName} @@ -454,7 +456,7 @@ _subset(x::VarNamedVector, vns) = subset(x, vns) # `merge` function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) values = merge(varinfo_left.values, varinfo_right.values) - accs = copy(getaccs(varinfo_right)) + accs = merge(getaccs(varinfo_left), getaccs(varinfo_right)) transformation = merge_transformations( varinfo_left.transformation, varinfo_right.transformation ) diff --git a/src/utils.jl b/src/utils.jl index af2891a2b..d3371271f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1139,3 +1139,10 @@ function group_varnames_by_symbol(vns::VarNameTuple) elements = map(collect, tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...)) return NamedTuple{syms}(elements) end + +""" + basetypeof(x) + +Return `typeof(x)` stripped of its type parameters. +""" +basetypeof(x::T) where {T} = Base.typename(T).wrapper diff --git a/src/varinfo.jl b/src/varinfo.jl index 101eb6d50..b364f5bcc 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -447,7 +447,7 @@ end function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) metadata = subset(varinfo.metadata, vns) - return VarInfo(metadata, copy(varinfo.accs)) + return VarInfo(metadata, subset(getaccs(varinfo), vns)) end function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName}) @@ -528,7 +528,8 @@ end function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) - return VarInfo(metadata, copy(varinfo_right.accs)) + accs = merge(getaccs(varinfo_left), getaccs(varinfo_right)) + return VarInfo(metadata, accs) end function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector) diff --git a/test/accumulators.jl b/test/accumulators.jl index 506821c38..d84fbf43d 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -39,13 +39,11 @@ using DynamicPPL: end @testset "addition and incrementation" begin - @test LogPriorAccumulator(1.0f0) + LogPriorAccumulator(1.0f0) == - LogPriorAccumulator(2.0f0) - @test LogPriorAccumulator(1.0) + LogPriorAccumulator(1.0f0) == - LogPriorAccumulator(2.0) - @test LogLikelihoodAccumulator(1.0f0) + LogLikelihoodAccumulator(1.0f0) == + @test acclogp(LogPriorAccumulator(1.0f0), 1.0f0) == LogPriorAccumulator(2.0f0) + @test acclogp(LogPriorAccumulator(1.0), 1.0f0) == LogPriorAccumulator(2.0) + @test acclogp(LogLikelihoodAccumulator(1.0f0), 1.0f0) == LogLikelihoodAccumulator(2.0f0) - @test LogLikelihoodAccumulator(1.0) + LogLikelihoodAccumulator(1.0f0) == + @test acclogp(LogLikelihoodAccumulator(1.0), 1.0f0) == LogLikelihoodAccumulator(2.0) @test increment(VariableOrderAccumulator()) == VariableOrderAccumulator(1) @test increment(VariableOrderAccumulator{UInt8}()) == @@ -110,6 +108,73 @@ using DynamicPPL: @test accumulate_observe!!(VariableOrderAccumulator(1), right, left, vn) == VariableOrderAccumulator(2) end + + @testset "merge" begin + @test merge(LogPriorAccumulator(1.0), LogPriorAccumulator(2.0)) == + LogPriorAccumulator(2.0) + @test merge(LogJacobianAccumulator(1.0), LogJacobianAccumulator(2.0)) == + LogJacobianAccumulator(2.0) + @test merge(LogLikelihoodAccumulator(1.0), LogLikelihoodAccumulator(2.0)) == + LogLikelihoodAccumulator(2.0) + + @test merge( + VariableOrderAccumulator(1, Dict{VarName,Int}()), + VariableOrderAccumulator(2, Dict{VarName,Int}()), + ) == VariableOrderAccumulator(2, Dict{VarName,Int}()) + @test merge( + VariableOrderAccumulator( + 2, Dict{VarName,Int}((@varname(a) => 1, @varname(b) => 2)) + ), + VariableOrderAccumulator( + 1, Dict{VarName,Int}((@varname(a) => 2, @varname(c) => 3)) + ), + ) == VariableOrderAccumulator( + 1, Dict{VarName,Int}((@varname(a) => 2, @varname(b) => 2, @varname(c) => 3)) + ) + end + + @testset "subset" begin + @test subset(LogPriorAccumulator(1.0), VarName[]) == LogPriorAccumulator(1.0) + @test subset(LogJacobianAccumulator(1.0), VarName[]) == + LogJacobianAccumulator(1.0) + @test subset(LogLikelihoodAccumulator(1.0), VarName[]) == + LogLikelihoodAccumulator(1.0) + + @test subset( + VariableOrderAccumulator(1, Dict{VarName,Int}()), + VarName[@varname(a), @varname(b)], + ) == VariableOrderAccumulator(1, Dict{VarName,Int}()) + @test subset( + VariableOrderAccumulator( + 2, Dict{VarName,Int}((@varname(a) => 1, @varname(b) => 2)) + ), + VarName[@varname(a)], + ) == VariableOrderAccumulator(2, Dict{VarName,Int}((@varname(a) => 1))) + @test subset( + VariableOrderAccumulator( + 2, Dict{VarName,Int}((@varname(a) => 1, @varname(b) => 2)) + ), + VarName[], + ) == VariableOrderAccumulator(2, Dict{VarName,Int}()) + @test subset( + VariableOrderAccumulator( + 2, + Dict{VarName,Int}(( + @varname(a) => 1, + @varname(a.b.c) => 2, + @varname(a.b.c.d[1]) => 2, + @varname(b) => 3, + @varname(c[1]) => 4, + )), + ), + VarName[@varname(a.b), @varname(b)], + ) == VariableOrderAccumulator( + 2, + Dict{VarName,Int}(( + @varname(a.b.c) => 2, @varname(a.b.c.d[1]) => 2, @varname(b) => 3 + )), + ) + end end @testset "accumulator tuples" begin @@ -118,7 +183,7 @@ using DynamicPPL: lp_f32 = LogPriorAccumulator(1.0f0) ll_f64 = LogLikelihoodAccumulator(1.0) ll_f32 = LogLikelihoodAccumulator(1.0f0) - np_i64 = VariableOrderAccumulator(1) + vo_i64 = VariableOrderAccumulator(1) @testset "constructors" begin @test AccumulatorTuple(lp_f64, ll_f64) == AccumulatorTuple((lp_f64, ll_f64)) @@ -132,22 +197,22 @@ using DynamicPPL: end @testset "basic operations" begin - at_all64 = AccumulatorTuple(lp_f64, ll_f64, np_i64) + at_all64 = AccumulatorTuple(lp_f64, ll_f64, vo_i64) @test at_all64[:LogPrior] == lp_f64 @test at_all64[:LogLikelihood] == ll_f64 - @test at_all64[:VariableOrder] == np_i64 + @test at_all64[:VariableOrder] == vo_i64 - @test haskey(AccumulatorTuple(np_i64), Val(:VariableOrder)) - @test ~haskey(AccumulatorTuple(np_i64), Val(:LogPrior)) - @test length(AccumulatorTuple(lp_f64, ll_f64, np_i64)) == 3 + @test haskey(AccumulatorTuple(vo_i64), Val(:VariableOrder)) + @test ~haskey(AccumulatorTuple(vo_i64), Val(:LogPrior)) + @test length(AccumulatorTuple(lp_f64, ll_f64, vo_i64)) == 3 @test keys(at_all64) == (:LogPrior, :LogLikelihood, :VariableOrder) - @test collect(at_all64) == [lp_f64, ll_f64, np_i64] + @test collect(at_all64) == [lp_f64, ll_f64, vo_i64] # Replace the existing LogPriorAccumulator @test setacc!!(at_all64, lp_f32)[:LogPrior] == lp_f32 # Check that setacc!! didn't modify the original - @test at_all64 == AccumulatorTuple(lp_f64, ll_f64, np_i64) + @test at_all64 == AccumulatorTuple(lp_f64, ll_f64, vo_i64) # Add a new accumulator type. @test setacc!!(AccumulatorTuple(lp_f64), ll_f64) == AccumulatorTuple(lp_f64, ll_f64) @@ -175,6 +240,52 @@ using DynamicPPL: acc -> convert_eltype(Float64, acc), accs, Val(:LogLikelihood) ) == AccumulatorTuple(lp_f32, LogLikelihoodAccumulator(1.0)) end + + @testset "merge" begin + vo1 = VariableOrderAccumulator( + 1, Dict{VarName,Int}(@varname(a) => 1, @varname(b) => 1) + ) + vo2 = VariableOrderAccumulator( + 2, Dict{VarName,Int}(@varname(a) => 2, @varname(c) => 2) + ) + accs1 = AccumulatorTuple(lp_f64, ll_f64, vo1) + accs2 = AccumulatorTuple(lp_f32, vo2) + @test merge(accs1, accs2) == AccumulatorTuple( + ll_f64, + lp_f32, + VariableOrderAccumulator( + 2, + Dict{VarName,Int}(@varname(a) => 2, @varname(b) => 1, @varname(c) => 2), + ), + ) + @test merge(AccumulatorTuple(), accs1) == accs1 + @test merge(accs1, AccumulatorTuple()) == accs1 + @test merge(accs1, accs1) == accs1 + end + + @testset "subset" begin + accs = AccumulatorTuple( + lp_f64, + ll_f64, + VariableOrderAccumulator( + 1, + Dict{VarName,Int}( + @varname(a.b) => 1, @varname(a.b[1]) => 2, @varname(b) => 1 + ), + ), + ) + + @test subset(accs, VarName[]) == AccumulatorTuple( + lp_f64, ll_f64, VariableOrderAccumulator(1, Dict{VarName,Int}()) + ) + @test subset(accs, VarName[@varname(a)]) == AccumulatorTuple( + lp_f64, + ll_f64, + VariableOrderAccumulator( + 1, Dict{VarName,Int}(@varname(a.b) => 1, @varname(a.b[1]) => 2) + ), + ) + end end end From c98a67c0d09e72a337e3d5505e94895b5fa748d7 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 6 Aug 2025 13:17:42 +0100 Subject: [PATCH 27/27] Minor tweak to changelog wording --- HISTORY.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index b59d8dd7f..c0e265fbf 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -3,7 +3,8 @@ ## 0.37.0 DynamicPPL 0.37 comes with a substantial reworking of its internals. -Fundamentally, there is no change to the actual modelling syntax: if you are a Turing.jl user, for example, this release is unlikely to affect you much. +Fundamentally, there is no change to the actual modelling syntax: if you are a Turing.jl user, for example, this release will not affect you too much (apart from the changes to `@addlogprob!`). +Any such changes will be covered separately in the Turing.jl changelog when a release is made. However, if you are a package developer or someone who uses DynamicPPL's functionality directly, you will notice a number of changes. To avoid overwhelming the reader, we begin by listing the most important, user-facing changes, before explaining the changes to the internals in more detail.