Skip to content

Commit eb2836d

Browse files
committed
Sketching VarNamedTuple and its VarInfo
1 parent 262d732 commit eb2836d

File tree

3 files changed

+77
-0
lines changed

3 files changed

+77
-0
lines changed

src/DynamicPPL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ include("contexts/prefix.jl")
178178
include("contexts/conditionfix.jl") # Must come after contexts/prefix.jl
179179
include("model.jl")
180180
include("varname.jl")
181+
include("varnamedtuple.jl")
182+
using .VarNamedTuples: VarNamedTuple
181183
include("distribution_wrappers.jl")
182184
include("submodel.jl")
183185
include("varnamedvector.jl")

src/varinfo.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ const NTVarInfo = VarInfo{<:NamedTuple}
154154
const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{
155155
VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}}
156156
}
157+
const TupleVarInfo = VarInfo{<:VarNamedTuple}
157158

158159
function Base.:(==)(vi1::VarInfo, vi2::VarInfo)
159160
return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs)
@@ -356,6 +357,28 @@ function typed_vector_varinfo(
356357
return typed_vector_varinfo(Random.default_rng(), model, init_strategy)
357358
end
358359

360+
function make_leaf_metadata((r, dist), optic)
361+
md = Metadata()
362+
vn = VarName{:_}(optic)
363+
push!(md, vn, r, dist)
364+
return md
365+
end
366+
367+
function tuple_varinfo()
368+
metadata = VarNamedTuple((;), make_leaf_metadata)
369+
return VarInfo(metadata, copy(default_accumulators()))
370+
end
371+
function tuple_varinfo(
372+
rng::Random.AbstractRNG,
373+
model::Model,
374+
init_strategy::AbstractInitStrategy=InitFromPrior(),
375+
)
376+
return last(init!!(rng, model, tuple_varinfo(), init_strategy))
377+
end
378+
function tuple_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior())
379+
return tuple_varinfo(Random.default_rng(), model, init_strategy)
380+
end
381+
359382
"""
360383
vector_length(varinfo::VarInfo)
361384
@@ -639,6 +662,9 @@ Return the metadata in `vi` that belongs to `vn`.
639662
"""
640663
getmetadata(vi::VarInfo, vn::VarName) = vi.metadata
641664
getmetadata(vi::NTVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn))
665+
function getmetadata(vi::TupleVarInfo, vn::VarName)
666+
return getindex(vi.metadata, remove_trailing_index(vn))
667+
end
642668

643669
"""
644670
getidx(vi::VarInfo, vn::VarName)
@@ -744,6 +770,10 @@ end
744770
Return the distribution from which `vn` was sampled in `vi`.
745771
"""
746772
getdist(vi::VarInfo, vn::VarName) = getdist(getmetadata(vi, vn), vn)
773+
function getdist(vi::TupleVarInfo, vn::VarName)
774+
main_vn, optic = split_trailing_index(vn)
775+
return getdist(getindex(vi.metadata, main_vn), VarName{:_}(optic))
776+
end
747777
getdist(md::Metadata, vn::VarName) = md.dists[getidx(md, vn)]
748778
# TODO(mhauru) Remove this once the old Gibbs sampler stuff is gone.
749779
function getdist(::VarNamedVector, ::VarName)
@@ -782,6 +812,10 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`.
782812
The values may or may not be transformed to Euclidean space.
783813
"""
784814
setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn)
815+
function setval!(vi::TupleVarInfo, val, vn::VarName)
816+
main_vn, optic = split_trailing_index(vn)
817+
return setval!(getindex(vi.metadata, main_vn), VarName{:_}(optic))
818+
end
785819
function setval!(md::Metadata, val::AbstractVector, vn::VarName)
786820
return md.vals[getrange(md, vn)] = val
787821
end
@@ -1579,6 +1613,7 @@ function Base.haskey(vi::NTVarInfo, vn::VarName)
15791613
end
15801614
return any(md_haskey)
15811615
end
1616+
Base.haskey(vi::TupleVarInfo, vn::VarName) = haskey(vi.metadata, vn)
15821617

15831618
function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo)
15841619
lines = Tuple{String,Any}[
@@ -1673,6 +1708,25 @@ function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution)
16731708
return vi
16741709
end
16751710

1711+
function BangBang.push!!(vi::TupleVarInfo, vn::VarName, r, dist::Distribution)
1712+
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TupleVarInfo with dist=$dist"
1713+
return VarInfo(setindex!!(vi.metadata, (r, dist), vn), vi.accs)
1714+
end
1715+
1716+
# TODO(mhauru) Implement properly
1717+
function is_transformed(vi::TupleVarInfo, vn::VarName)
1718+
return false
1719+
end
1720+
1721+
function getindex(vi::TupleVarInfo, vn::VarName)
1722+
main_vn, optic = split_trailing_index(vn)
1723+
return getindex(getindex(vi.metadata, main_vn), VarName{:_}(optic))
1724+
end
1725+
function getindex_internal(vi::TupleVarInfo, vn::VarName)
1726+
main_vn, optic = split_trailing_index(vn)
1727+
return getindex_internal(getindex(vi.metadata, main_vn), VarName{:_}(optic))
1728+
end
1729+
16761730
function Base.push!(vi::UntypedVectorVarInfo, vn::VarName, val, args...)
16771731
push!(getmetadata(vi, vn), vn, val, args...)
16781732
return vi

src/varname.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,24 @@ Possibly existing indices of `varname` are neglected.
4141
) where {s,missings,_F,_a,_T}
4242
return s in missings
4343
end
44+
45+
function remove_trailing_index(vn::VarName{sym,Optic}) where {sym,Optic}
46+
return if Optic === typeof(identity)
47+
vn
48+
elseif Optic isa IndexLens
49+
VarName{sym}()
50+
else
51+
prefix(remove_trailing_index(unprefix(vn, sym)), sym)
52+
end
53+
end
54+
55+
function split_trailing_index(vn::VarName{sym,Optic}) where {sym,Optic}
56+
return if Optic === typeof(identity)
57+
(vn, identity)
58+
elseif Optic isa IndexLens
59+
(VarName{sym}(), Optic.index)
60+
else
61+
(prefix, index) = split_trailing_index(unprefix(vn, sym))
62+
(prefix(prefix, sym), index)
63+
end
64+
end

0 commit comments

Comments
 (0)