Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries:
m => 2.0

julia> values_as(vi, Vector)
2-element Vector{Real}:
2-element Vector{Float64}:
1.0
2.0
```
Expand Down
73 changes: 49 additions & 24 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ the left-hand side of tilde statements. For example, `x[1]` and `x[2]` both
have the same symbol `x`.

Several type aliases are provided for these forms of VarInfos:
- `VarInfo{<:Metadata}` is `UntypedVarInfo`
- `VarInfo{<:Metadata}` is `UntypedLegacyVarInfo`
- `VarInfo{<:VarNamedVector}` is `UntypedVectorVarInfo`
- `VarInfo{<:NamedTuple}` is `NTVarInfo`

Expand All @@ -107,7 +107,7 @@ struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo
metadata::Tmeta
accs::Accs
end
function VarInfo(meta=Metadata())
function VarInfo(meta=VarNamedVector())
return VarInfo(meta, default_accumulators())
end

Expand Down Expand Up @@ -143,7 +143,7 @@ function VarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior
end

const UntypedVectorVarInfo = VarInfo{<:VarNamedVector}
const UntypedVarInfo = VarInfo{<:Metadata}
const UntypedLegacyVarInfo = VarInfo{<:Metadata}
# TODO: NTVarInfo carries no information about the type of the actual metadata
# i.e. the elements of the NamedTuple. It could be Metadata or it could be
# VarNamedVector.
Expand All @@ -154,6 +154,7 @@ const NTVarInfo = VarInfo{<:NamedTuple}
const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{
VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}}
}
const UntypedVarInfo = UntypedVectorVarInfo

function Base.:(==)(vi1::VarInfo, vi2::VarInfo)
return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs)
Expand Down Expand Up @@ -194,8 +195,20 @@ end
# VarInfo constructors #
########################

function untyped_varinfo(
rng::Random.AbstractRNG,
model::Model,
init_strategy::AbstractInitStrategy=InitFromPrior(),
)
return untyped_vector_varinfo(rng, model, init_strategy)
end

function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior())
return untyped_vector_varinfo(Random.default_rng(), model, init_strategy)
end

"""
untyped_varinfo([rng, ]model[, init_strategy])
untyped_legacy_varinfo([rng, ]model[, init_strategy])

Construct a VarInfo object for the given `model`, which has just a single
`Metadata` as its metadata field.
Expand All @@ -205,27 +218,29 @@ Construct a VarInfo object for the given `model`, which has just a single
- `model::Model`: The model for which to create the varinfo object
- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`.
"""
function untyped_varinfo(
function untyped_legacy_varinfo(
rng::Random.AbstractRNG,
model::Model,
init_strategy::AbstractInitStrategy=InitFromPrior(),
)
return last(init!!(rng, model, VarInfo(Metadata()), init_strategy))
end
function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior())
return untyped_varinfo(Random.default_rng(), model, init_strategy)
function untyped_legacy_varinfo(
model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()
)
return untyped_legacy_varinfo(Random.default_rng(), model, init_strategy)
end

"""
typed_varinfo(vi::UntypedVarInfo)
typed_legacy_varinfo(vi::UntypedLegacyVarInfo)

This function finds all the unique `sym`s from the instances of `VarName{sym}` found in
`vi.metadata.vns`. It then extracts the metadata associated with each symbol from the
global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `metadata` as
a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each
symbol.
"""
function typed_varinfo(vi::UntypedVarInfo)
function typed_legacy_varinfo(vi::UntypedLegacyVarInfo)
meta = vi.metadata
new_metas = Metadata[]
# Symbols of all instances of `VarName{sym}` in `vi.vns`
Expand Down Expand Up @@ -289,12 +304,16 @@ function typed_varinfo(
model::Model,
init_strategy::AbstractInitStrategy=InitFromPrior(),
)
return typed_varinfo(untyped_varinfo(rng, model, init_strategy))
return typed_vector_varinfo(rng, model, init_strategy)
end
function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior())
return typed_varinfo(Random.default_rng(), model, init_strategy)
end

function typed_varinfo(vi::UntypedVectorVarInfo)
return typed_vector_varinfo(vi)
end

"""
untyped_vector_varinfo([rng, ]model[, init_strategy])

Expand All @@ -306,7 +325,7 @@ Return a VarInfo object for the given `model`, which has just a single
- `model::Model`: The model for which to create the varinfo object
- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`.
"""
function untyped_vector_varinfo(vi::UntypedVarInfo)
function untyped_vector_varinfo(vi::UntypedLegacyVarInfo)
md = metadata_to_varnamedvector(vi.metadata)
return VarInfo(md, copy(vi.accs))
end
Expand Down Expand Up @@ -626,11 +645,11 @@ end
const VarView = Union{Int,UnitRange,Vector{Int}}

"""
setval!(vi::UntypedVarInfo, val, vview::Union{Int, UnitRange, Vector{Int}})
setval!(vi::UntypedLegacyVarInfo, val, vview::Union{Int, UnitRange, Vector{Int}})

Set the value of `vi.vals[vview]` to `val`.
"""
setval!(vi::UntypedVarInfo, val, vview::VarView) = vi.metadata.vals[vview] = val
setval!(vi::UntypedLegacyVarInfo, val, vview::VarView) = vi.metadata.vals[vview] = val

"""
getmetadata(vi::VarInfo, vn::VarName)
Expand Down Expand Up @@ -825,10 +844,10 @@ set_transformed!!(vi::VarInfo, ::AbstractTransformation) = set_transformed!!(vi,

Returns a tuple of the unique symbols of random variables in `vi`.
"""
syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols
syms(vi::UntypedLegacyVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols
syms(vi::NTVarInfo) = keys(vi.metadata)

_getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs)
_getidcs(vi::UntypedLegacyVarInfo) = 1:length(vi.metadata.idcs)
_getidcs(vi::NTVarInfo) = _getidcs(vi.metadata)

@generated function _getidcs(metadata::NamedTuple{names}) where {names}
Expand Down Expand Up @@ -949,7 +968,7 @@ function link!!(
return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model)
end

function _link!!(vi::UntypedVarInfo, vns)
function _link!!(vi::UntypedLegacyVarInfo, vns)
# TODO: Change to a lazy iterator over `vns`
if ~is_transformed(vi, vns[1])
for vn in vns
Expand Down Expand Up @@ -1063,7 +1082,7 @@ function maybe_invlink_before_eval!!(vi::VarInfo, model::Model)
return maybe_invlink_before_eval!!(t, vi, model)
end

function _invlink!!(vi::UntypedVarInfo, vns)
function _invlink!!(vi::UntypedLegacyVarInfo, vns)
if is_transformed(vi, vns[1])
for vn in vns
f = linked_internal_to_internal_transform(vi, vn)
Expand Down Expand Up @@ -1477,7 +1496,7 @@ function _invlink_metadata!!(
end

# TODO(mhauru) The treatment of the case when some variables are transformed and others are
# not should be revised. It used to be the case that for UntypedVarInfo `is_transformed`
# not should be revised. It used to be the case that for UntypedLegacyVarInfo `is_transformed`
# returned whether the first variable was linked. For NTVarInfo we did an OR over the first
# variables under each symbol. We now more consistently use OR, but I'm not convinced this
# is really the right thing to do.
Expand Down Expand Up @@ -1567,9 +1586,15 @@ Set the current value(s) of the random variable `vn` in `vi` to `val`.
The value(s) may or may not be transformed to Euclidean space.
"""
setindex!(vi::VarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi)

function BangBang.setindex!!(vi::VarInfo, val, vn::VarName)
setindex!(vi, val, vn)
return vi
md = setindex!!(getmetadata(vi, vn), val, vn)
return VarInfo(md, vi.accs)
end

function BangBang.setindex!!(vi::NTVarInfo, val, vn::VarName)
submd = setindex!!(getmetadata(vi, vn), val, vn)
return Accessors.@set vi.metadata[getsym(vn)] = submd
end

@inline function findvns(vi, f_vns)
Expand All @@ -1594,7 +1619,7 @@ function Base.haskey(vi::NTVarInfo, vn::VarName)
return any(md_haskey)
end

function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo)
function Base.show(io::IO, ::MIME"text/plain", vi::UntypedLegacyVarInfo)
lines = Tuple{String,Any}[
("VarNames", vi.metadata.vns),
("Range", vi.metadata.ranges),
Expand Down Expand Up @@ -1649,7 +1674,7 @@ function _show_varnames(io::IO, vi)
end
end

function Base.show(io::IO, vi::UntypedVarInfo)
function Base.show(io::IO, vi::UntypedLegacyVarInfo)
print(io, "VarInfo (")
_show_varnames(io, vi)
print(io, "; accumulators: ")
Expand Down Expand Up @@ -1821,11 +1846,11 @@ end

values_as(vi::VarInfo) = vi.metadata
values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon()))
function values_as(vi::UntypedVarInfo, ::Type{NamedTuple})
function values_as(vi::UntypedLegacyVarInfo, ::Type{NamedTuple})
iter = values_from_metadata(vi.metadata)
return NamedTuple(map(p -> Symbol(p.first) => p.second, iter))
end
function values_as(vi::UntypedVarInfo, ::Type{D}) where {D<:AbstractDict}
function values_as(vi::UntypedLegacyVarInfo, ::Type{D}) where {D<:AbstractDict}
return ConstructionBase.constructorof(D)(values_from_metadata(vi.metadata))
end

Expand Down
9 changes: 6 additions & 3 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -417,12 +417,15 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()

@testset "InitContext" begin
empty_varinfos = [
("untyped+metadata", VarInfo()),
("typed+metadata", DynamicPPL.typed_varinfo(VarInfo())),
("untyped+metadata", VarInfo(DynamicPPL.Metadata())),
(
"typed+metadata",
DynamicPPL.typed_legacy_varinfo(VarInfo(DynamicPPL.Metadata())),
),
("untyped+VNV", VarInfo(DynamicPPL.VarNamedVector())),
(
"typed+VNV",
DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())),
DynamicPPL.typed_vector_varinfo(VarInfo(DynamicPPL.VarNamedVector())),
),
("SVI+NamedTuple", SimpleVarInfo()),
("Svi+Dict", SimpleVarInfo(Dict{VarName,Any}())),
Expand Down
15 changes: 12 additions & 3 deletions test/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,11 @@
return nothing
end
buggy_model = buggy_subsumes_demo_model()
varinfo = VarInfo(buggy_model)
@test_throws "should not subsume each other" DynamicPPL.untyped_varinfo(
buggy_model
)

varinfo = DynamicPPL.untyped_legacy_varinfo(buggy_model)
@test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo)
issuccess = check_model(buggy_model, varinfo)
@test !issuccess
Expand All @@ -94,8 +97,11 @@
return nothing
end
buggy_model = buggy_subsumes_demo_model()
varinfo = VarInfo(buggy_model)
@test_throws "should not subsume each other" DynamicPPL.untyped_varinfo(
buggy_model
)

varinfo = DynamicPPL.untyped_legacy_varinfo(buggy_model)
@test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo)
issuccess = check_model(buggy_model, varinfo)
@test !issuccess
Expand All @@ -112,8 +118,11 @@
return nothing
end
buggy_model = buggy_subsumes_demo_model()
varinfo = VarInfo(buggy_model)
@test_throws "should not subsume each other" DynamicPPL.untyped_varinfo(
buggy_model
)

varinfo = DynamicPPL.untyped_legacy_varinfo(buggy_model)
@test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo)
issuccess = check_model(buggy_model, varinfo)
@test !issuccess
Expand Down
2 changes: 1 addition & 1 deletion test/ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
end
end
@test DynamicPPL.Experimental.determine_suitable_varinfo(demo4()) isa
DynamicPPL.UntypedVarInfo
DynamicPPL.NTVarInfo

# In this model, the type error occurs in the user code rather than in DynamicPPL.
@model function demo5()
Expand Down
2 changes: 1 addition & 1 deletion test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function short_varinfo_name(vi::DynamicPPL.NTVarInfo)
"TypedVarInfo"
end
end
short_varinfo_name(::DynamicPPL.UntypedVarInfo) = "UntypedVarInfo"
short_varinfo_name(::DynamicPPL.UntypedLegacyVarInfo) = "UntypedLegacyVarInfo"
short_varinfo_name(::DynamicPPL.UntypedVectorVarInfo) = "UntypedVectorVarInfo"
function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref})
return "SimpleVarInfo{<:NamedTuple,<:Ref}"
Expand Down
14 changes: 9 additions & 5 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ end
end
model = gdemo(1.0, 2.0)

_, vi = DynamicPPL.init!!(model, VarInfo(), InitFromUniform())
tvi = DynamicPPL.typed_varinfo(vi)
# TODO(mhauru) Make this test more generic. It currently explicitly relies on
# Metadata.
_, vi = DynamicPPL.init!!(model, VarInfo(DynamicPPL.Metadata()), InitFromUniform())
tvi = DynamicPPL.typed_legacy_varinfo(vi)

meta = vi.metadata
for f in fieldnames(typeof(tvi.metadata))
Expand Down Expand Up @@ -290,7 +292,7 @@ end
dist = Normal(0, 1)
r = rand(dist)

push!!(vi, vn_x, r, dist)
vi = push!!(vi, vn_x, r, dist)

# is_transformed is set by default
@test !is_transformed(vi, vn_x)
Expand Down Expand Up @@ -353,7 +355,9 @@ end
# worth specifically checking that it can do this without having to
# change the VarInfo object.
# TODO(penelopeysm): Move this to InitFromUniform tests rather than here.
vi = VarInfo()
# TODO(mhauru) Make this test more generic. It currently explicitly relies on
# Metadata.
vi = VarInfo(DynamicPPL.Metadata())
meta = vi.metadata
_, vi = DynamicPPL.init!!(model, vi, InitFromUniform())
@test all(x -> !is_transformed(vi, x), meta.vns)
Expand All @@ -367,7 +371,7 @@ end
@test meta.vals ≈ v atol = 1e-10

# Check that linking and invlinking preserves the values
vi = DynamicPPL.typed_varinfo(vi)
vi = DynamicPPL.typed_legacy_varinfo(vi)
meta = vi.metadata
v_s = copy(meta.s.vals)
v_m = copy(meta.m.vals)
Expand Down
Loading