8080
8181"""
8282 hasvalue(
83- vals::AbstractDict,
83+ vals::Union{ AbstractDict,NamedTuple} ,
8484 vn::VarName,
8585 dist::Distribution;
8686 error_on_incomplete::Bool=false
@@ -98,6 +98,11 @@ the values needed for `vn` are present, but others are not. This may help
9898to detect invalid cases where the user has provided e.g. data of the wrong
9999shape.
100100
101+ Note that this check is only possible if a Dict is passed, because the key type
102+ of a NamedTuple (i.e., Symbol) is not rich enough to carry indexing
103+ information. If this method is called with a NamedTuple, it will just defer
104+ to `hasvalue(vals, vn)`.
105+
101106For example:
102107
103108```jldoctest; setup=:(using Distributions, LinearAlgebra))
@@ -114,6 +119,16 @@ ERROR: hasvalue: only partial values for `x` found in the values provided
114119[...]
115120```
116121"""
122+ function AbstractPPL. hasvalue (
123+ vals:: NamedTuple ,
124+ vn:: VarName ,
125+ dist:: Distributions.Distribution ;
126+ error_on_incomplete:: Bool = false ,
127+ )
128+ # NamedTuples can't have such complicated hierarchies, so it's safe to
129+ # defer to the simpler `hasvalue(vals, vn)`.
130+ return hasvalue (vals, vn)
131+ end
117132function AbstractPPL. hasvalue (
118133 vals:: AbstractDict ,
119134 vn:: VarName ,
@@ -169,7 +184,11 @@ function AbstractPPL.hasvalue(
169184end
170185
171186"""
172- getvalue(vals::AbstractDict, vn::VarName, dist::Distribution)
187+ getvalue(
188+ vals::Union{AbstractDict,NamedTuple},
189+ vn::VarName,
190+ dist::Distribution
191+ )
173192
174193Retrieve the value of `vn` from `vals`, using the distribution `dist` to
175194reconstruct the value if necessary.
@@ -178,6 +197,11 @@ This is a more general version of `getvalue(vals, vn)`, in that even if `vn`
178197itself is not inside `vals`, it can still reconstruct the value of `vn`
179198from sub-values of `vn` that are present in `vals`.
180199
200+ Note that this reconstruction is only possible if a Dict is passed, because the
201+ key type of a NamedTuple (i.e., Symbol) is not rich enough to carry indexing
202+ information. If this method is called with a NamedTuple, it will just defer
203+ to `getvalue(vals, vn)`.
204+
181205For example:
182206
183207```jldoctest; setup=:(using Distributions, LinearAlgebra))
@@ -194,6 +218,15 @@ ERROR: getvalue: `x` was not found in the values provided
194218[...]
195219```
196220"""
221+ function AbstractPPL. getvalue (
222+ vals:: NamedTuple ,
223+ vn:: VarName ,
224+ dist:: Distributions.Distribution
225+ )
226+ # NamedTuples can't have such complicated hierarchies, so it's safe to
227+ # defer to the simpler `getvalue(vals, vn)`.
228+ return getvalue (vals, vn)
229+ end
197230function AbstractPPL. getvalue (
198231 vals:: AbstractDict , vn:: VarName , dist:: Distributions.Distribution ;
199232)
0 commit comments