1- struct TrackedValue{T}
2- value:: T
3- end
4-
5- is_tracked_value (:: TrackedValue ) = true
6- is_tracked_value (:: Any ) = false
7-
8- check_tilde_rhs (x:: TrackedValue ) = x
9-
101"""
11- ValuesAsInModelContext
2+ ValuesAsInModelAccumulator <: AbstractAccumulator
123
13- A context that is used by [`values_as_in_model`](@ref) to obtain values
4+ An accumulator that is used by [`values_as_in_model`](@ref) to obtain values
145of the model parameters as they are in the model.
156
167This is particularly useful when working in unconstrained space, but one
@@ -19,72 +10,47 @@ wants to extract the realization of a model in a constrained space.
1910# Fields
2011$(TYPEDFIELDS)
2112"""
22- struct ValuesAsInModelContext{C <: AbstractContext } <: AbstractContext
13+ struct ValuesAsInModelAccumulator <: AbstractAccumulator
2314 " values that are extracted from the model"
2415 values:: OrderedDict
2516 " whether to extract variables on the LHS of :="
2617 include_colon_eq:: Bool
27- " child context"
28- context:: C
2918end
30- function ValuesAsInModelContext (include_colon_eq, context :: AbstractContext )
31- return ValuesAsInModelContext (OrderedDict (), include_colon_eq, context )
19+ function ValuesAsInModelAccumulator (include_colon_eq)
20+ return ValuesAsInModelAccumulator (OrderedDict (), include_colon_eq)
3221end
3322
34- NodeTrait (:: ValuesAsInModelContext ) = IsParent ()
35- childcontext (context:: ValuesAsInModelContext ) = context. context
36- function setchildcontext (context:: ValuesAsInModelContext , child)
37- return ValuesAsInModelContext (context. values, context. include_colon_eq, child)
38- end
23+ accumulator_name (:: Type{<:ValuesAsInModelAccumulator} ) = :ValuesAsInModel
3924
40- is_extracting_values (context:: ValuesAsInModelContext ) = context. include_colon_eq
41- function is_extracting_values (context:: AbstractContext )
42- return is_extracting_values (NodeTrait (context), context)
25+ function split (acc:: ValuesAsInModelAccumulator )
26+ return ValuesAsInModelAccumulator (empty (acc. values), acc. include_colon_eq)
4327end
44- is_extracting_values (:: IsParent , :: AbstractContext ) = false
45- is_extracting_values (:: IsLeaf , :: AbstractContext ) = false
46-
47- function Base. push! (context:: ValuesAsInModelContext , vn:: VarName , value)
48- return setindex! (context. values, copy (value), prefix (context, vn))
28+ function combine (acc1:: ValuesAsInModelAccumulator , acc2:: ValuesAsInModelAccumulator )
29+ if acc1. include_colon_eq != acc2. include_colon_eq
30+ msg = " Cannot combine accumulators with different include_colon_eq values."
31+ throw (ArgumentError (msg))
32+ end
33+ return ValuesAsInModelAccumulator (
34+ merge (acc1. values, acc2. values), acc1. include_colon_eq
35+ )
4936end
5037
51- function broadcast_push! (context:: ValuesAsInModelContext , vns, values)
52- return push! .((context,), vns, values)
38+ function Base. push! (acc:: ValuesAsInModelAccumulator , vn:: VarName , val)
39+ setindex! (acc. values, deepcopy (val), vn)
40+ return acc
5341end
5442
55- # This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`.
56- function broadcast_push! (
57- context:: ValuesAsInModelContext , vns:: AbstractVector , values:: AbstractMatrix
58- )
59- for (vn, col) in zip (vns, eachcol (values))
60- push! (context, vn, col)
61- end
43+ function is_extracting_values (vi:: AbstractVarInfo )
44+ return hasacc (vi, Val (:ValuesAsInModel )) &&
45+ getacc (vi, Val (:ValuesAsInModel )). include_colon_eq
6246end
6347
64- # `tilde_asssume`
65- function tilde_assume (context:: ValuesAsInModelContext , right, vn, vi)
66- if is_tracked_value (right)
67- value = right. value
68- else
69- value, vi = tilde_assume (childcontext (context), right, vn, vi)
70- end
71- push! (context, vn, value)
72- return value, vi
73- end
74- function tilde_assume (
75- rng:: Random.AbstractRNG , context:: ValuesAsInModelContext , sampler, right, vn, vi
76- )
77- if is_tracked_value (right)
78- value = right. value
79- else
80- value, vi = tilde_assume (rng, childcontext (context), sampler, right, vn, vi)
81- end
82- # Save the value.
83- push! (context, vn, value)
84- # Pass on.
85- return value, vi
48+ function accumulate_assume!! (acc:: ValuesAsInModelAccumulator , val, logjac, vn, right)
49+ return push! (acc, vn, val)
8650end
8751
52+ accumulate_observe!! (acc:: ValuesAsInModelAccumulator , right, left, vn) = acc
53+
8854"""
8955 values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext])
9056
@@ -103,7 +69,7 @@ space at the cost of additional model evaluations.
10369- `model::Model`: model to extract realizations from.
10470- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
10571- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
106- - `context::AbstractContext`: base context to use for the extraction. Defaults
72+ - `context::AbstractContext`: evaluation context to use in the extraction. Defaults
10773 to `DynamicPPL.DefaultContext()`.
10874
10975# Examples
@@ -164,7 +130,8 @@ function values_as_in_model(
164130 varinfo:: AbstractVarInfo ,
165131 context:: AbstractContext = DefaultContext (),
166132)
167- context = ValuesAsInModelContext (include_colon_eq, context)
168- evaluate!! (model, varinfo, context)
169- return context. values
133+ accs = getaccs (varinfo)
134+ varinfo = setaccs!! (deepcopy (varinfo), (ValuesAsInModelAccumulator (include_colon_eq),))
135+ varinfo = last (evaluate!! (model, varinfo, context))
136+ return getacc (varinfo, Val (:ValuesAsInModel )). values
170137end
0 commit comments