diff --git a/HISTORY.md b/HISTORY.md index ddbe67842..d67afcbfe 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,8 +2,66 @@ ## 0.38.0 -The `varname_leaves` and `varname_and_value_leaves` functions have been moved to AbstractPPL.jl. -Their behaviour is otherwise identical. +**Breaking changes** + +### Introduction of `InitContext` + +DynamicPPL 0.38 introduces a new evaluation context, `InitContext`. +It is used to generate fresh values for random variables in a model. + +Evaluation contexts are stored inside a `DynamicPPL.Model` object, and control what happens with tilde-statements when a model is run. +The two major leaf (basic) contexts are `DefaultContext` and, now, `InitContext`. +`DefaultContext` is the default context, and it simply uses the values that are already stored in the `VarInfo` object passed to the model evaluation function. +On the other hand, `InitContext` ignores values in the VarInfo object and inserts new values obtained from a specified source. +(It follows also that the VarInfo being used may be empty, which means that `InitContext` is now also the way to obtain a fresh VarInfo for a model.) + +DynamicPPL 0.38 provides three flavours of _initialisation strategies_, which are specified as the second argument to `InitContext`: + + - `InitContext(rng, InitFromPrior())`: New values are sampled from the prior distribution (on the right-hand side of the tilde). + - `InitContext(rng, InitFromUniform(a, b))`: New values are sampled uniformly from the interval `[a, b]`, and then invlinked to the support of the distribution on the right-hand side of the tilde. + - `InitContext(rng, InitFromParams(p, fallback))`: New values are obtained by indexing into the `p` object, which can be a `NamedTuple` or `Dict{<:VarName}`. If a variable is not found in `p`, then the `fallback` strategy is used, which is simply another of these strategies. In particular, `InitFromParams` enables the case where different variables are to be initialised from different sources. + +(It is possible to define your own initialisation strategy; users who wish to do so are referred to the DynamicPPL API documentation and source code.) + +**The main impact on the upcoming Turing.jl release** is that, instead of providing initial values for sampling, the user will be expected to provide an initialisation strategy instead. +This is a more flexible approach, and not only solves a number of pre-existing issues with initialisation of Turing models, but also improves the clarity of user code. +In particular: + + - When providing a set of fixed parameters (i.e. `InitFromParams(p)`), `p` must now either be a NamedTuple or a Dict. Previously Vectors were allowed, which is error-prone because the ordering of variables in a VarInfo is not obvious. + - The parameters in `p` must now always be provided in unlinked space (i.e., in the space of the distribution on the right-hand side of the tilde). Previously, whether a parameter was expected to be in linked or unlinked space depended on whether the VarInfo was linked or not, which was confusing. + +### Removal of `SamplingContext` + +For developers working on DynamicPPL, `InitContext` now completely replaces what used to be `SamplingContext`, `SampleFromPrior`, and `SampleFromUniform`. +Evaluating a model with `SamplingContext(SampleFromPrior())` (e.g. with `DynamicPPL.evaluate_and_sample!!(model, VarInfo(), SampleFromPrior())` has a direct one-to-one replacement in `DynamicPPL.init!!(model, VarInfo(), InitFromPrior())`. +Please see the docstring of `init!!` for more details. +Likewise `SampleFromUniform()` can be replaced with `InitFromUniform()`. +`InitFromParams()` provides new functionality which was previously implemented in the roundabout way of manipulating the VarInfo (e.g. using `unflatten`, or even more hackily by directly modifying values in the VarInfo), and then evaluating using `DefaultContext`. + +The main change that this is likely to create is for those who are implementing samplers or inference algorithms. +The exact way in which this happens will be detailed in the Turing.jl changelog when a new release is made. +Broadly speaking, though, `SamplingContext(MySampler())` will be removed so if your sampler needs custom behaviour with the tilde-pipeline you will likely have to define your own context. + +### Simplification of the tilde-pipeline + +There are now only two functions in the tilde-pipeline that need to be overloaded to change the behaviour of tilde-statements, namely, `tilde_assume!!` and `tilde_observe!!`. +Other functions such as `tilde_assume` and `assume` (and their `observe` counterparts) have been removed. + +Note that this was effectively already the case in DynamicPPL 0.37 (where they were just wrappers around each other). +The separation of these functions was primarily implemented to avoid performing extra work where unneeded (e.g. to not calculate the log-likelihood when `PriorContext` was being used). This functionality has since been replaced with accumulators (see the 0.37 changelog for more details). + +**Other changes** + +### Reimplementation of functions using `InitContext` + +A number of functions have been reimplemented and unified with the help of `InitContext`. +In particular, this release brings substantial performance improvements for `returned` and `predict`. +Their APIs are the same. + +### Upstreaming of VarName functionality + +The implementation of the `varname_leaves` and `varname_and_value_leaves` functions have been moved to AbstractPPL.jl. +Their behaviour is otherwise identical, and they are still accessible from the DynamicPPL module (though still not exported). ## 0.37.3 diff --git a/docs/src/api.md b/docs/src/api.md index d1dddb560..e5c483bca 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -8,7 +8,7 @@ Part of the API of DynamicPPL is defined in the more lightweight interface packa A core component of DynamicPPL is the [`@model`](@ref) macro. It can be used to define probabilistic models in an intuitive way by specifying random variables and their distributions with `~` statements. -These statements are rewritten by `@model` as calls of [internal functions](@ref model_internal) for sampling the variables and computing their log densities. +These statements are rewritten by `@model` as calls of internal functions for sampling the variables and computing their log densities. ```@docs @model @@ -344,6 +344,13 @@ Base.empty! SimpleVarInfo ``` +### Tilde-pipeline + +```@docs +tilde_assume!! +tilde_observe!! +``` + ### Accumulators The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. @@ -447,12 +454,12 @@ AbstractPPL.evaluate!! This method mutates the `varinfo` used for execution. By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`. +If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this. The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. Contexts are subtypes of `AbstractPPL.AbstractContext`. ```@docs -SamplingContext DefaultContext PrefixContext ConditionContext @@ -486,15 +493,7 @@ DynamicPPL.init ### Samplers -In DynamicPPL two samplers are defined that are used to initialize unobserved random variables: -[`SampleFromPrior`](@ref) which samples from the prior distribution, and [`SampleFromUniform`](@ref) which samples from a uniform distribution. - -```@docs -SampleFromPrior -SampleFromUniform -``` - -Additionally, a generic sampler for inference is implemented. +In DynamicPPL a generic sampler for inference is implemented. ```@docs Sampler @@ -520,9 +519,3 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va DynamicPPL.Experimental.determine_suitable_varinfo DynamicPPL.Experimental.is_suitable_varinfo ``` - -### [Model-Internal Functions](@id model_internal) - -```@docs -tilde_assume -``` diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index d592e76b3..0088f8908 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -8,8 +8,6 @@ else using ..EnzymeCore end -@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true - # Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme # only checks whether such a method exists, and never runs it. @inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.istrans), args...) = nothing diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 6a01884a9..edf44439e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -96,18 +96,16 @@ export AbstractVarInfo, values_as_in_model, # Samplers Sampler, - SampleFromPrior, - SampleFromUniform, # LogDensityFunction LogDensityFunction, # Contexts contextualize, - SamplingContext, DefaultContext, PrefixContext, ConditionContext, - assume, - tilde_assume, + # Tilde pipeline + tilde_assume!!, + tilde_observe!!, # Initialisation InitContext, AbstractInitStrategy, diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 786d7c913..a8f2d57e6 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -1,42 +1,37 @@ -# assume """ - tilde_assume(context::SamplingContext, right, vn, vi) - -Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the sampled value with a context associated -with a sampler. - -Falls back to -```julia -tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) -``` + DynamicPPL.tilde_assume!!( + context::AbstractContext, + right::Distribution, + vn::VarName, + vi::AbstractVarInfo + ) + +Handle assumed variables, i.e. anything which is not observed (see +[`tilde_observe!!`](@ref)). Accumulate the associated log probability, and return the +sampled value and updated `vi`. + +`vn` is the VarName on the left-hand side of the tilde statement. """ -function tilde_assume(context::SamplingContext, right, vn, vi) - return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) -end - -function tilde_assume(context::AbstractContext, args...) - return tilde_assume(childcontext(context), args...) -end -function tilde_assume(::DefaultContext, right, vn, vi) - return assume(right, vn, vi) -end - -function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...) - return tilde_assume(rng, childcontext(context), args...) -end -function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi) - return assume(rng, sampler, right, vn, vi) +function tilde_assume!!( + context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) + return tilde_assume!!(childcontext(context), right, vn, vi) end -function tilde_assume(::DefaultContext, sampler, right, vn, vi) - # same as above but no rng - return assume(Random.default_rng(), sampler, right, vn, vi) +function tilde_assume!!( + ::DefaultContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) + y = getindex_internal(vi, vn) + f = from_maybe_linked_internal_transform(vi, vn, right) + x, inv_logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) + return x, vi end - -function tilde_assume(context::PrefixContext, right, vn, vi) +function tilde_assume!!( + context::PrefixContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) # Note that we can't use something like this here: # new_vn = prefix(context, vn) - # return tilde_assume(childcontext(context), right, new_vn, vi) + # return tilde_assume!!(childcontext(context), right, new_vn, vi) # This is because `prefix` applies _all_ prefixes in a given context to a # variable name. Thus, if we had two levels of nested prefixes e.g. # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the @@ -44,50 +39,64 @@ function tilde_assume(context::PrefixContext, right, vn, vi) # would apply the prefix `b._`, resulting in `b.a.b._`. # This is why we need a special function, `prefix_and_strip_contexts`. new_vn, new_context = prefix_and_strip_contexts(context, vn) - return tilde_assume(new_context, right, new_vn, vi) + return tilde_assume!!(new_context, right, new_vn, vi) end -function tilde_assume( - rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi +""" + DynamicPPL.tilde_assume!!( + context::AbstractContext, + right::DynamicPPL.Submodel, + vn::VarName, + vi::AbstractVarInfo + ) + +Evaluate the submodel with the given context. +""" +function tilde_assume!!( + context::AbstractContext, right::DynamicPPL.Submodel, vn::VarName, vi::AbstractVarInfo ) - new_vn, new_context = prefix_and_strip_contexts(context, vn) - return tilde_assume(rng, new_context, sampler, right, new_vn, vi) + return _evaluate!!(right, vi, context, vn) end """ - tilde_assume!!(context, right, vn, vi) + tilde_observe!!( + context::AbstractContext, + right::Distribution, + left, + vn::Union{VarName, Nothing}, + vi::AbstractVarInfo + ) -Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the sampled value and updated `vi`. +This function handles observed variables, which may be: -By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log -probability of `vi` with the returned value. -""" -function tilde_assume!!(context, right, vn, vi) - return if right isa DynamicPPL.Submodel - _evaluate!!(right, vi, context, vn) - else - tilde_assume(context, right, vn, vi) - end -end +- literals on the left-hand side, e.g., `3.0 ~ Normal()` +- a model input, e.g. `x ~ Normal()` in a model `@model f(x) ... end` +- a conditioned or fixed variable, e.g. `x ~ Normal()` in a model `model | (; x = 3.0)`. -# observe -""" - tilde_observe!!(context::SamplingContext, right, left, vi) +The relevant log-probability associated with the observation is computed and accumulated in +the VarInfo object `vi` (except for fixed variables, which do not contribute to the +log-probability). -Handle observed constants with a `context` associated with a sampler. +`left` is the actual value that the left-hand side evaluates to. `vn` is the VarName on the +left-hand side, or `nothing` if the left-hand side is a literal value. -Falls back to `tilde_observe!!(context.context, right, left, vi)`. +Observations of submodels are not yet supported in DynamicPPL. """ -function tilde_observe!!(context::SamplingContext, right, left, vn, vi) - return tilde_observe!!(context.context, right, left, vn, vi) -end - -function tilde_observe!!(context::AbstractContext, right, left, vn, vi) +function tilde_observe!!( + context::AbstractContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) return tilde_observe!!(childcontext(context), right, left, vn, vi) end - -# `PrefixContext` -function tilde_observe!!(context::PrefixContext, right, left, vn, vi) +function tilde_observe!!( + context::PrefixContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal # value. For the need for prefix_and_strip_contexts rather than just prefix, see the # comment in `tilde_assume!!`. @@ -98,74 +107,22 @@ function tilde_observe!!(context::PrefixContext, right, left, vn, vi) end return tilde_observe!!(new_context, right, left, new_vn, vi) end - -""" - tilde_observe!!(context, right, left, vn, vi) - -Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the observed value and updated `vi`. - -Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name -and indices; if needed, these can be accessed through this function, though. -""" -function tilde_observe!!(::DefaultContext, right, left, vn, vi) - right isa DynamicPPL.Submodel && - throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) +function tilde_observe!!( + ::DefaultContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) vi = accumulate_observe!!(vi, right, left, vn) return left, vi end - -function assume(::Random.AbstractRNG, spl::Sampler, dist) - return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") -end - -# fallback without sampler -function assume(dist::Distribution, vn::VarName, vi) - y = getindex_internal(vi, vn) - f = from_maybe_linked_internal_transform(vi, vn, dist) - x, inv_logjac = with_logabsdet_jacobian(f, y) - vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist) - return x, vi -end - -# TODO: Remove this thing. -# SampleFromPrior and SampleFromUniform -function assume( - rng::Random.AbstractRNG, - sampler::Union{SampleFromPrior,SampleFromUniform}, - dist::Distribution, - vn::VarName, - vi::VarInfoOrThreadSafeVarInfo, +function tilde_observe!!( + ::AbstractContext, + ::DynamicPPL.Submodel, + left, + vn::Union{VarName,Nothing}, + ::AbstractVarInfo, ) - if haskey(vi, vn) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") - # TODO(mhauru) Is it important to unset the flag here? The `true` allows us - # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure - # if that's okay. - unset_flag!(vi, vn, "del", true) - r = init(rng, dist, sampler) - f = to_maybe_linked_internal_transform(vi, vn, dist) - # TODO(mhauru) This should probably be call a function called setindex_internal! - vi = BangBang.setindex!!(vi, f(r), vn) - else - # Otherwise we just extract it. - r = vi[vn, dist] - end - else - r = init(rng, dist, sampler) - if istrans(vi) - f = to_linked_internal_transform(vi, vn, dist) - vi = push!!(vi, vn, f(r), dist) - # By default `push!!` sets the transformed flag to `false`. - vi = settrans!!(vi, true, vn) - else - vi = push!!(vi, vn, r, dist) - end - end - - # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. - logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) - vi = accumulate_assume!!(vi, r, logjac, vn, dist) - return r, vi + throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) end diff --git a/src/contexts.jl b/src/contexts.jl index cd9876768..439da47e5 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -47,7 +47,7 @@ effectively updating the child context. ```jldoctest julia> using DynamicPPL: DynamicTransformationContext -julia> ctx = SamplingContext(); +julia> ctx = ConditionContext((; a = 1)); julia> DynamicPPL.childcontext(ctx) DefaultContext() @@ -121,73 +121,6 @@ setleafcontext(::IsLeaf, ::IsParent, left, right) = right setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right # Contexts -""" - SamplingContext( - [rng::Random.AbstractRNG=Random.default_rng()], - [sampler::AbstractSampler=SampleFromPrior()], - [context::AbstractContext=DefaultContext()], - ) - -Create a context that allows you to sample parameters with the `sampler` when running the model. -The `context` determines how the returned log density is computed when running the model. - -See also: [`DefaultContext`](@ref) -""" -struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext - rng::R - sampler::S - context::C -end - -function SamplingContext( - rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior() -) - return SamplingContext(rng, sampler, DefaultContext()) -end - -function SamplingContext( - sampler::AbstractSampler, context::AbstractContext=DefaultContext() -) - return SamplingContext(Random.default_rng(), sampler, context) -end - -function SamplingContext(rng::Random.AbstractRNG, context::AbstractContext) - return SamplingContext(rng, SampleFromPrior(), context) -end - -function SamplingContext(context::AbstractContext) - return SamplingContext(Random.default_rng(), SampleFromPrior(), context) -end - -NodeTrait(context::SamplingContext) = IsParent() -childcontext(context::SamplingContext) = context.context -function setchildcontext(parent::SamplingContext, child) - return SamplingContext(parent.rng, parent.sampler, child) -end - -""" - hassampler(context) - -Return `true` if `context` has a sampler. -""" -hassampler(::SamplingContext) = true -hassampler(context::AbstractContext) = hassampler(NodeTrait(context), context) -hassampler(::IsLeaf, context::AbstractContext) = false -hassampler(::IsParent, context::AbstractContext) = hassampler(childcontext(context)) - -""" - getsampler(context) - -Return the sampler of the context `context`. - -This will traverse the context tree until it reaches the first [`SamplingContext`](@ref), -at which point it will return the sampler of that context. -""" -getsampler(context::SamplingContext) = context.sampler -getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context) -getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context)) -getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context") - """ struct DefaultContext <: AbstractContext end @@ -252,7 +185,7 @@ PrefixContexts removed. NOTE: This does _not_ modify any variables in any `ConditionContext` and `FixedContext` that may be present in the context stack. This is because this -function is only used in `tilde_assume`, which is lower in the tilde-pipeline +function is only used in `tilde_assume!!`, which is lower in the tilde-pipeline than `contextual_isassumption` and `contextual_isfixed` (the functions which actually use the `ConditionContext` and `FixedContext` values). Thus, by this time, any `ConditionContext`s and `FixedContext`s present have already served diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 636847117..4baca1b57 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -154,7 +154,7 @@ struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractCon end NodeTrait(::InitContext) = IsLeaf() -function tilde_assume( +function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) in_varinfo = haskey(vi, vn) @@ -191,6 +191,12 @@ function tilde_assume( return x, vi end -function tilde_observe!!(::InitContext, right, left, vn, vi) +function tilde_observe!!( + ::InitContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) return tilde_observe!!(DefaultContext(), right, left, vn, vi) end diff --git a/src/debug_utils.jl b/src/debug_utils.jl index c2be4b46b..2ec8b15a2 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -485,7 +485,7 @@ and checking if the model is consistent across runs. function has_static_constraints( rng::Random.AbstractRNG, model::Model; num_evals::Int=5, error_on_failure::Bool=false ) - new_model = DynamicPPL.contextualize(model, SamplingContext(rng, SampleFromPrior())) + new_model = DynamicPPL.contextualize(model, InitContext(rng)) results = map(1:num_evals) do _ check_model_and_trace(new_model, VarInfo(); error_on_failure=error_on_failure) end diff --git a/src/sampler.jl b/src/sampler.jl index 98b50ba55..8b49f6c3b 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -1,34 +1,3 @@ -# TODO: Make `UniformSampling` and `Prior` algs + just use `Sampler` -# That would let us use all defaults for Sampler, combine it with other samplers etc. -""" - SampleFromUniform - -Sampling algorithm that samples unobserved random variables from a uniform distribution. - -# References - -[Stan reference manual](https://mc-stan.org/docs/2_28/reference-manual/initialization.html#random-initial-values) -""" -struct SampleFromUniform <: AbstractSampler end - -""" - SampleFromPrior - -Sampling algorithm that samples unobserved random variables from their prior distribution. -""" -struct SampleFromPrior <: AbstractSampler end - -# Initializations. -init(rng, dist, ::SampleFromPrior) = rand(rng, dist) -function init(rng, dist, ::SampleFromUniform) - return istransformable(dist) ? inittrans(rng, dist) : rand(rng, dist) -end - -init(rng, dist, ::SampleFromPrior, n::Int) = rand(rng, dist, n) -function init(rng, dist, ::SampleFromUniform, n::Int) - return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n) -end - # TODO(mhauru) Could we get rid of Sampler now that it's just a wrapper around `alg`? # (Selector has been removed). """ @@ -49,20 +18,6 @@ struct Sampler{T} <: AbstractSampler alg::T end -# AbstractMCMC interface for SampleFromUniform and SampleFromPrior -function AbstractMCMC.step( - rng::Random.AbstractRNG, - model::Model, - sampler::Union{SampleFromUniform,SampleFromPrior}, - state=nothing; - kwargs..., -) - vi = VarInfo() - strategy = sampler isa SampleFromPrior ? InitFromPrior() : InitFromUniform() - _, new_vi = DynamicPPL.init!!(rng, model, vi, strategy) - return new_vi, nothing -end - """ default_varinfo(rng, model, sampler) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 27365e4dc..f430755e7 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -466,25 +466,6 @@ function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) return SimpleVarInfo(values, accs, transformation) end -# Context implementations -# NOTE: Evaluations, i.e. those without `rng` are shared with other -# implementations of `AbstractVarInfo`. -function assume( - rng::Random.AbstractRNG, - sampler::Union{SampleFromPrior,SampleFromUniform}, - dist::Distribution, - vn::VarName, - vi::SimpleOrThreadSafeSimple, -) - value = init(rng, dist, sampler) - # Transform if we're working in unconstrained space. - f = to_maybe_linked_internal_transform(vi, vn, dist) - value_raw, logjac = with_logabsdet_jacobian(f, value) - vi = BangBang.push!!(vi, vn, value_raw, dist) - vi = accumulate_assume!!(vi, value, logjac, vn, dist) - return value, vi -end - function settrans!!(vi::SimpleVarInfo, trans) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) end diff --git a/src/transforming.jl b/src/transforming.jl index 56f861cff..589dca031 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -12,8 +12,11 @@ how to do the transformation, used by e.g. `SimpleVarInfo`. struct DynamicTransformationContext{isinverse} <: AbstractContext end NodeTrait(::DynamicTransformationContext) = IsLeaf() -function tilde_assume( - ::DynamicTransformationContext{isinverse}, right, vn, vi +function tilde_assume!!( + ::DynamicTransformationContext{isinverse}, + right::Distribution, + vn::VarName, + vi::AbstractVarInfo, ) where {isinverse} # vi[vn, right] always provides the value in unlinked space. x = vi[vn, right] @@ -31,7 +34,13 @@ function tilde_assume( return x, vi end -function tilde_observe!!(::DynamicTransformationContext, right, left, vn, vi) +function tilde_observe!!( + ::DynamicTransformationContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) return tilde_observe!!(DefaultContext(), right, left, vn, vi) end diff --git a/src/utils.jl b/src/utils.jl index c7d1e089f..a4c5f4a1b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -456,50 +456,6 @@ function recombine(d::MultivariateDistribution, val::AbstractVector, n::Int) return copy(reshape(val, length(d), n)) end -# Uniform random numbers with range 4 for robust initializations -# Reference: https://mc-stan.org/docs/2_19/reference-manual/initialization.html -randrealuni(rng::Random.AbstractRNG) = 4 * rand(rng) - 2 -randrealuni(rng::Random.AbstractRNG, args...) = 4 .* rand(rng, args...) .- 2 - -istransformable(dist) = link_transform(dist) !== identity - -################################# -# Single-sample initialisations # -################################# - -inittrans(rng, dist::UnivariateDistribution) = Bijectors.invlink(dist, randrealuni(rng)) -function inittrans(rng, dist::MultivariateDistribution) - # Get the length of the unconstrained vector - b = link_transform(dist) - d = Bijectors.output_length(b, length(dist)) - return Bijectors.invlink(dist, randrealuni(rng, d)) -end -function inittrans(rng, dist::MatrixDistribution) - # Get the size of the unconstrained vector - b = link_transform(dist) - sz = Bijectors.output_size(b, size(dist)) - return Bijectors.invlink(dist, randrealuni(rng, sz...)) -end -function inittrans(rng, dist::Distribution{CholeskyVariate}) - # Get the size of the unconstrained vector - b = link_transform(dist) - sz = Bijectors.output_size(b, size(dist)) - return Bijectors.invlink(dist, randrealuni(rng, sz...)) -end -################################ -# Multi-sample initialisations # -################################ - -function inittrans(rng, dist::UnivariateDistribution, n::Int) - return Bijectors.invlink(dist, randrealuni(rng, n)) -end -function inittrans(rng, dist::MultivariateDistribution, n::Int) - return Bijectors.invlink(dist, randrealuni(rng, size(dist)[1], n)) -end -function inittrans(rng, dist::MatrixDistribution, n::Int) - return Bijectors.invlink(dist, [randrealuni(rng, size(dist)...) for _ in 1:n]) -end - ####################### # Convenience methods # ####################### diff --git a/test/Project.toml b/test/Project.toml index 537214464..3e7ffcaea 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,7 +11,6 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -39,7 +38,6 @@ DifferentiationInterface = "0.6.41, 0.7" Distributions = "0.25" DistributionsAD = "0.6.3" Documenter = "1" -EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10.12, 1" JET = "0.9 - 0.10.6" LogDensityProblems = "2" diff --git a/test/ad.jl b/test/ad.jl index 0e5d8d7cf..23e676ee7 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -77,49 +77,6 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest end end - @testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin - # Failing model - t = 1:0.05:8 - σ = 0.3 - y = @. rand(sin(t) + Normal(0, σ)) - @model function state_space(y, TT, ::Type{T}=Float64) where {T} - # Priors - α ~ Normal(y[1], 0.001) - τ ~ Exponential(1) - η ~ filldist(Normal(0, 1), TT - 1) - σ ~ Exponential(1) - # create latent variable - x = Vector{T}(undef, TT) - x[1] = α - for t in 2:TT - x[t] = x[t - 1] + η[t - 1] * τ - end - # measurement model - y ~ MvNormal(x, σ^2 * I) - return x - end - model = state_space(y, length(t)) - - # Dummy sampling algorithm for testing. The test case can only be replicated - # with a custom sampler, it doesn't work with SampleFromPrior(). We need to - # overload assume so that model evaluation doesn't fail due to a lack - # of implementation - struct MyEmptyAlg end - DynamicPPL.assume( - ::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi - ) = DynamicPPL.assume(dist, vn, vi) - - # Compiling the ReverseDiff tape used to fail here - spl = Sampler(MyEmptyAlg()) - vi = DynamicPPL.link!!(VarInfo(model), model) - sampling_model = contextualize(model, SamplingContext(model.context)) - ldf = LogDensityFunction( - sampling_model, getlogjoint_internal, vi; adtype=AutoReverseDiff(; compile=true) - ) - x = ldf.varinfo[:] - @test LogDensityProblems.logdensity_and_gradient(ldf, x) isa Any - end - # Test that various different ways of specifying array types as arguments work with all # ADTypes. @testset "Array argument types" begin diff --git a/test/contexts.jl b/test/contexts.jl index 1a6279bf4..2687c4336 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -24,8 +24,6 @@ using DynamicPPL: using LinearAlgebra: I using Random: Xoshiro -using EnzymeCore - # TODO: Should we maybe put this in DPPL itself? function Base.iterate(context::AbstractContext) if NodeTrait(context) isa IsLeaf @@ -150,11 +148,11 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() vn = @varname(x[1]) ctx1 = PrefixContext(@varname(a)) @test DynamicPPL.prefix(ctx1, vn) == @varname(a.x[1]) - ctx2 = SamplingContext(ctx1) + ctx2 = ConditionContext(Dict{VarName,Any}(), ctx1) @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) ctx3 = PrefixContext(@varname(b), ctx2) @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) - ctx4 = DynamicPPL.SamplingContext(ctx3) + ctx4 = FixedContext(Dict(), ctx3) @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) end @@ -203,22 +201,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end - @testset "SamplingContext" begin - context = SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()) - @test context isa SamplingContext - - # convenience constructors - @test SamplingContext() == context - @test SamplingContext(Random.default_rng()) == context - @test SamplingContext(SampleFromPrior()) == context - @test SamplingContext(DefaultContext()) == context - @test SamplingContext(Random.default_rng(), SampleFromPrior()) == context - @test SamplingContext(Random.default_rng(), DefaultContext()) == context - @test SamplingContext(SampleFromPrior(), DefaultContext()) == context - @test SamplingContext(SampleFromPrior(), DefaultContext()) == context - @test EnzymeCore.EnzymeRules.inactive_type(typeof(context)) - end - @testset "ConditionContext" begin @testset "Nesting" begin @testset "NamedTuple" begin diff --git a/test/debug_utils.jl b/test/debug_utils.jl index 5bf741ff3..f950f6b45 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -149,7 +149,7 @@ model = demo_missing_in_multivariate([1.0, missing]) # Have to run this check_model call with an empty varinfo, because actually # instantiating the VarInfo would cause it to throw a MethodError. - model = contextualize(model, SamplingContext()) + model = contextualize(model, InitContext()) @test_throws ErrorException check_model(model, VarInfo(); error_on_failure=true) end diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 692f53911..b34424a1c 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -64,6 +64,12 @@ @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS # Use debug logging below. varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) + # Check that the inferred varinfo is indeed suitable for evaluation + f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, varinfo + ) + JET.test_call(f_eval, argtypes_eval) + # For our demo models, they should all result in typed. is_typed = varinfo isa DynamicPPL.NTVarInfo @test is_typed diff --git a/test/lkj.jl b/test/lkj.jl index d581cd21b..5c5603aba 100644 --- a/test/lkj.jl +++ b/test/lkj.jl @@ -16,20 +16,15 @@ end # Same for both distributions target_mean = vec(Matrix{Float64}(I, 2, 2)) +n_samples = 1000 _lkj_atol = 0.05 @testset "Sample from x ~ LKJ(2, 1)" begin model = lkj_prior_demo() - # `SampleFromPrior` will sample in constrained space. - @testset "SampleFromPrior" begin - samples = sample(model, SampleFromPrior(), 1_000; progress=false) - @test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol = - _lkj_atol - end - - # `SampleFromUniform` will sample in unconstrained space. - @testset "SampleFromUniform" begin - samples = sample(model, SampleFromUniform(), 1_000; progress=false) + for init_strategy in [InitFromPrior(), InitFromUniform()] + samples = [ + last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples + ] @test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol = _lkj_atol end @@ -37,21 +32,10 @@ end @testset "Sample from x ~ LKJCholesky(2, 1, $(uplo))" for uplo in ['U', 'L'] model = lkj_chol_prior_demo(uplo) - # `SampleFromPrior` will sample in unconstrained space. - @testset "SampleFromPrior" begin - samples = sample(model, SampleFromPrior(), 1_000; progress=false) - # Build correlation matrix from factor - corr_matrices = map(samples) do s - M = reshape(s.metadata.vals, (2, 2)) - pd_from_triangular(M, uplo) - end - @test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol - end - - # `SampleFromUniform` will sample in unconstrained space. - @testset "SampleFromUniform" begin - samples = sample(model, SampleFromUniform(), 1_000; progress=false) - # Build correlation matrix from factor + for init_strategy in [InitFromPrior(), InitFromUniform()] + samples = [ + last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples + ] corr_matrices = map(samples) do s M = reshape(s.metadata.vals, (2, 2)) pd_from_triangular(M, uplo) diff --git a/test/sampler.jl b/test/sampler.jl index c812de938..5380ad17e 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -113,60 +113,6 @@ end end - @testset "SampleFromPrior and SampleUniform" begin - @model function gdemo(x, y) - s ~ InverseGamma(2, 3) - m ~ Normal(2.0, sqrt(s)) - x ~ Normal(m, sqrt(s)) - return y ~ Normal(m, sqrt(s)) - end - - model = gdemo(1.0, 2.0) - N = 1_000 - - chains = sample(model, SampleFromPrior(), N; progress=false) - @test chains isa Vector{<:VarInfo} - @test length(chains) == N - - # Expected value of ``X`` where ``X ~ N(2, ...)`` is 2. - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.15 - - # Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3. - @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2 - - chains = sample(model, SampleFromUniform(), N; progress=false) - @test chains isa Vector{<:VarInfo} - @test length(chains) == N - - # `m` is Gaussian, i.e. no transformation is used, so it - # will be drawn from U[-2, 2] and its mean should be 0. - @test mean(vi[@varname(m)] for vi in chains) ≈ 0.0 atol = 0.1 - - # Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8. - @test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1 - end - - @testset "init" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - N = 1000 - chain_init = sample(model, SampleFromUniform(), N; progress=false) - - for vn in keys(first(chain_init)) - if AbstractPPL.subsumes(@varname(s), vn) - # `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2]. - dist = InverseGamma(2, 3) - b = DynamicPPL.link_transform(dist) - @test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11 - elseif AbstractPPL.subsumes(@varname(m), vn) - # `m ~ Normal(0, sqrt(s))` and its constrained value is the same. - @test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11 - else - error("Unknown variable name: $vn") - end - end - end - end - @testset "Initial parameters" begin # dummy algorithm that just returns initial value and does not perform any sampling abstract type OnlyInitAlg end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 0421c89e2..522730566 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -68,8 +68,7 @@ @time model(vi) # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) + DynamicPPL.evaluate_threadsafe!!(model, vi) @test getlogjoint(vi) ≈ lp_w_threads # check that it's wrapped during the model evaluation @test vi_ isa DynamicPPL.ThreadSafeVarInfo @@ -77,7 +76,7 @@ @test vi isa VarInfo println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) + @time DynamicPPL.evaluate_threadsafe!!(model, vi) @model function wothreads(x) global vi_ = __varinfo__ @@ -104,13 +103,12 @@ @test lp_w_threads ≈ lp_wo_threads # Ensure that we use `VarInfo`. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) + DynamicPPL.evaluate_threadunsafe!!(model, vi) @test getlogjoint(vi) ≈ lp_w_threads @test vi_ isa VarInfo @test vi isa VarInfo println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) + @time DynamicPPL.evaluate_threadunsafe!!(model, vi) end end