@@ -309,20 +309,40 @@ function prefix(model::Model, ::Val{x}) where {x}
309309 return contextualize (model, PrefixContext {Symbol(x)} (model. context))
310310end
311311
312- struct ConditionContext{Values,Ctx<: AbstractContext } <: AbstractContext
312+ """
313+
314+ ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext}
315+
316+ Model context that contains values that are to be conditioned on. The values
317+ can either be a NamedTuple mapping symbols to values, such as `(a=1, b=2)`, or
318+ an AbstractDict mapping varnames to values (e.g. `Dict(@varname(a) => 1,
319+ @varname(b) => 2)`). The former is more performant, but the latter must be used
320+ when there are varnames that cannot be represented as symbols, e.g.
321+ `@varname(x[1])`.
322+ """
323+ struct ConditionContext{
324+ Values<: Union{NamedTuple,AbstractDict{<:VarName}} ,Ctx<: AbstractContext
325+ } <: AbstractContext
313326 values:: Values
314327 context:: Ctx
315328end
316329
317330const NamedConditionContext{Names} = ConditionContext{<: NamedTuple{Names} }
318331const DictConditionContext = ConditionContext{<: AbstractDict }
319332
320- ConditionContext (values) = ConditionContext (values, DefaultContext ())
321-
322- # Try to avoid nested `ConditionContext`.
333+ # Use DefaultContext as the default base context
334+ function ConditionContext (values:: Union{NamedTuple,AbstractDict} )
335+ return ConditionContext (values, DefaultContext ())
336+ end
337+ # Optimisation when there are no values to condition on
338+ ConditionContext (:: NamedTuple{()} , context:: AbstractContext ) = context
339+ # Collapse consecutive levels of `ConditionContext`. Note that this overrides
340+ # values inside the child context, thus giving precedence to the outermost
341+ # `ConditionContext`.
323342function ConditionContext (values:: NamedTuple , context:: NamedConditionContext )
324- # Note that this potentially overrides values from `context`, thus giving
325- # precedence to the outmost `ConditionContext`.
343+ return ConditionContext (merge (context. values, values), childcontext (context))
344+ end
345+ function ConditionContext (values:: AbstractDict{<:VarName} , context:: DictConditionContext )
326346 return ConditionContext (merge (context. values, values), childcontext (context))
327347end
328348
@@ -399,43 +419,6 @@ function getconditioned_nested(::IsParent, context, vn)
399419 end
400420end
401421
402- """
403- condition([context::AbstractContext,] values::NamedTuple)
404- condition([context::AbstractContext]; values...)
405-
406- Return `ConditionContext` with `values` and `context` if `values` is non-empty,
407- otherwise return `context` which is [`DefaultContext`](@ref) by default.
408-
409- See also: [`decondition`](@ref)
410- """
411- AbstractPPL. condition (; values... ) = condition (NamedTuple (values))
412- AbstractPPL. condition (values:: NamedTuple ) = condition (DefaultContext (), values)
413- function AbstractPPL. condition (value:: Pair{<:VarName} , values:: Pair{<:VarName} ...)
414- return condition ((value, values... ))
415- end
416- function AbstractPPL. condition (values:: NTuple{<:Any,<:Pair{<:VarName}} )
417- return condition (DefaultContext (), values)
418- end
419- AbstractPPL. condition (context:: AbstractContext , values:: NamedTuple{()} ) = context
420- function AbstractPPL. condition (
421- context:: AbstractContext , values:: Union{AbstractDict,NamedTuple}
422- )
423- return ConditionContext (values, context)
424- end
425- function AbstractPPL. condition (context:: AbstractContext ; values... )
426- return condition (context, NamedTuple (values))
427- end
428- function AbstractPPL. condition (
429- context:: AbstractContext , value:: Pair{<:VarName} , values:: Pair{<:VarName} ...
430- )
431- return condition (context, (value, values... ))
432- end
433- function AbstractPPL. condition (
434- context:: AbstractContext , values:: NTuple{<:Any,Pair{<:VarName}}
435- )
436- return condition (context, Dict (values))
437- end
438-
439422"""
440423 decondition(context::AbstractContext, syms...)
441424
@@ -445,41 +428,34 @@ Note that this recursively traverses contexts, deconditioning all along the way.
445428
446429See also: [`condition`](@ref)
447430"""
448- AbstractPPL . decondition (:: IsLeaf , context, args... ) = context
449- function AbstractPPL . decondition (:: IsParent , context, args... )
450- return setchildcontext (context, decondition (childcontext (context), args... ))
431+ decondition_context (:: IsLeaf , context, args... ) = context
432+ function decondition_context (:: IsParent , context, args... )
433+ return setchildcontext (context, decondition_context (childcontext (context), args... ))
451434end
452- function AbstractPPL . decondition (context, args... )
453- return decondition (NodeTrait (context), context, args... )
435+ function decondition_context (context, args... )
436+ return decondition_context (NodeTrait (context), context, args... )
454437end
455- function AbstractPPL. decondition (context:: ConditionContext )
456- return decondition (childcontext (context))
457- end
458- function AbstractPPL. decondition (context:: ConditionContext , sym)
459- return condition (
460- decondition (childcontext (context), sym), BangBang. delete!! (context. values, sym)
461- )
438+ function decondition_context (context:: ConditionContext )
439+ return decondition_context (childcontext (context))
462440end
463- function AbstractPPL. decondition (context:: ConditionContext , sym, syms... )
464- return decondition (
465- condition (
466- decondition (childcontext (context), syms... ),
467- BangBang. delete!! (context. values, sym),
468- ),
469- syms... ,
470- )
471- end
472-
473- function AbstractPPL. decondition (
474- context:: NamedConditionContext , vn:: VarName{sym}
475- ) where {sym}
476- return condition (
477- decondition (childcontext (context), vn), BangBang. delete!! (context. values, sym)
478- )
441+ function decondition_context (context:: ConditionContext , sym, syms... )
442+ new_values = deepcopy (context. values)
443+ for s in (sym, syms... )
444+ new_values = BangBang. delete!! (new_values, s)
445+ end
446+ return if length (new_values) == 0
447+ # No more values left, can unwrap
448+ decondition_context (childcontext (context), syms... )
449+ else
450+ ConditionContext (
451+ new_values, decondition_context (childcontext (context), sym, syms... )
452+ )
453+ end
479454end
480- function AbstractPPL. decondition (context:: ConditionContext , vn:: VarName )
481- return condition (
482- decondition (childcontext (context), vn), BangBang. delete!! (context. values, vn)
455+ function decondition_context (context:: NamedConditionContext , vn:: VarName{sym} ) where {sym}
456+ return ConditionContext (
457+ BangBang. delete!! (context. values, sym),
458+ decondition_context (childcontext (context), vn),
483459 )
484460end
485461
0 commit comments