diff --git a/HISTORY.md b/HISTORY.md index 54b40b7e9..4b8f5980e 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,9 @@ # DynamicPPL Changelog +## 0.38.4 + +Improve performance of VarNamedVector. It should now be very nearly on par with Metadata for all models we've benchmarked on. + ## 0.38.3 Add an implementation of `returned(::Model, ::AbstractDict{<:VarName})`. diff --git a/Project.toml b/Project.toml index d54f9d1da..0773bbe04 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.38.3" +version = "0.38.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/api.md b/docs/src/api.md index 31b7d07da..b04bd445d 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -414,7 +414,7 @@ DynamicPPL.reset! DynamicPPL.update! DynamicPPL.insert! DynamicPPL.loosen_types!! -DynamicPPL.tighten_types +DynamicPPL.tighten_types!! ``` ```@docs diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 396e1463f..44dbc5508 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -180,7 +180,9 @@ function tilde_assume!!( end # Neither of these set the `trans` flag so we have to do it manually if # necessary. - insert_transformed_value && set_transformed!!(vi, true, vn) + if insert_transformed_value + vi = set_transformed!!(vi, true, vn) + end # `accumulate_assume!!` wants untransformed values as the second argument. vi = accumulate_assume!!(vi, x, logjac, vn, dist) # We always return the untransformed value here, as that will determine diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 13124e3a7..e8b50a0b7 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -27,7 +27,7 @@ add_io_context(io::IO) = IOContext(io, :compact => true, :limit => true) show_varname(io::IO, varname::VarName) = print(io, varname) function show_varname(io::IO, varname::Array{<:VarName,N}) where {N} # Attempt to make the type concrete in case the symbol is shared. - return _show_varname(io, map(identity, varname)) + return _show_varname(io, [vn for vn in varname]) end function _show_varname(io::IO, varname::Array{<:VarName,N}) where {N} # Print the first and last element of the array. @@ -407,7 +407,7 @@ julia> @model function demo_incorrect() end demo_incorrect (generic function with 2 methods) -julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually +julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually # alert us to the issue of `x` being sampled twice. model = demo_incorrect(); varinfo = VarInfo(model); diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index e5e6a6dae..3b7b84953 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -49,7 +49,7 @@ box: - [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of linking - [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected - by linking, since transforms are only applied to random variables) + by linking, since transforms are only applied to random variables) !!! note By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the @@ -146,7 +146,7 @@ struct LogDensityFunction{ is_supported(adtype) || @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." # Get a set of dummy params to use for prep - x = map(identity, varinfo[:]) + x = [val for val in varinfo[:]] if use_closure(adtype) prep = DI.prepare_gradient( LogDensityAt(model, getlogdensity, varinfo), adtype, x @@ -282,7 +282,7 @@ function LogDensityProblems.logdensity_and_gradient( ) where {M,F,V,AD<:ADTypes.AbstractADType} f.prep === nothing && error("Gradient preparation not available; this should not happen") - x = map(identity, x) # Concretise type + x = [val for val in x] # Concretise type # Make branching statically inferrable, i.e. type-stable (even if the two # branches happen to return different types) return if use_closure(f.adtype) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 2ba25f142..434480be6 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -484,6 +484,7 @@ function set_transformed!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName) "Individual variables in SimpleVarInfo cannot have different `set_transformed` statuses.", ) end + return vi end is_transformed(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 79442fccf..a49ffd18b 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -197,7 +197,7 @@ Everything else is optional, and can be categorised into several groups: 1. _How to specify the results to compare against._ Once logp and its gradient has been calculated with the specified `adtype`, - it can optionally be tested for correctness. The exact way this is tested + it can optionally be tested for correctness. The exact way this is tested is specified in the `test` parameter. There are several options for this: @@ -260,7 +260,7 @@ function run_ad( if isnothing(params) params = varinfo[:] end - params = map(identity, params) # Concretise + params = [p for p in params] # Concretise # Calculate log-density and gradient with the backend of interest verbose && @info "Running AD on $(model.f) with $(adtype)\n" diff --git a/src/varinfo.jl b/src/varinfo.jl index 734bf3db5..a90b81488 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -315,7 +315,7 @@ function untyped_vector_varinfo( model::Model, init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return untyped_vector_varinfo(untyped_varinfo(rng, model, init_strategy)) + return last(init!!(rng, model, VarInfo(VarNamedVector()), init_strategy)) end function untyped_vector_varinfo( model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() @@ -789,10 +789,16 @@ function setval!(md::Metadata, val, vn::VarName) return md.vals[getrange(md, vn)] = tovec(val) end +function set_transformed!!(vi::NTVarInfo, val::Bool, vn::VarName) + md = set_transformed!!(getmetadata(vi, vn), val, vn) + return Accessors.@set vi.metadata[getsym(vn)] = md +end + function set_transformed!!(vi::VarInfo, val::Bool, vn::VarName) - set_transformed!!(getmetadata(vi, vn), val, vn) - return vi + md = set_transformed!!(getmetadata(vi, vn), val, vn) + return VarInfo(md, vi.accs) end + function set_transformed!!(metadata::Metadata, val::Bool, vn::VarName) metadata.is_transformed[getidx(metadata, vn)] = val return metadata @@ -800,7 +806,7 @@ end function set_transformed!!(vi::VarInfo, val::Bool) for vn in keys(vi) - set_transformed!!(vi, val, vn) + vi = set_transformed!!(vi, val, vn) end return vi @@ -977,7 +983,7 @@ function filter_subsumed(filter_vns, filtered_vns) end @generated function _link!!( - ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} + ::NamedTuple{metadata_names}, vi, varnames::NamedTuple{vns_names} ) where {metadata_names,vns_names} expr = Expr(:block) for f in metadata_names @@ -988,7 +994,7 @@ end expr.args, quote f_vns = vi.metadata.$f.vns - f_vns = filter_subsumed(vns.$f, f_vns) + f_vns = filter_subsumed(varnames.$f, f_vns) if !isempty(f_vns) if !is_transformed(vi, f_vns[1]) # Iterate over all `f_vns` and transform @@ -1652,30 +1658,47 @@ end Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to the `VarInfo` `vi`, mutating if it makes sense. """ -function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) - if vi isa UntypedVarInfo - @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist" - elseif vi isa NTVarInfo - @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist" - end +function BangBang.push!!(vi::VarInfo, vn::VarName, val, dist::Distribution) + @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist" + md = push!!(getmetadata(vi, vn), vn, val, dist) + return VarInfo(md, vi.accs) +end +function BangBang.push!!(vi::NTVarInfo, vn::VarName, val, dist::Distribution) + @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist" sym = getsym(vn) - if vi isa NTVarInfo && ~haskey(vi.metadata, sym) + meta = if ~haskey(vi.metadata, sym) # The NamedTuple doesn't have an entry for this variable, let's add one. - val = tovec(r) - md = Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false])) - vi = Accessors.@set vi.metadata[sym] = md + _new_submetadata(vi, vn, val, dist) else - meta = getmetadata(vi, vn) - push!(meta, vn, r, dist) + push!!(getmetadata(vi, vn), vn, val, dist) end - + vi = Accessors.@set vi.metadata[sym] = meta return vi end -function Base.push!(vi::UntypedVectorVarInfo, vn::VarName, val, args...) - push!(getmetadata(vi, vn), vn, val, args...) - return vi +""" + _new_submetadata(vi::VarInfo{NamedTuple{Names,SubMetas}}, args...) where {Names,SubMetas} + +Create a new sub-metadata for an NTVarInfo. The type is chosen by the types of existing +SubMetas. +""" +@generated function _new_submetadata( + vi::VarInfo{NamedTuple{Names,SubMetas}}, vn, r, dist +) where {Names,SubMetas} + has_vnv = any(s -> s <: VarNamedVector, SubMetas.parameters) + return if has_vnv + :(return _new_vnv_submetadata(vn, r, dist)) + else + :(return _new_metadata_submetadata(vn, r, dist)) + end +end + +_new_vnv_submetadata(vn, r, _) = VarNamedVector([vn], [r]) + +function _new_metadata_submetadata(vn, r, dist) + val = tovec(r) + return Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false])) end function Base.push!(vi::UntypedVectorVarInfo, pair::Pair, args...) @@ -1700,6 +1723,11 @@ function Base.push!(meta::Metadata, vn, r, dist) return meta end +function BangBang.push!!(meta::Metadata, vn, r, dist) + push!(meta, vn, r, dist) + return meta +end + function Base.delete!(vi::VarInfo, vn::VarName) delete!(getmetadata(vi, vn), vn) return vi diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index f68498e46..2c66e1245 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -56,13 +56,13 @@ $(FIELDS) The values for different variables are internally all stored in a single vector. For instance, ```jldoctest varnamedvector-struct -julia> using DynamicPPL: ReshapeTransform, VarNamedVector, @varname, setindex!, update!, getindex_internal +julia> using DynamicPPL: ReshapeTransform, VarNamedVector, @varname, setindex!!, update!!, getindex_internal julia> vnv = VarNamedVector(); -julia> setindex!(vnv, [0.0, 0.0, 0.0, 0.0], @varname(x)); +julia> vnv = setindex!!(vnv, [0.0, 0.0, 0.0, 0.0], @varname(x)); -julia> setindex!(vnv, reshape(1:6, (2,3)), @varname(y)); +julia> vnv = setindex!!(vnv, reshape(1:6, (2,3)), @varname(y)); julia> vnv.vals 10-element Vector{Real}: @@ -91,7 +91,7 @@ If a variable is updated with a new value that is of a smaller dimension than th value, rather than resizing `vnv.vals`, some elements in `vnv.vals` are marked as inactive. ```jldoctest varnamedvector-struct -julia> update!(vnv, [46.0, 48.0], @varname(x)) +julia> vnv = update!!(vnv, [46.0, 48.0], @varname(x)); julia> vnv.vals 10-element Vector{Real}: @@ -107,7 +107,7 @@ julia> vnv.vals 6 julia> println(vnv.num_inactive); -OrderedDict(1 => 2) +Dict(1 => 2) ``` This helps avoid unnecessary memory allocations for values that repeatedly change dimension. @@ -133,17 +133,17 @@ julia> getindex_internal(vnv, :) ``` """ struct VarNamedVector{ - K<:VarName,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector + K<:VarName,V,T,KVec<:AbstractVector{K},VVec<:AbstractVector{V},TVec<:AbstractVector{T} } """ mapping from a `VarName` to its integer index in `varnames`, `ranges` and `transforms` """ - varname_to_index::OrderedDict{K,Int} + varname_to_index::Dict{K,Int} """ vector of `VarNames` for the variables, where `varnames[varname_to_index[vn]] == vn` """ - varnames::TVN # AbstractVector{<:VarName} + varnames::KVec """ vector of index ranges in `vals` corresponding to `varnames`; each `VarName` `vn` has @@ -156,14 +156,14 @@ struct VarNamedVector{ vector of values of all variables; the value(s) of `vn` is/are `vals[ranges[varname_to_index[vn]]]` """ - vals::TVal # AbstractVector{<:Real} + vals::VVec """ vector of transformations, so that `transforms[varname_to_index[vn]]` is a callable that transforms the value of `vn` back to its original space, undoing any linking and vectorisation """ - transforms::TTrans + transforms::TVec """ vector of booleans indicating whether a variable has been explicitly transformed to @@ -182,18 +182,18 @@ struct VarNamedVector{ Inactive entries always come after the last active entry for the given variable. See the extended help with `??VarNamedVector` for more details. """ - num_inactive::OrderedDict{Int,Int} + num_inactive::Dict{Int,Int} function VarNamedVector( varname_to_index, - varnames::TVN, + varnames::KVec, ranges, - vals::TVal, - transforms::TTrans, + vals::VVec, + transforms::TVec, is_unconstrained=fill!(BitVector(undef, length(varnames)), 0), - num_inactive=OrderedDict{Int,Int}(); + num_inactive=Dict{Int,Int}(); check_consistency::Bool=CHECK_CONSISTENCY_DEFAULT, - ) where {K,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector} + ) where {K,V,T,KVec<:AbstractVector{K},VVec<:AbstractVector{V},TVec<:AbstractVector{T}} if check_consistency if length(varnames) != length(ranges) || length(varnames) != length(transforms) || @@ -257,7 +257,7 @@ struct VarNamedVector{ # tiny bit of thought. end - return new{K,V,TVN,TVal,TTrans}( + return new{K,V,T,KVec,VVec,TVec}( varname_to_index, varnames, ranges, @@ -269,18 +269,13 @@ struct VarNamedVector{ end end -function VarNamedVector{K,V}() where {K,V} +function VarNamedVector{K,V,T}() where {K,V,T} return VarNamedVector( - OrderedDict{K,Int}(), K[], UnitRange{Int}[], V[], Any[]; check_consistency=false + Dict{K,Int}(), K[], UnitRange{Int}[], V[], T[]; check_consistency=false ) end -# TODO(mhauru) I would like for this to be VarNamedVector(Union{}, Union{}). Simlarly the -# transform vector type above could then be Union{}[]. This would allow expanding the -# VarName and element types only as necessary, which would help keep them concrete. However, -# making that change here opens some other cans of worms related to how VarInfo uses -# BangBang, that I don't want to deal with right now. -VarNamedVector() = VarNamedVector{VarName,Real}() +VarNamedVector() = VarNamedVector{Union{},Union{},Union{}}() function VarNamedVector(xs::Pair...; check_consistency=CHECK_CONSISTENCY_DEFAULT) return VarNamedVector(OrderedDict(xs...); check_consistency=check_consistency) end @@ -298,16 +293,17 @@ function VarNamedVector( transforms=fill(identity, length(varnames)); check_consistency=CHECK_CONSISTENCY_DEFAULT, ) + if isempty(varnames) && isempty(orig_vals) && isempty(transforms) + return VarNamedVector{eltype(varnames),eltype(orig_vals),eltype(transforms)}() + end # Convert `vals` into a vector of vectors. vals_vecs = map(tovec, orig_vals) transforms = map( (t, val) -> _compose_no_identity(t, from_vec_transform(val)), transforms, orig_vals ) - # TODO: Is this really the way to do this? - if !(eltype(varnames) <: VarName) - varnames = convert(Vector{VarName}, varnames) - end - varname_to_index = OrderedDict{eltype(varnames),Int}( + # Make `varnames` have as concrete an element type as possible. + varnames = [v for v in varnames] + varname_to_index = Dict{eltype(varnames),Int}( vn => i for (i, vn) in enumerate(varnames) ) vals = reduce(vcat, vals_vecs) @@ -345,6 +341,12 @@ function ==(vnv_left::VarNamedVector, vnv_right::VarNamedVector) vnv_left.num_inactive == vnv_right.num_inactive end +function is_concretely_typed(vnv::VarNamedVector) + return isconcretetype(eltype(vnv.varnames)) && + isconcretetype(eltype(vnv.vals)) && + isconcretetype(eltype(vnv.transforms)) +end + getidx(vnv::VarNamedVector, vn::VarName) = vnv.varname_to_index[vn] getrange(vnv::VarNamedVector, idx::Int) = vnv.ranges[idx] @@ -562,7 +564,7 @@ to be the default vectorisation transform. This undoes any possible linking. ```jldoctest varnamedvector-reset julia> using DynamicPPL: VarNamedVector, @varname, reset! -julia> vnv = VarNamedVector(); +julia> vnv = VarNamedVector{VarName,Any,Any}(); julia> vnv[@varname(x)] = reshape(1:9, (3, 3)); @@ -797,11 +799,16 @@ function update_internal!( return nothing end -function BangBang.push!(vnv::VarNamedVector, vn, val, dist) +function Base.push!(vnv::VarNamedVector, vn, val, dist) f = from_vec_transform(dist) return setindex_internal!(vnv, tovec(val), vn, f) end +function BangBang.push!!(vnv::VarNamedVector, vn, val, dist) + f = from_vec_transform(dist) + return setindex_internal!!(vnv, tovec(val), vn, f) +end + # BangBang versions of the above functions. # The only difference is that update_internal!! and insert_internal!! check whether the # container types of the VarNamedVector vector need to be expanded to accommodate the new @@ -810,7 +817,7 @@ end # with every ! call replaced with a !! call. """ - loosen_types!!(vnv::VarNamedVector{K,V,TVN,TVal,TTrans}, ::Type{KNew}, ::Type{TransNew}) + loosen_types!!(vnv::VarNamedVector, ::Type{KNew}, ::Type{VNew}, ::Type{TNew}) Loosen the types of `vnv` to allow varname type `KNew` and transformation type `TransNew`. @@ -821,7 +828,7 @@ transformations of type `TransNew` can be pushed to it. Some of the underlying s shared between `vnv` and the return value, and thus mutating one may affect the other. # See also -[`tighten_types`](@ref) +[`tighten_types!!`](@ref) # Examples @@ -836,7 +843,9 @@ julia> setindex_internal!(vnv, collect(1:4), @varname(y), y_trans) ERROR: MethodError: Cannot `convert` an object of type [...] -julia> vnv_loose = DynamicPPL.loosen_types!!(vnv, typeof(@varname(y)), typeof(y_trans)); +julia> vnv_loose = DynamicPPL.loosen_types!!( + vnv, typeof(@varname(y)), Float64, typeof(y_trans) + ); julia> setindex_internal!(vnv_loose, collect(1:4), @varname(y), y_trans) @@ -847,40 +856,63 @@ julia> vnv_loose[@varname(y)] ``` """ function loosen_types!!( - vnv::VarNamedVector, ::Type{KNew}, ::Type{TransNew} -) where {KNew,TransNew} + vnv::VarNamedVector, ::Type{KNew}, ::Type{VNew}, ::Type{TNew} +) where {KNew,VNew,TNew} K = eltype(vnv.varnames) - Trans = eltype(vnv.transforms) - if KNew <: K && TransNew <: Trans + V = eltype(vnv.vals) + T = eltype(vnv.transforms) + if KNew <: K && VNew <: V && TNew <: T return vnv else - vn_type = promote_type(K, KNew) - transform_type = promote_type(Trans, TransNew) - return VarNamedVector( - OrderedDict{vn_type,Int}(vnv.varname_to_index), - Vector{vn_type}(vnv.varnames), - vnv.ranges, - vnv.vals, - Vector{transform_type}(vnv.transforms), - vnv.is_unconstrained, - vnv.num_inactive; - check_consistency=false, - ) + # We could use promote_type here, instead of typejoin. However, that would e.g. + # cause Ints to be converted to Float64s, since + # promote_type(Int, Float64) == Float64, which can cause problems. See + # https://github.com/TuringLang/DynamicPPL.jl/pull/1098#discussion_r2472636188. + # Base.promote_typejoin would be like typejoin, but creates Unions out of Nothing + # and Missing, rather than falling back on Any. However, it's not exported. + vn_type = typejoin(K, KNew) + val_type = typejoin(V, VNew) + transform_type = typejoin(T, TNew) + # This function would work the same way if the first if statement a few lines above + # was skipped, and we only checked for the below condition. However, the first one + # is constant propagated away at compile time (at least on Julia v1.11.7), whereas + # this one isn't. Hence we keep both for performance. + return if vn_type == K && val_type == V && transform_type == T + vnv + elseif isempty(vnv) + VarNamedVector(vn_type[], val_type[], transform_type[]) + else + # TODO(mhauru) We allow a `vnv` to have any AbstractVector type as its vals, but + # then here always revert to Vector. + VarNamedVector( + Dict{vn_type,Int}(vnv.varname_to_index), + Vector{vn_type}(vnv.varnames), + vnv.ranges, + Vector{val_type}(vnv.vals), + Vector{transform_type}(vnv.transforms), + vnv.is_unconstrained, + vnv.num_inactive; + check_consistency=false, + ) + end end end """ - tighten_types(vnv::VarNamedVector) + tighten_types!!(vnv::VarNamedVector) + +Return a `VarNamedVector` like `vnv` with the most concrete types possible. -Return a copy of `vnv` with the most concrete types possible. +This function either returns `vnv` itself or new `VarNamedVector` with the same values in +it, but with the element types of various containers made as concrete as possible. For instance, if `vnv` has its vector of transforms have eltype `Any`, but all the transforms are actually identity transformations, this function will return a new `VarNamedVector` with the transforms vector having eltype `typeof(identity)`. -This is a lot like the reverse of [`loosen_types!!`](@ref), but with two notable -differences: Unlike `loosen_types!!`, this function does not mutate `vnv`; it also changes -not only the key and transform eltypes, but also the values eltype. +This is a lot like the reverse of [`loosen_types!!`](@ref). Like with `loosen_types!!`, the +return value may share some of its underlying storage with `vnv`, and thus mutating one may +affect the other. # See also [`loosen_types!!`](@ref) @@ -890,9 +922,9 @@ not only the key and transform eltypes, but also the values eltype. ```jldoctest varnamedvector-tighten-types julia> using DynamicPPL: VarNamedVector, @varname, loosen_types!!, setindex_internal! -julia> vnv = VarNamedVector(); +julia> vnv = VarNamedVector(@varname(x) => Real[23], @varname(y) => randn(2,2)); -julia> setindex!(vnv, [23], @varname(x)) +julia> vnv = delete!(vnv, @varname(y)); julia> eltype(vnv) Real @@ -901,7 +933,7 @@ julia> vnv.transforms 1-element Vector{Any}: identity (generic function with 1 method) -julia> vnv_tight = DynamicPPL.tighten_types(vnv); +julia> vnv_tight = DynamicPPL.tighten_types!!(vnv); julia> eltype(vnv_tight) == Int true @@ -911,17 +943,24 @@ julia> vnv_tight.transforms identity (generic function with 1 method) ``` """ -function tighten_types(vnv::VarNamedVector) - return VarNamedVector( - OrderedDict(vnv.varname_to_index...), - map(identity, vnv.varnames), - copy(vnv.ranges), - map(identity, vnv.vals), - map(identity, vnv.transforms), - copy(vnv.is_unconstrained), - copy(vnv.num_inactive); - check_consistency=false, - ) +function tighten_types!!(vnv::VarNamedVector) + return if is_concretely_typed(vnv) + # There can not be anything to tighten, so short-circuit. + vnv + elseif isempty(vnv) + VarNamedVector() + else + VarNamedVector( + Dict(vnv.varname_to_index...), + [x for x in vnv.varnames], + vnv.ranges, + [x for x in vnv.vals], + [x for x in vnv.transforms], + vnv.is_unconstrained, + vnv.num_inactive; + check_consistency=false, + ) + end end function BangBang.setindex!!(vnv::VarNamedVector, val, vn::VarName) @@ -973,18 +1012,22 @@ function setindex_internal!!( end end -function insert_internal!!(vnv::VarNamedVector, val, vn::VarName, transform=nothing) +function insert_internal!!( + vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing +) if transform === nothing transform = identity end - vnv = loosen_types!!(vnv, typeof(vn), typeof(transform)) + vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform)) insert_internal!(vnv, val, vn, transform) return vnv end -function update_internal!!(vnv::VarNamedVector, val, vn::VarName, transform=nothing) +function update_internal!!( + vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing +) transform_resolved = transform === nothing ? gettransform(vnv, vn) : transform - vnv = loosen_types!!(vnv, typeof(vn), typeof(transform_resolved)) + vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform_resolved)) update_internal!(vnv, val, vn, transform) return vnv end @@ -1134,12 +1177,12 @@ function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) # Determine `eltype` of `vals`. T_left = eltype(left_vnv.vals) T_right = eltype(right_vnv.vals) - T = promote_type(T_left, T_right) + T = typejoin(T_left, T_right) # Determine `eltype` of `varnames`. V_left = eltype(left_vnv.varnames) V_right = eltype(right_vnv.varnames) - V = promote_type(V_left, V_right) + V = typejoin(V_left, V_right) if !(V <: VarName) V = VarName end @@ -1147,10 +1190,10 @@ function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) # Determine `eltype` of `transforms`. F_left = eltype(left_vnv.transforms) F_right = eltype(right_vnv.transforms) - F = promote_type(F_left, F_right) + F = typejoin(F_left, F_right) # Allocate. - varname_to_index = OrderedDict{V,Int}() + varname_to_index = Dict{V,Int}() ranges = UnitRange{Int}[] vals = T[] transforms = F[] @@ -1219,7 +1262,6 @@ julia> subset(vnv, [@varname(x[2])]) == VarNamedVector(@varname(x[2]) => [2.0]) true """ function subset(vnv::VarNamedVector, vns_given::AbstractVector{<:VarName}) - # NOTE: This does not specialize types when possible. vnv_new = similar(vnv) # Return early if possible. isempty(vnv) && return vnv_new @@ -1231,7 +1273,7 @@ function subset(vnv::VarNamedVector, vns_given::AbstractVector{<:VarName}) end end - return vnv_new + return tighten_types!!(vnv_new) end """ @@ -1430,7 +1472,7 @@ true """ function group_by_symbol(vnv::VarNamedVector) symbols = unique(map(getsym, vnv.varnames)) - nt_vals = map(s -> tighten_types(subset(vnv, [VarName{s}()])), symbols) + nt_vals = map(s -> tighten_types!!(subset(vnv, [VarName{s}()])), symbols) return OrderedDict(zip(symbols, nt_vals)) end @@ -1508,6 +1550,16 @@ function Base.delete!(vnv::VarNamedVector, vn::VarName) return vnv end +""" + delete!!(vnv::VarNamedVector, vn::VarName) + +Like `delete!!`, but tightens the element types of the returned `VarNamedVector`. + +# See also: +[`tighten_types!!`](@ref) +""" +BangBang.delete!!(vnv::VarNamedVector, vn::VarName) = tighten_types!!(delete!(vnv, vn)) + """ values_as(vnv::VarNamedVector[, T]) diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index 3fd76ffe2..b764d517b 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -148,10 +148,10 @@ end # Empty. vnv = DynamicPPL.VarNamedVector() @test isempty(vnv) - @test eltype(vnv) == Real + @test eltype(vnv) == Union{} # Empty with types. - vnv = DynamicPPL.VarNamedVector{VarName,Float64}() + vnv = DynamicPPL.VarNamedVector{VarName,Float64,typeof(identity)}() @test isempty(vnv) @test eltype(vnv) == Float64 end @@ -369,13 +369,17 @@ end # Explicitly setting the transformation. increment(x) = x .+ 10 vnv = deepcopy(vnv_base) - vnv = DynamicPPL.loosen_types!!(vnv, typeof(vn_left), typeof(increment)) + vnv = DynamicPPL.loosen_types!!( + vnv, typeof(vn_left), eltype(vnv), typeof(increment) + ) DynamicPPL.setindex_internal!( vnv, to_vec_left(val_left .+ 100), vn_left, increment ) @test vnv[vn_left] == to_vec_left(val_left .+ 110) - vnv = DynamicPPL.loosen_types!!(vnv, typeof(vn_right), typeof(increment)) + vnv = DynamicPPL.loosen_types!!( + vnv, typeof(vn_right), eltype(vnv), typeof(increment) + ) DynamicPPL.setindex_internal!( vnv, to_vec_right(val_right .+ 100), vn_right, increment )