@@ -124,85 +124,119 @@ end
124124# #####################
125125# Default Transition #
126126# #####################
127- # Default
128- getstats (t) = nothing
127+ getstats (:: Any ) = NamedTuple ()
129128
129+ # TODO (penelopeysm): Remove this abstract type by converting SGLDTransition,
130+ # SMCTransition, and PGTransition to Turing.Inference.Transition instead.
130131abstract type AbstractTransition end
131132
132- struct Transition{T,F<: AbstractFloat ,S <: Union{ NamedTuple,Nothing} } <: AbstractTransition
133+ struct Transition{T,F<: AbstractFloat ,N <: NamedTuple } <: AbstractTransition
133134 θ:: T
134- lp:: F # TODO : merge `lp` with `stat`
135- stat:: S
136- end
135+ logprior:: F
136+ loglikelihood:: F
137+ stat:: N
138+
139+ """
140+ Transition(model::Model, vi::AbstractVarInfo, sampler_transition)
141+
142+ Construct a new `Turing.Inference.Transition` object using the outputs of a
143+ sampler step.
144+
145+ Here, `vi` represents a VarInfo _for which the appropriate parameters have
146+ already been set_. However, the accumulators (e.g. logp) may in general
147+ have junk contents. The role of this method is to re-evaluate `model` and
148+ thus set the accumulators to the correct values.
149+
150+ `sampler_transition` is the transition object returned by the sampler
151+ itself and is only used to extract statistics of interest.
152+ """
153+ function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , sampler_transition)
154+ vi = DynamicPPL. setaccs!! (
155+ vi,
156+ (
157+ DynamicPPL. ValuesAsInModelAccumulator (true ),
158+ DynamicPPL. LogPriorAccumulator (),
159+ DynamicPPL. LogLikelihoodAccumulator (),
160+ ),
161+ )
162+ _, vi = DynamicPPL. evaluate!! (model, vi)
163+
164+ # Extract all the information we need
165+ vals_as_in_model = DynamicPPL. getacc (vi, Val (:ValuesAsInModel )). values
166+ logprior = DynamicPPL. getlogprior (vi)
167+ loglikelihood = DynamicPPL. getloglikelihood (vi)
168+
169+ # Convert values to the format needed (i.e. a Vector of (varname,
170+ # value) tuples, where value isa Real: all vector-valued varnames must
171+ # be split up.)
172+ # TODO (penelopeysm): This wouldn't be necessary if not for MCMCChains's
173+ # poor representation...
174+ values_split = if isempty (vals_as_in_model)
175+ # If there are no values, we return an empty vector.
176+ # This is the case for models with no parameters.
177+ Vector {Tuple{VarName,Any}} ()
178+ else
179+ iters = map (
180+ DynamicPPL. varname_and_value_leaves,
181+ keys (vals_as_in_model),
182+ values (vals_as_in_model),
183+ )
184+ mapreduce (collect, vcat, iters)
185+ end
137186
138- Transition (θ, lp) = Transition (θ, lp, nothing )
139- function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , t)
140- # TODO (DPPL0.37/penelopeysm): Fix this
141- θ = getparams (model, vi)
142- lp = getlogjoint_internal (vi)
143- return Transition (θ, lp, getstats (t))
144- end
187+ # Get additional statistics
188+ stats = getstats (sampler_transition)
189+ return new {typeof(values_split),typeof(logprior),typeof(stats)} (
190+ values_split, logprior, loglikelihood, stats
191+ )
192+ end
145193
146- # TODO (DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
147- function metadata (t:: Transition )
148- stat = t. stat
149- if stat === nothing
150- return (lp= t. lp,)
151- else
152- return merge ((lp= t. lp,), stat)
194+ function Transition (
195+ model:: DynamicPPL.Model ,
196+ untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata} ,
197+ sampler_transition,
198+ )
199+ # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
200+ # much faster to convert it to a typed varinfo first, hence this method.
201+ # https://github.com/TuringLang/Turing.jl/issues/2604
202+ return Transition (model, DynamicPPL. typed_varinfo (untyped_vi), sampler_transition)
153203 end
154204end
155205
156- # TODO (DPPL0.37/penelopeysm): Fix this
157- DynamicPPL. getlogjoint (t:: Transition ) = t. lp
158-
159- # Metadata of VarInfo object
160- # TODO (DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
161- metadata (vi:: AbstractVarInfo ) = (lp= getlogjoint (vi),)
206+ function metadata (t:: Transition )
207+ return merge (
208+ t. stat,
209+ (
210+ lp= t. logprior + t. loglikelihood,
211+ logprior= t. logprior,
212+ loglikelihood= t. loglikelihood,
213+ ),
214+ )
215+ end
216+ function metadata (vi:: AbstractVarInfo )
217+ return (
218+ lp= DynamicPPL. getlogjoint (vi),
219+ logprior= DynamicPPL. getlogp (vi),
220+ loglikelihood= DynamicPPL. getloglikelihood (vi),
221+ )
222+ end
162223
163224# #########################
164225# Chain making utilities #
165226# #########################
166227
167- """
168- getparams(model, t)
169-
170- Return a named tuple of parameters.
171- """
172- getparams (model, t) = t. θ
173- function getparams (model:: DynamicPPL.Model , vi:: DynamicPPL.VarInfo )
174- # NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
175- # Unfortunately, using `invlink` can cause issues in scenarios where the constraints
176- # of the parameters change depending on the realizations. Hence we have to use
177- # `values_as_in_model`, which re-runs the model and extracts the parameters
178- # as they are seen in the model, i.e. in the constrained space. Moreover,
179- # this means that the code below will work both of linked and invlinked `vi`.
180- # Ref: https://github.com/TuringLang/Turing.jl/issues/2195
181- # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
182- vals = DynamicPPL. values_as_in_model (model, true , deepcopy (vi))
183-
184- # Obtain an iterator over the flattened parameter names and values.
185- iters = map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals))
186-
187- # Materialize the iterators and concatenate.
188- return mapreduce (collect, vcat, iters)
228+ getparams (:: DynamicPPL.Model , t:: AbstractTransition ) = t. θ
229+ function getparams (model:: DynamicPPL.Model , vi:: AbstractVarInfo )
230+ t = Transition (model, vi, nothing )
231+ return getparams (model, t)
189232end
190- function getparams (
191- model:: DynamicPPL.Model , untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
192- )
193- # values_as_in_model is unconscionably slow for untyped VarInfo. It's
194- # much faster to convert it to a typed varinfo before calling getparams.
195- # https://github.com/TuringLang/Turing.jl/issues/2604
196- return getparams (model, DynamicPPL. typed_varinfo (untyped_vi))
197- end
198- function getparams (:: DynamicPPL.Model , :: DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}} )
199- return float (Real)[]
200- end
201-
202233function _params_to_array (model:: DynamicPPL.Model , ts:: Vector )
203234 names_set = OrderedSet {VarName} ()
204235 # Extract the parameter names and values from each transition.
205236 dicts = map (ts) do t
237+ # TODO (penelopeysm): Get rid of AbstractVarInfo transitions. see
238+ # https://github.com/TuringLang/Turing.jl/issues/2631. That would
239+ # allow us to just use t.θ here.
206240 nms_and_vs = getparams (model, t)
207241 nms = map (first, nms_and_vs)
208242 vs = map (last, nms_and_vs)
@@ -221,7 +255,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
221255end
222256
223257function get_transition_extras (ts:: AbstractVector{<:VarInfo} )
224- valmat = reshape ([getlogjoint (t) for t in ts], :, 1 )
258+ valmat = reshape ([DynamicPPL . getlogjoint (t) for t in ts], :, 1 )
225259 return [:lp ], valmat
226260end
227261
@@ -463,16 +497,17 @@ function transitions_from_chain(
463497 chain:: MCMCChains.Chains ;
464498 sampler= DynamicPPL. SampleFromPrior (),
465499)
466- vi = Turing . VarInfo (model)
500+ vi = VarInfo (model)
467501
468502 iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
469503 transitions = map (iters) do (sample_idx, chain_idx)
470504 # Set variables present in `chain` and mark those NOT present in chain to be resampled.
505+ # TODO (DPPL0.37/penelopeysm): Aargh! setval_and_resample!!!! Burn this!!!
471506 DynamicPPL. setval_and_resample! (vi, chain, sample_idx, chain_idx)
472507 model (rng, vi, sampler)
473508
474509 # Convert `VarInfo` into `NamedTuple` and save.
475- Transition (model, vi)
510+ Transition (model, vi, nothing )
476511 end
477512
478513 return transitions
0 commit comments