@@ -176,13 +176,7 @@ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
176176 # this means that the code below will work both of linked and invlinked `vi`.
177177 # Ref: https://github.com/TuringLang/Turing.jl/issues/2195
178178 # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
179- vals = DynamicPPL. values_as_in_model (model, true , deepcopy (vi))
180-
181- # Obtain an iterator over the flattened parameter names and values.
182- iters = map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals))
183-
184- # Materialize the iterators and concatenate.
185- return mapreduce (collect, vcat, iters)
179+ return DynamicPPL. values_as_in_model (model, true , deepcopy (vi))
186180end
187181function getparams (
188182 model:: DynamicPPL.Model , untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
@@ -193,14 +187,25 @@ function getparams(
193187 return getparams (model, DynamicPPL. typed_varinfo (untyped_vi))
194188end
195189function getparams (:: DynamicPPL.Model , :: DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}} )
196- return float (Real)[]
190+ return Dict {VarName,Any} ()
197191end
198192
199193function _params_to_array (model:: DynamicPPL.Model , ts:: Vector )
200194 names_set = OrderedSet {VarName} ()
201195 # Extract the parameter names and values from each transition.
202196 dicts = map (ts) do t
203- nms_and_vs = getparams (model, t)
197+ # In general getparams returns a dict of VarName => values. We need to also
198+ # split it up into constituent elements using
199+ # `DynamicPPL.varname_and_value_leaves` because otherwise MCMCChains.jl
200+ # won't understand it.
201+ vals = getparams (model, t)
202+ nms_and_vs = if isempty (vals)
203+ Tuple{VarName,Any}[]
204+ else
205+ iters = map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals))
206+ mapreduce (collect, vcat, iters)
207+ end
208+
204209 nms = map (first, nms_and_vs)
205210 vs = map (last, nms_and_vs)
206211 for nm in nms
@@ -210,9 +215,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
210215 return OrderedDict (zip (nms, vs))
211216 end
212217 names = collect (names_set)
213- vals = [
214- get (dicts[i], key, missing ) for i in eachindex (dicts), (j, key) in enumerate (names)
215- ]
218+ vals = [get (dicts[i], key, missing ) for i in eachindex (dicts), key in names]
216219
217220 return names, vals
218221end
0 commit comments