diff --git a/HISTORY.md b/HISTORY.md index 24c0df3d0..f181897f7 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,26 @@ # DynamicPPL Changelog +## 0.39.0 + +### Breaking changes + +#### Parent and leaf contexts + +The `DynamicPPL.NodeTrait` function has been removed. +Instead of implementing this, parent contexts should subtype `DynamicPPL.AbstractParentContext`. +This is an abstract type which requires you to overload two functions, `DynamicPPL.childcontext` and `DynamicPPL.setchildcontext`. + +There should generally be few reasons to define your own parent contexts (the only one we are aware of, outside of DynamicPPL itself, is `Turing.Inference.GibbsContext`), so this change should not really affect users. + +Leaf contexts require no changes, apart from a removal of the `NodeTrait` function. + +`ConditionContext` and `PrefixContext` are no longer exported. +You should not need to use these directly, please use `AbstractPPL.condition` and `DynamicPPL.prefix` instead. + +#### Miscellaneous + +Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. + ## 0.38.9 Remove warning when using Enzyme as the AD backend. diff --git a/Project.toml b/Project.toml index 23f5eec5b..1b5e52492 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.38.9" +version = "0.39.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 0d4e9a654..55ca81da0 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -23,7 +23,7 @@ DynamicPPL = {path = "../"} ADTypes = "1.14.0" BenchmarkTools = "1.6.0" Distributions = "0.25.117" -DynamicPPL = "0.38" +DynamicPPL = "0.39" Enzyme = "0.13" ForwardDiff = "0.10.38, 1" LogDensityProblems = "2.1.2" diff --git a/docs/Project.toml b/docs/Project.toml index 03a3ff0a0..10a4a5c8a 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -21,7 +21,7 @@ Accessors = "0.1" Distributions = "0.25" Documenter = "1" DocumenterMermaid = "0.1, 0.2" -DynamicPPL = "0.38" +DynamicPPL = "0.39" FillArrays = "0.13, 1" ForwardDiff = "0.10, 1" JET = "0.9, 0.10, 0.11" diff --git a/docs/src/api.md b/docs/src/api.md index bbe39fb73..63dafdfca 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -352,13 +352,6 @@ Base.empty! SimpleVarInfo ``` -### Tilde-pipeline - -```@docs -tilde_assume!! -tilde_observe!! -``` - ### Accumulators The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. @@ -463,15 +456,48 @@ By default, it does not perform any actual sampling: it only evaluates the model If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this. The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. -Contexts are subtypes of `AbstractPPL.AbstractContext`. + +All contexts are subtypes of `AbstractPPL.AbstractContext`. + +Contexts are split into two kinds: + +**Leaf contexts**: These are the most important contexts as they ultimately decide how model evaluation proceeds. +For example, `DefaultContext` evaluates the model using values stored inside a VarInfo's metadata, whereas `InitContext` obtains new values either by sampling or from a known set of parameters. +DynamicPPL has more leaf contexts which are used for internal purposes, but these are the two that are exported. ```@docs DefaultContext -PrefixContext -ConditionContext InitContext ``` +To implement a leaf context, you need to subtype `AbstractPPL.AbstractContext` and implement the `tilde_assume!!` and `tilde_observe!!` methods for your context. + +```@docs +tilde_assume!! +tilde_observe!! +``` + +**Parent contexts**: These essentially act as 'modifiers' for leaf contexts. +For example, `PrefixContext` adds a prefix to all variable names during evaluation, while `ConditionContext` marks certain variables as observed. + +To implement a parent context, you have to subtype `DynamicPPL.AbstractParentContext`, and implement the `childcontext` and `setchildcontext` methods. +If needed, you can also implement `tilde_assume!!` and `tilde_observe!!` for your context. +This is optional; the default implementation is to simply delegate to the child context. + +```@docs +AbstractParentContext +childcontext +setchildcontext +``` + +Since contexts form a tree structure, these functions are automatically defined for manipulating context stacks. +They are mainly useful for modifying the fundamental behaviour (i.e. the leaf context), without affecting any of the modifiers (i.e. parent contexts). + +```@docs +leafcontext +setleafcontext +``` + ### VarInfo initialisation The function `init!!` is used to initialise, or overwrite, values in a VarInfo. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e66f3fe11..c43bd89d5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -94,16 +94,21 @@ export AbstractVarInfo, values_as_in_model, # LogDensityFunction LogDensityFunction, - # Contexts + # Leaf contexts + AbstractContext, contextualize, DefaultContext, - PrefixContext, - ConditionContext, + InitContext, + # Parent contexts + AbstractParentContext, + childcontext, + setchildcontext, + leafcontext, + setleafcontext, # Tilde pipeline tilde_assume!!, tilde_observe!!, # Initialisation - InitContext, AbstractInitStrategy, InitFromPrior, InitFromUniform, diff --git a/src/contexts.jl b/src/contexts.jl index 32a236e8e..46c5b8855 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,48 +1,32 @@ """ - NodeTrait(context) - NodeTrait(f, context) + AbstractParentContext -Specifies the role of `context` in the context-tree. +An abstract context that has a child context. -The officially supported traits are: -- `IsLeaf`: `context` does not have any decendants. -- `IsParent`: `context` has a child context to which we often defer. - Expects the following methods to be implemented: - - [`childcontext`](@ref) - - [`setchildcontext`](@ref) -""" -abstract type NodeTrait end -NodeTrait(_, context) = NodeTrait(context) - -""" - IsLeaf - -Specifies that the context is a leaf in the context-tree. -""" -struct IsLeaf <: NodeTrait end -""" - IsParent +Subtypes of `AbstractParentContext` must implement the following interface: -Specifies that the context is a parent in the context-tree. +- `DynamicPPL.childcontext(context::AbstractParentContext)`: Return the child context. +- `DynamicPPL.setchildcontext(parent::AbstractParentContext, child::AbstractContext)`: Reconstruct + `parent` but now using `child` as its child context. """ -struct IsParent <: NodeTrait end +abstract type AbstractParentContext <: AbstractContext end """ - childcontext(context) + childcontext(context::AbstractParentContext) Return the descendant context of `context`. """ childcontext """ - setchildcontext(parent::AbstractContext, child::AbstractContext) + setchildcontext(parent::AbstractParentContext, child::AbstractContext) Reconstruct `parent` but now using `child` is its [`childcontext`](@ref), effectively updating the child context. # Examples ```jldoctest -julia> using DynamicPPL: DynamicTransformationContext +julia> using DynamicPPL: DynamicTransformationContext, ConditionContext julia> ctx = ConditionContext((; a = 1)); @@ -60,12 +44,11 @@ setchildcontext """ leafcontext(context::AbstractContext) -Return the leaf of `context`, i.e. the first descendant context that `IsLeaf`. +Return the leaf of `context`, i.e. the first descendant context that is not an +`AbstractParentContext`. """ -leafcontext(context::AbstractContext) = - leafcontext(NodeTrait(leafcontext, context), context) -leafcontext(::IsLeaf, context::AbstractContext) = context -leafcontext(::IsParent, context::AbstractContext) = leafcontext(childcontext(context)) +leafcontext(context::AbstractContext) = context +leafcontext(context::AbstractParentContext) = leafcontext(childcontext(context)) """ setleafcontext(left::AbstractContext, right::AbstractContext) @@ -80,12 +63,10 @@ original leaf context of `left`. ```jldoctest julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, DynamicTransformationContext -julia> struct ParentContext{C} <: AbstractContext +julia> struct ParentContext{C} <: AbstractParentContext context::C end -julia> DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() - julia> DynamicPPL.childcontext(context::ParentContext) = context.context julia> DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) @@ -104,21 +85,10 @@ julia> # Append another parent context. ParentContext(ParentContext(ParentContext(DefaultContext()))) ``` """ -function setleafcontext(left::AbstractContext, right::AbstractContext) - return setleafcontext( - NodeTrait(setleafcontext, left), NodeTrait(setleafcontext, right), left, right - ) -end -function setleafcontext( - ::IsParent, ::IsParent, left::AbstractContext, right::AbstractContext -) +function setleafcontext(left::AbstractParentContext, right::AbstractContext) return setchildcontext(left, setleafcontext(childcontext(left), right)) end -function setleafcontext(::IsParent, ::IsLeaf, left::AbstractContext, right::AbstractContext) - return setchildcontext(left, setleafcontext(childcontext(left), right)) -end -setleafcontext(::IsLeaf, ::IsParent, left::AbstractContext, right::AbstractContext) = right -setleafcontext(::IsLeaf, ::IsLeaf, left::AbstractContext, right::AbstractContext) = right +setleafcontext(::AbstractContext, right::AbstractContext) = right """ DynamicPPL.tilde_assume!!( @@ -138,10 +108,15 @@ This function should return a tuple `(x, vi)`, where `x` is the sampled value (w must be in unlinked space!) and `vi` is the updated VarInfo. """ function tilde_assume!!( - context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo + context::AbstractParentContext, right::Distribution, vn::VarName, vi::AbstractVarInfo ) return tilde_assume!!(childcontext(context), right, vn, vi) end +function tilde_assume!!( + context::AbstractContext, ::Distribution, ::VarName, ::AbstractVarInfo +) + return error("tilde_assume!! not implemented for context of type $(typeof(context))") +end """ DynamicPPL.tilde_observe!!( @@ -171,7 +146,7 @@ This function should return a tuple `(left, vi)`, where `left` is the same as th `vi` is the updated VarInfo. """ function tilde_observe!!( - context::AbstractContext, + context::AbstractParentContext, right::Distribution, left, vn::Union{VarName,Nothing}, @@ -179,3 +154,12 @@ function tilde_observe!!( ) return tilde_observe!!(childcontext(context), right, left, vn, vi) end +function tilde_observe!!( + context::AbstractContext, + ::Distribution, + ::Any, + ::Union{VarName,Nothing}, + ::AbstractVarInfo, +) + return error("tilde_observe!! not implemented for context of type $(typeof(context))") +end diff --git a/src/contexts/conditionfix.jl b/src/contexts/conditionfix.jl index d3802de85..7a34db5cb 100644 --- a/src/contexts/conditionfix.jl +++ b/src/contexts/conditionfix.jl @@ -11,7 +11,7 @@ when there are varnames that cannot be represented as symbols, e.g. """ struct ConditionContext{ Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext -} <: AbstractContext +} <: AbstractParentContext values::Values context::Ctx end @@ -41,9 +41,10 @@ function Base.show(io::IO, context::ConditionContext) return print(io, "ConditionContext($(context.values), $(childcontext(context)))") end -NodeTrait(::ConditionContext) = IsParent() childcontext(context::ConditionContext) = context.context -setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) +function setchildcontext(parent::ConditionContext, child::AbstractContext) + return ConditionContext(parent.values, child) +end """ hasconditioned(context::AbstractContext, vn::VarName) @@ -76,11 +77,8 @@ Return `true` if `vn` is found in `context` or any of its descendants. This is contrast to [`hasconditioned(::AbstractContext, ::VarName)`](@ref) which only checks for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. """ -function hasconditioned_nested(context::AbstractContext, vn) - return hasconditioned_nested(NodeTrait(hasconditioned_nested, context), context, vn) -end -hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) -function hasconditioned_nested(::IsParent, context, vn) +hasconditioned_nested(context::AbstractContext, vn) = hasconditioned(context, vn) +function hasconditioned_nested(context::AbstractParentContext, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end function hasconditioned_nested(context::PrefixContext, vn) @@ -96,15 +94,12 @@ This is contrast to [`getconditioned`](@ref) which only returns the value `vn` i not recursively looking into its descendants. """ function getconditioned_nested(context::AbstractContext, vn) - return getconditioned_nested(NodeTrait(getconditioned_nested, context), context, vn) -end -function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getconditioned_nested(context::PrefixContext, vn) return getconditioned_nested(collapse_prefix_stack(context), vn) end -function getconditioned_nested(::IsParent, context, vn) +function getconditioned_nested(context::AbstractParentContext, vn) return if hasconditioned(context, vn) getconditioned(context, vn) else @@ -113,7 +108,7 @@ function getconditioned_nested(::IsParent, context, vn) end """ - decondition(context::AbstractContext, syms...) + decondition_context(context::AbstractContext, syms...) Return `context` but with `syms` no longer conditioned on. @@ -121,13 +116,10 @@ Note that this recursively traverses contexts, deconditioning all along the way. See also: [`condition`](@ref) """ -decondition_context(::IsLeaf, context, args...) = context -function decondition_context(::IsParent, context, args...) +decondition_context(context::AbstractContext, args...) = context +function decondition_context(context::AbstractParentContext, args...) return setchildcontext(context, decondition_context(childcontext(context), args...)) end -function decondition_context(context, args...) - return decondition_context(NodeTrait(context), context, args...) -end function decondition_context(context::ConditionContext) return decondition_context(childcontext(context)) end @@ -160,11 +152,8 @@ Return `NamedTuple` of values that are conditioned on under context`. Note that this will recursively traverse the context stack and return a merged version of the condition values. """ -function conditioned(context::AbstractContext) - return conditioned(NodeTrait(conditioned, context), context) -end -conditioned(::IsLeaf, context) = NamedTuple() -conditioned(::IsParent, context) = conditioned(childcontext(context)) +conditioned(::AbstractContext) = NamedTuple() +conditioned(context::AbstractParentContext) = conditioned(childcontext(context)) function conditioned(context::ConditionContext) # Note the order of arguments to `merge`. The behavior of the rest of DPPL # is that the outermost `context` takes precendence, hence when resolving @@ -176,7 +165,7 @@ function conditioned(context::PrefixContext) return conditioned(collapse_prefix_stack(context)) end -struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext +struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractParentContext values::Values context::Ctx end @@ -197,16 +186,17 @@ function Base.show(io::IO, context::FixedContext) return print(io, "FixedContext($(context.values), $(childcontext(context)))") end -NodeTrait(::FixedContext) = IsParent() childcontext(context::FixedContext) = context.context -setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child) +function setchildcontext(parent::FixedContext, child::AbstractContext) + return FixedContext(parent.values, child) +end """ hasfixed(context::AbstractContext, vn::VarName) Return `true` if a fixed value for `vn` is found in `context`. """ -hasfixed(context::AbstractContext, vn::VarName) = false +hasfixed(::AbstractContext, ::VarName) = false hasfixed(context::FixedContext, vn::VarName) = hasvalue(context.values, vn) function hasfixed(context::FixedContext, vns::AbstractArray{<:VarName}) return all(Base.Fix1(hasvalue, context.values), vns) @@ -230,11 +220,8 @@ Return `true` if a fixed value for `vn` is found in `context` or any of its desc This is contrast to [`hasfixed(::AbstractContext, ::VarName)`](@ref) which only checks for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. """ -function hasfixed_nested(context::AbstractContext, vn) - return hasfixed_nested(NodeTrait(hasfixed_nested, context), context, vn) -end -hasfixed_nested(::IsLeaf, context, vn) = hasfixed(context, vn) -function hasfixed_nested(::IsParent, context, vn) +hasfixed_nested(context::AbstractContext, vn) = hasfixed(context, vn) +function hasfixed_nested(context::AbstractParentContext, vn) return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) end function hasfixed_nested(context::PrefixContext, vn) @@ -250,15 +237,12 @@ This is contrast to [`getfixed`](@ref) which only returns the value `vn` in `con not recursively looking into its descendants. """ function getfixed_nested(context::AbstractContext, vn) - return getfixed_nested(NodeTrait(getfixed_nested, context), context, vn) -end -function getfixed_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getfixed_nested(context::PrefixContext, vn) return getfixed_nested(collapse_prefix_stack(context), vn) end -function getfixed_nested(::IsParent, context, vn) +function getfixed_nested(context::AbstractParentContext, vn) return if hasfixed(context, vn) getfixed(context, vn) else @@ -283,7 +267,7 @@ end function fix(values::NTuple{<:Any,<:Pair{<:VarName}}) return fix(DefaultContext(), values) end -fix(context::AbstractContext, values::NamedTuple{()}) = context +fix(context::AbstractContext, ::NamedTuple{()}) = context function fix(context::AbstractContext, values::Union{AbstractDict,NamedTuple}) return FixedContext(values, context) end @@ -306,13 +290,10 @@ Note that this recursively traverses contexts, unfixing all along the way. See also: [`fix`](@ref) """ -unfix(::IsLeaf, context, args...) = context -function unfix(::IsParent, context, args...) +unfix(context::AbstractContext, args...) = context +function unfix(context::AbstractParentContext, args...) return setchildcontext(context, unfix(childcontext(context), args...)) end -function unfix(context, args...) - return unfix(NodeTrait(context), context, args...) -end function unfix(context::FixedContext) return unfix(childcontext(context)) end @@ -341,9 +322,8 @@ Return the values that are fixed under `context`. Note that this will recursively traverse the context stack and return a merged version of the fix values. """ -fixed(context::AbstractContext) = fixed(NodeTrait(fixed, context), context) -fixed(::IsLeaf, context) = NamedTuple() -fixed(::IsParent, context) = fixed(childcontext(context)) +fixed(::AbstractContext) = NamedTuple() +fixed(context::AbstractParentContext) = fixed(childcontext(context)) function fixed(context::FixedContext) # Note the order of arguments to `merge`. The behavior of the rest of DPPL # is that the outermost `context` takes precendence, hence when resolving @@ -374,7 +354,7 @@ topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_co which explains this in much more detail. ```jldoctest -julia> using DynamicPPL: collapse_prefix_stack +julia> using DynamicPPL: collapse_prefix_stack, PrefixContext, ConditionContext julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); @@ -403,11 +383,8 @@ function collapse_prefix_stack(context::PrefixContext) # depth of the context stack. return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) end -function collapse_prefix_stack(context::AbstractContext) - return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) -end -collapse_prefix_stack(::IsLeaf, context) = context -function collapse_prefix_stack(::IsParent, context) +collapse_prefix_stack(context::AbstractContext) = context +function collapse_prefix_stack(context::AbstractParentContext) new_child_context = collapse_prefix_stack(childcontext(context)) return setchildcontext(context, new_child_context) end @@ -448,19 +425,10 @@ function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) return FixedContext(prefixed_vn_dict, prefixed_child_ctx) end -function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) - return prefix_cond_and_fixed_variables( - NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix - ) -end -function prefix_cond_and_fixed_variables( - ::IsLeaf, context::AbstractContext, prefix::VarName -) +function prefix_cond_and_fixed_variables(context::AbstractContext, ::VarName) return context end -function prefix_cond_and_fixed_variables( - ::IsParent, context::AbstractContext, prefix::VarName -) +function prefix_cond_and_fixed_variables(context::AbstractParentContext, prefix::VarName) return setchildcontext( context, prefix_cond_and_fixed_variables(childcontext(context), prefix) ) diff --git a/src/contexts/default.jl b/src/contexts/default.jl index ec21e1a56..3cafe39f1 100644 --- a/src/contexts/default.jl +++ b/src/contexts/default.jl @@ -17,7 +17,6 @@ with `DefaultContext` means 'calculating the log-probability associated with the in the `AbstractVarInfo`'. """ struct DefaultContext <: AbstractContext end -NodeTrait(::DefaultContext) = IsLeaf() """ DynamicPPL.tilde_assume!!( diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 44dbc5508..83507353f 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -150,7 +150,6 @@ struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractCon return InitContext(Random.default_rng(), strategy) end end -NodeTrait(::InitContext) = IsLeaf() function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo diff --git a/src/contexts/prefix.jl b/src/contexts/prefix.jl index 24615e683..45307874a 100644 --- a/src/contexts/prefix.jl +++ b/src/contexts/prefix.jl @@ -13,7 +13,7 @@ unique. See also: [`to_submodel`](@ref) """ -struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext +struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractParentContext vn_prefix::Tvn context::C end @@ -23,7 +23,6 @@ function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} end PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) -NodeTrait(::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context function setchildcontext(ctx::PrefixContext, child::AbstractContext) return PrefixContext(ctx.vn_prefix, child) @@ -37,11 +36,8 @@ Apply the prefixes in the context `ctx` to the variable name `vn`. function prefix(ctx::PrefixContext, vn::VarName) return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) end -function prefix(ctx::AbstractContext, vn::VarName) - return prefix(NodeTrait(ctx), ctx, vn) -end -prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn -function prefix(::IsParent, ctx::AbstractContext, vn::VarName) +prefix(::AbstractContext, vn::VarName) = vn +function prefix(ctx::AbstractParentContext, vn::VarName) return prefix(childcontext(ctx), vn) end @@ -72,11 +68,8 @@ function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) ) return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes end -function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) - return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) -end -prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) -function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) +prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) = (vn, ctx) +function prefix_and_strip_contexts(ctx::AbstractParentContext, vn::VarName) vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) return vn, setchildcontext(ctx, new_ctx) end diff --git a/src/contexts/transformation.jl b/src/contexts/transformation.jl index 5153f7857..c2eee2863 100644 --- a/src/contexts/transformation.jl +++ b/src/contexts/transformation.jl @@ -10,7 +10,6 @@ Note that some `AbstractVarInfo` types, must notably `VarInfo`, override the how to do the transformation, used by e.g. `SimpleVarInfo`. """ struct DynamicTransformationContext{isinverse} <: AbstractContext end -NodeTrait(::DynamicTransformationContext) = IsLeaf() function tilde_assume!!( ::DynamicTransformationContext{isinverse}, diff --git a/src/model.jl b/src/model.jl index edb042ba9..94fcd9fd4 100644 --- a/src/model.jl +++ b/src/model.jl @@ -427,7 +427,7 @@ Return the conditioned values in `model`. ```jldoctest julia> using Distributions -julia> using DynamicPPL: conditioned, contextualize +julia> using DynamicPPL: conditioned, contextualize, PrefixContext, ConditionContext julia> @model function demo() m ~ Normal() @@ -770,7 +770,7 @@ Return the fixed values in `model`. ```jldoctest julia> using Distributions -julia> using DynamicPPL: fixed, contextualize +julia> using DynamicPPL: fixed, contextualize, PrefixContext julia> @model function demo() m ~ Normal() @@ -1107,11 +1107,6 @@ function predict end Execute `model` with variables `keys` set to `values` and return the values returned by the `model`. - returned(model::Model, values, keys) - -Execute `model` with variables `keys` set to `values` and return the values returned by the `model`. -This method is deprecated; use the NamedTuple or AbstractDict version instead. - # Example ```jldoctest julia> using DynamicPPL, Distributions @@ -1141,6 +1136,3 @@ function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarN # We can't use new_model() because that overwrites it with an InitContext of its own. return first(evaluate!!(new_model, vi)) end -Base.@deprecate returned(model::Model, values, keys) returned( - model, NamedTuple{keys}(values) -) diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index aae2e4ec6..c48d2ddfd 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -4,11 +4,10 @@ # Utilities for testing contexts. # Dummy context to test nested behaviors. -struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext +struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractParentContext context::C end TestParentContext() = TestParentContext(DefaultContext()) -DynamicPPL.NodeTrait(::TestParentContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(context::TestParentContext) = context.context DynamicPPL.setchildcontext(::TestParentContext, child) = TestParentContext(child) function Base.show(io::IO, c::TestParentContext) @@ -25,19 +24,13 @@ This method ensures that `context` - Correctly implements the tilde-pipeline. """ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - node_trait = DynamicPPL.NodeTrait(context) - if node_trait isa DynamicPPL.IsLeaf - test_leaf_context(context, model) - elseif node_trait isa DynamicPPL.IsParent - test_parent_context(context, model) - else - error("Invalid NodeTrait: $node_trait") - end + return test_leaf_context(context, model) +end +function test_context(context::DynamicPPL.AbstractParentContext, model::DynamicPPL.Model) + return test_parent_context(context, model) end function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf - # Note that for a leaf context we can't assume that it will work with an # empty VarInfo. (For example, DefaultContext will error with empty # varinfos.) Thus we only test evaluation with VarInfos that are already @@ -57,8 +50,6 @@ function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPP end function test_parent_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent - @testset "get/set leaf and child contexts" begin # Ensure we're using a different leaf context than the current. leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext diff --git a/test/contexts.jl b/test/contexts.jl index 972d833a5..ae7332a43 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -6,10 +6,9 @@ using DynamicPPL: childcontext, setchildcontext, AbstractContext, - NodeTrait, - IsLeaf, - IsParent, + AbstractParentContext, contextual_isassumption, + PrefixContext, FixedContext, ConditionContext, decondition_context, @@ -25,22 +24,21 @@ using LinearAlgebra: I using Random: Xoshiro # TODO: Should we maybe put this in DPPL itself? +function Base.iterate(context::AbstractParentContext) + return context, childcontext(context) +end function Base.iterate(context::AbstractContext) - if NodeTrait(context) isa IsLeaf - return nothing - end - - return context, context + return context, nothing end -function Base.iterate(_::AbstractContext, context::AbstractContext) - return _iterate(NodeTrait(context), context) +function Base.iterate(::AbstractContext, state::AbstractParentContext) + return state, childcontext(state) end -_iterate(::IsLeaf, context) = nothing -function _iterate(::IsParent, context) - child = childcontext(context) - return child, child +function Base.iterate(::AbstractContext, state::AbstractContext) + return state, nothing +end +function Base.iterate(::AbstractContext, state::Nothing) + return nothing end - Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @@ -347,11 +345,11 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "collapse_prefix_stack" begin # Utility function to make sure that there are no PrefixContexts in # the context stack. - function has_no_prefixcontexts(ctx::AbstractContext) - return !(ctx isa PrefixContext) && ( - NodeTrait(ctx) isa IsLeaf || has_no_prefixcontexts(childcontext(ctx)) - ) + has_no_prefixcontexts(::PrefixContext) = false + function has_no_prefixcontexts(ctx::AbstractParentContext) + return has_no_prefixcontexts(childcontext(ctx)) end + has_no_prefixcontexts(::AbstractContext) = true # Prefix -> Condition c1 = PrefixContext(@varname(a), ConditionContext((c=1, d=2)))