@@ -154,6 +154,7 @@ const NTVarInfo = VarInfo{<:NamedTuple}
154154const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{
155155 VarInfo{Tmeta},ThreadSafeVarInfo{<: VarInfo{Tmeta} }
156156}
157+ const TupleVarInfo = VarInfo{<: VarNamedTuple }
157158
158159function Base.:(== )(vi1:: VarInfo , vi2:: VarInfo )
159160 return (vi1. metadata == vi2. metadata && vi1. accs == vi2. accs)
@@ -356,6 +357,28 @@ function typed_vector_varinfo(
356357 return typed_vector_varinfo (Random. default_rng (), model, init_strategy)
357358end
358359
360+ function make_leaf_metadata ((r, dist), optic)
361+ md = Metadata ()
362+ vn = VarName {:_} (optic)
363+ push! (md, vn, r, dist)
364+ return md
365+ end
366+
367+ function tuple_varinfo ()
368+ metadata = VarNamedTuple ((;), make_leaf_metadata)
369+ return VarInfo (metadata, copy (default_accumulators ()))
370+ end
371+ function tuple_varinfo (
372+ rng:: Random.AbstractRNG ,
373+ model:: Model ,
374+ init_strategy:: AbstractInitStrategy = InitFromPrior (),
375+ )
376+ return last (init!! (rng, model, tuple_varinfo (), init_strategy))
377+ end
378+ function tuple_varinfo (model:: Model , init_strategy:: AbstractInitStrategy = InitFromPrior ())
379+ return tuple_varinfo (Random. default_rng (), model, init_strategy)
380+ end
381+
359382"""
360383 vector_length(varinfo::VarInfo)
361384
@@ -639,6 +662,9 @@ Return the metadata in `vi` that belongs to `vn`.
639662"""
640663getmetadata (vi:: VarInfo , vn:: VarName ) = vi. metadata
641664getmetadata (vi:: NTVarInfo , vn:: VarName ) = getfield (vi. metadata, getsym (vn))
665+ function getmetadata (vi:: TupleVarInfo , vn:: VarName )
666+ return getindex (vi. metadata, remove_trailing_index (vn))
667+ end
642668
643669"""
644670 getidx(vi::VarInfo, vn::VarName)
744770Return the distribution from which `vn` was sampled in `vi`.
745771"""
746772getdist (vi:: VarInfo , vn:: VarName ) = getdist (getmetadata (vi, vn), vn)
773+ function getdist (vi:: TupleVarInfo , vn:: VarName )
774+ main_vn, optic = split_trailing_index (vn)
775+ return getdist (getindex (vi. metadata, main_vn), VarName {:_} (optic))
776+ end
747777getdist (md:: Metadata , vn:: VarName ) = md. dists[getidx (md, vn)]
748778# TODO (mhauru) Remove this once the old Gibbs sampler stuff is gone.
749779function getdist (:: VarNamedVector , :: VarName )
@@ -782,6 +812,10 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`.
782812The values may or may not be transformed to Euclidean space.
783813"""
784814setval! (vi:: VarInfo , val, vn:: VarName ) = setval! (getmetadata (vi, vn), val, vn)
815+ function setval! (vi:: TupleVarInfo , val, vn:: VarName )
816+ main_vn, optic = split_trailing_index (vn)
817+ return setval! (getindex (vi. metadata, main_vn), VarName {:_} (optic))
818+ end
785819function setval! (md:: Metadata , val:: AbstractVector , vn:: VarName )
786820 return md. vals[getrange (md, vn)] = val
787821end
@@ -1579,6 +1613,7 @@ function Base.haskey(vi::NTVarInfo, vn::VarName)
15791613 end
15801614 return any (md_haskey)
15811615end
1616+ Base. haskey (vi:: TupleVarInfo , vn:: VarName ) = haskey (vi. metadata, vn)
15821617
15831618function Base. show (io:: IO , :: MIME"text/plain" , vi:: UntypedVarInfo )
15841619 lines = Tuple{String,Any}[
@@ -1673,6 +1708,25 @@ function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution)
16731708 return vi
16741709end
16751710
1711+ function BangBang. push!! (vi:: TupleVarInfo , vn:: VarName , r, dist:: Distribution )
1712+ @assert ~ (haskey (vi, vn)) " [push!!] attempt to add an existing variable $(getsym (vn)) ($(vn) ) to TupleVarInfo with dist=$dist "
1713+ return VarInfo (setindex!! (vi. metadata, (r, dist), vn), vi. accs)
1714+ end
1715+
1716+ # TODO (mhauru) Implement properly
1717+ function is_transformed (vi:: TupleVarInfo , vn:: VarName )
1718+ return false
1719+ end
1720+
1721+ function getindex (vi:: TupleVarInfo , vn:: VarName )
1722+ main_vn, optic = split_trailing_index (vn)
1723+ return getindex (getindex (vi. metadata, main_vn), VarName {:_} (optic))
1724+ end
1725+ function getindex_internal (vi:: TupleVarInfo , vn:: VarName )
1726+ main_vn, optic = split_trailing_index (vn)
1727+ return getindex_internal (getindex (vi. metadata, main_vn), VarName {:_} (optic))
1728+ end
1729+
16761730function Base. push! (vi:: UntypedVectorVarInfo , vn:: VarName , val, args... )
16771731 push! (getmetadata (vi, vn), vn, val, args... )
16781732 return vi
0 commit comments