From 150de719fe459a8718bc9ade591ac835813e48dc Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Oct 2025 19:09:57 +0000 Subject: [PATCH 01/11] Change VNV to use Dict rather than OrderedDict --- src/varnamedvector.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index f68498e46..e373a4e95 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -107,7 +107,7 @@ julia> vnv.vals 6 julia> println(vnv.num_inactive); -OrderedDict(1 => 2) +Dict(1 => 2) ``` This helps avoid unnecessary memory allocations for values that repeatedly change dimension. @@ -138,7 +138,7 @@ struct VarNamedVector{ """ mapping from a `VarName` to its integer index in `varnames`, `ranges` and `transforms` """ - varname_to_index::OrderedDict{K,Int} + varname_to_index::Dict{K,Int} """ vector of `VarNames` for the variables, where `varnames[varname_to_index[vn]] == vn` @@ -182,7 +182,7 @@ struct VarNamedVector{ Inactive entries always come after the last active entry for the given variable. See the extended help with `??VarNamedVector` for more details. """ - num_inactive::OrderedDict{Int,Int} + num_inactive::Dict{Int,Int} function VarNamedVector( varname_to_index, @@ -191,7 +191,7 @@ struct VarNamedVector{ vals::TVal, transforms::TTrans, is_unconstrained=fill!(BitVector(undef, length(varnames)), 0), - num_inactive=OrderedDict{Int,Int}(); + num_inactive=Dict{Int,Int}(); check_consistency::Bool=CHECK_CONSISTENCY_DEFAULT, ) where {K,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector} if check_consistency @@ -307,7 +307,7 @@ function VarNamedVector( if !(eltype(varnames) <: VarName) varnames = convert(Vector{VarName}, varnames) end - varname_to_index = OrderedDict{eltype(varnames),Int}( + varname_to_index = Dict{eltype(varnames),Int}( vn => i for (i, vn) in enumerate(varnames) ) vals = reduce(vcat, vals_vecs) @@ -857,7 +857,7 @@ function loosen_types!!( vn_type = promote_type(K, KNew) transform_type = promote_type(Trans, TransNew) return VarNamedVector( - OrderedDict{vn_type,Int}(vnv.varname_to_index), + Dict{vn_type,Int}(vnv.varname_to_index), Vector{vn_type}(vnv.varnames), vnv.ranges, vnv.vals, @@ -913,7 +913,7 @@ julia> vnv_tight.transforms """ function tighten_types(vnv::VarNamedVector) return VarNamedVector( - OrderedDict(vnv.varname_to_index...), + Dict(vnv.varname_to_index...), map(identity, vnv.varnames), copy(vnv.ranges), map(identity, vnv.vals), @@ -1150,7 +1150,7 @@ function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) F = promote_type(F_left, F_right) # Allocate. - varname_to_index = OrderedDict{V,Int}() + varname_to_index = Dict{V,Int}() ranges = UnitRange{Int}[] vals = T[] transforms = F[] From 30ac1d0968da3cb171b5d0fdabf4893e15e00e29 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Oct 2025 19:13:47 +0000 Subject: [PATCH 02/11] Change concretisation from map(identity, x) to a comprehension --- src/debug_utils.jl | 2 +- src/logdensityfunction.jl | 4 ++-- src/test_utils/ad.jl | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 13124e3a7..deb2a9256 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -27,7 +27,7 @@ add_io_context(io::IO) = IOContext(io, :compact => true, :limit => true) show_varname(io::IO, varname::VarName) = print(io, varname) function show_varname(io::IO, varname::Array{<:VarName,N}) where {N} # Attempt to make the type concrete in case the symbol is shared. - return _show_varname(io, map(identity, varname)) + return _show_varname(io, [vn for vn in varname]) end function _show_varname(io::IO, varname::Array{<:VarName,N}) where {N} # Print the first and last element of the array. diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index e5e6a6dae..68281cfec 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -146,7 +146,7 @@ struct LogDensityFunction{ is_supported(adtype) || @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." # Get a set of dummy params to use for prep - x = map(identity, varinfo[:]) + x = [val for val in varinfo[:]] if use_closure(adtype) prep = DI.prepare_gradient( LogDensityAt(model, getlogdensity, varinfo), adtype, x @@ -282,7 +282,7 @@ function LogDensityProblems.logdensity_and_gradient( ) where {M,F,V,AD<:ADTypes.AbstractADType} f.prep === nothing && error("Gradient preparation not available; this should not happen") - x = map(identity, x) # Concretise type + x = [val for val in x] # Concretise type # Make branching statically inferrable, i.e. type-stable (even if the two # branches happen to return different types) return if use_closure(f.adtype) diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 79442fccf..1cd83ec0a 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -260,7 +260,7 @@ function run_ad( if isnothing(params) params = varinfo[:] end - params = map(identity, params) # Concretise + params = [p for p in params] # Concretise # Calculate log-density and gradient with the backend of interest verbose && @info "Running AD on $(model.f) with $(adtype)\n" From 4ae0c6d4d064a4f77d7717248bf285547dcf4553 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Oct 2025 19:24:28 +0000 Subject: [PATCH 03/11] Improve tighten_types!! and loosen_types!! --- docs/src/api.md | 2 +- src/varnamedvector.jl | 157 ++++++++++++++++++++++++----------------- test/varnamedvector.jl | 6 +- 3 files changed, 98 insertions(+), 67 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 80970c0bb..469959947 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -414,7 +414,7 @@ DynamicPPL.reset! DynamicPPL.update! DynamicPPL.insert! DynamicPPL.loosen_types!! -DynamicPPL.tighten_types +DynamicPPL.tighten_types!! ``` ```@docs diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index e373a4e95..41c40ee72 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -116,24 +116,24 @@ like `setindex!` and `getindex!` rather than directly accessing `vnv.vals`. ```jldoctest varnamedvector-struct julia> vnv[@varname(x)] -2-element Vector{Real}: +2-element Vector{Float64}: 46.0 48.0 julia> getindex_internal(vnv, :) -8-element Vector{Real}: +8-element Vector{Float64}: 46.0 48.0 - 1 - 2 - 3 - 4 - 5 - 6 + 1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 ``` """ struct VarNamedVector{ - K<:VarName,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector + K<:VarName,V,T,KVec<:AbstractVector{K},VVec<:AbstractVector{V},TVec<:AbstractVector{T} } """ mapping from a `VarName` to its integer index in `varnames`, `ranges` and `transforms` @@ -143,7 +143,7 @@ struct VarNamedVector{ """ vector of `VarNames` for the variables, where `varnames[varname_to_index[vn]] == vn` """ - varnames::TVN # AbstractVector{<:VarName} + varnames::KVec """ vector of index ranges in `vals` corresponding to `varnames`; each `VarName` `vn` has @@ -156,14 +156,14 @@ struct VarNamedVector{ vector of values of all variables; the value(s) of `vn` is/are `vals[ranges[varname_to_index[vn]]]` """ - vals::TVal # AbstractVector{<:Real} + vals::VVec """ vector of transformations, so that `transforms[varname_to_index[vn]]` is a callable that transforms the value of `vn` back to its original space, undoing any linking and vectorisation """ - transforms::TTrans + transforms::TVec """ vector of booleans indicating whether a variable has been explicitly transformed to @@ -186,14 +186,14 @@ struct VarNamedVector{ function VarNamedVector( varname_to_index, - varnames::TVN, + varnames::KVec, ranges, - vals::TVal, - transforms::TTrans, + vals::VVec, + transforms::TVec, is_unconstrained=fill!(BitVector(undef, length(varnames)), 0), num_inactive=Dict{Int,Int}(); check_consistency::Bool=CHECK_CONSISTENCY_DEFAULT, - ) where {K,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector} + ) where {K,V,T,KVec<:AbstractVector{K},VVec<:AbstractVector{V},TVec<:AbstractVector{T}} if check_consistency if length(varnames) != length(ranges) || length(varnames) != length(transforms) || @@ -257,7 +257,7 @@ struct VarNamedVector{ # tiny bit of thought. end - return new{K,V,TVN,TVal,TTrans}( + return new{K,V,T,KVec,VVec,TVec}( varname_to_index, varnames, ranges, @@ -269,9 +269,9 @@ struct VarNamedVector{ end end -function VarNamedVector{K,V}() where {K,V} +function VarNamedVector{K,V,T}() where {K,V,T} return VarNamedVector( - OrderedDict{K,Int}(), K[], UnitRange{Int}[], V[], Any[]; check_consistency=false + Dict{K,Int}(), K[], UnitRange{Int}[], V[], T[]; check_consistency=false ) end @@ -280,7 +280,7 @@ end # VarName and element types only as necessary, which would help keep them concrete. However, # making that change here opens some other cans of worms related to how VarInfo uses # BangBang, that I don't want to deal with right now. -VarNamedVector() = VarNamedVector{VarName,Real}() +VarNamedVector() = VarNamedVector{VarName,Real,An}() function VarNamedVector(xs::Pair...; check_consistency=CHECK_CONSISTENCY_DEFAULT) return VarNamedVector(OrderedDict(xs...); check_consistency=check_consistency) end @@ -345,6 +345,12 @@ function ==(vnv_left::VarNamedVector, vnv_right::VarNamedVector) vnv_left.num_inactive == vnv_right.num_inactive end +function is_concretely_typed(vnv::VarNamedVector) + return isconcretetype(eltype(vnv.varnames)) && + isconcretetype(eltype(vnv.vals)) && + isconcretetype(eltype(vnv.transforms)) +end + getidx(vnv::VarNamedVector, vn::VarName) = vnv.varname_to_index[vn] getrange(vnv::VarNamedVector, idx::Int) = vnv.ranges[idx] @@ -562,7 +568,7 @@ to be the default vectorisation transform. This undoes any possible linking. ```jldoctest varnamedvector-reset julia> using DynamicPPL: VarNamedVector, @varname, reset! -julia> vnv = VarNamedVector(); +julia> vnv = VarNamedVector{VarName,Any,Any}(); julia> vnv[@varname(x)] = reshape(1:9, (3, 3)); @@ -810,7 +816,7 @@ end # with every ! call replaced with a !! call. """ - loosen_types!!(vnv::VarNamedVector{K,V,TVN,TVal,TTrans}, ::Type{KNew}, ::Type{TransNew}) + loosen_types!!(vnv::VarNamedVector, ::Type{KNew}, ::Type{VNew}, ::Type{TNew}) Loosen the types of `vnv` to allow varname type `KNew` and transformation type `TransNew`. @@ -821,7 +827,7 @@ transformations of type `TransNew` can be pushed to it. Some of the underlying s shared between `vnv` and the return value, and thus mutating one may affect the other. # See also -[`tighten_types`](@ref) +[`tighten_types!!`](@ref) # Examples @@ -836,7 +842,9 @@ julia> setindex_internal!(vnv, collect(1:4), @varname(y), y_trans) ERROR: MethodError: Cannot `convert` an object of type [...] -julia> vnv_loose = DynamicPPL.loosen_types!!(vnv, typeof(@varname(y)), typeof(y_trans)); +julia> vnv_loose = DynamicPPL.loosen_types!!( + vnv, typeof(@varname(y)), Float64, typeof(y_trans) + ); julia> setindex_internal!(vnv_loose, collect(1:4), @varname(y), y_trans) @@ -847,40 +855,57 @@ julia> vnv_loose[@varname(y)] ``` """ function loosen_types!!( - vnv::VarNamedVector, ::Type{KNew}, ::Type{TransNew} -) where {KNew,TransNew} + vnv::VarNamedVector, ::Type{KNew}, ::Type{VNew}, ::Type{TNew} +) where {KNew,VNew,TNew} K = eltype(vnv.varnames) - Trans = eltype(vnv.transforms) - if KNew <: K && TransNew <: Trans + V = eltype(vnv.vals) + T = eltype(vnv.transforms) + if KNew <: K && VNew <: V && TNew <: T return vnv else vn_type = promote_type(K, KNew) - transform_type = promote_type(Trans, TransNew) - return VarNamedVector( - Dict{vn_type,Int}(vnv.varname_to_index), - Vector{vn_type}(vnv.varnames), - vnv.ranges, - vnv.vals, - Vector{transform_type}(vnv.transforms), - vnv.is_unconstrained, - vnv.num_inactive; - check_consistency=false, - ) + val_type = promote_type(V, VNew) + transform_type = promote_type(T, TNew) + # This function would work the same way if the first if statement a few lines above + # was skipped, and we only checked for the below condition. However, the first one + # is constant propagated away at compile time (at least on Julia v1.11.7), whereas + # this one isn't. Hence we keep both for performance. + return if vn_type == K && val_type == V && transform_type == T + vnv + elseif isempty(vnv) + VarNamedVector(vn_type[], val_type[], transform_type[]) + else + # TODO(mhauru) We allow a `vnv` to have any AbstractVector type as its vals, but + # then here always revert to Vector. + VarNamedVector( + Dict{vn_type,Int}(vnv.varname_to_index), + Vector{vn_type}(vnv.varnames), + vnv.ranges, + Vector{val_type}(vnv.vals), + Vector{transform_type}(vnv.transforms), + vnv.is_unconstrained, + vnv.num_inactive; + check_consistency=false, + ) + end end end """ - tighten_types(vnv::VarNamedVector) + tighten_types!!(vnv::VarNamedVector) -Return a copy of `vnv` with the most concrete types possible. +Return a `VarNamedVector` like `vnv` with the most concrete types possible. + +This function either returns `vnv` itself or new `VarNamedVector` with the same values in +it, but with the element types of various containers made as concrete as possible. For instance, if `vnv` has its vector of transforms have eltype `Any`, but all the transforms are actually identity transformations, this function will return a new `VarNamedVector` with the transforms vector having eltype `typeof(identity)`. -This is a lot like the reverse of [`loosen_types!!`](@ref), but with two notable -differences: Unlike `loosen_types!!`, this function does not mutate `vnv`; it also changes -not only the key and transform eltypes, but also the values eltype. +This is a lot like the reverse of [`loosen_types!!`](@ref). Like with `loosen_types!!`, the +return value may share some of its underlying storage with `vnv`, and thus mutating one may +affect the other. # See also [`loosen_types!!`](@ref) @@ -890,9 +915,9 @@ not only the key and transform eltypes, but also the values eltype. ```jldoctest varnamedvector-tighten-types julia> using DynamicPPL: VarNamedVector, @varname, loosen_types!!, setindex_internal! -julia> vnv = VarNamedVector(); +julia> vnv = VarNamedVector(@varname(x) => Real[23], @varname(y) => randn(2,2)); -julia> setindex!(vnv, [23], @varname(x)) +julia> vnv = delete!(vnv, @varname(y)); julia> eltype(vnv) Real @@ -901,7 +926,7 @@ julia> vnv.transforms 1-element Vector{Any}: identity (generic function with 1 method) -julia> vnv_tight = DynamicPPL.tighten_types(vnv); +julia> vnv_tight = DynamicPPL.tighten_types!!(vnv); julia> eltype(vnv_tight) == Int true @@ -911,17 +936,24 @@ julia> vnv_tight.transforms identity (generic function with 1 method) ``` """ -function tighten_types(vnv::VarNamedVector) - return VarNamedVector( - Dict(vnv.varname_to_index...), - map(identity, vnv.varnames), - copy(vnv.ranges), - map(identity, vnv.vals), - map(identity, vnv.transforms), - copy(vnv.is_unconstrained), - copy(vnv.num_inactive); - check_consistency=false, - ) +function tighten_types!!(vnv::VarNamedVector) + return if is_concretely_typed(vnv) + # There can not be anything to tighten, so short-circuit. + vnv + elseif isempty(vnv) + VarNamedVector() + else + VarNamedVector( + Dict(vnv.varname_to_index...), + [x for x in vnv.varnames], + vnv.ranges, + [x for x in vnv.vals], + [x for x in vnv.transforms], + vnv.is_unconstrained, + vnv.num_inactive; + check_consistency=false, + ) + end end function BangBang.setindex!!(vnv::VarNamedVector, val, vn::VarName) @@ -977,14 +1009,14 @@ function insert_internal!!(vnv::VarNamedVector, val, vn::VarName, transform=noth if transform === nothing transform = identity end - vnv = loosen_types!!(vnv, typeof(vn), typeof(transform)) + vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform)) insert_internal!(vnv, val, vn, transform) return vnv end function update_internal!!(vnv::VarNamedVector, val, vn::VarName, transform=nothing) transform_resolved = transform === nothing ? gettransform(vnv, vn) : transform - vnv = loosen_types!!(vnv, typeof(vn), typeof(transform_resolved)) + vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform_resolved)) update_internal!(vnv, val, vn, transform) return vnv end @@ -1219,7 +1251,6 @@ julia> subset(vnv, [@varname(x[2])]) == VarNamedVector(@varname(x[2]) => [2.0]) true """ function subset(vnv::VarNamedVector, vns_given::AbstractVector{<:VarName}) - # NOTE: This does not specialize types when possible. vnv_new = similar(vnv) # Return early if possible. isempty(vnv) && return vnv_new @@ -1231,7 +1262,7 @@ function subset(vnv::VarNamedVector, vns_given::AbstractVector{<:VarName}) end end - return vnv_new + return tighten_types!!(vnv_new) end """ @@ -1430,7 +1461,7 @@ true """ function group_by_symbol(vnv::VarNamedVector) symbols = unique(map(getsym, vnv.varnames)) - nt_vals = map(s -> tighten_types(subset(vnv, [VarName{s}()])), symbols) + nt_vals = map(s -> tighten_types!!(subset(vnv, [VarName{s}()])), symbols) return OrderedDict(zip(symbols, nt_vals)) end diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index 3fd76ffe2..59539a1a7 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -151,7 +151,7 @@ end @test eltype(vnv) == Real # Empty with types. - vnv = DynamicPPL.VarNamedVector{VarName,Float64}() + vnv = DynamicPPL.VarNamedVector{VarName,Float64,typeof(identity)}() @test isempty(vnv) @test eltype(vnv) == Float64 end @@ -369,13 +369,13 @@ end # Explicitly setting the transformation. increment(x) = x .+ 10 vnv = deepcopy(vnv_base) - vnv = DynamicPPL.loosen_types!!(vnv, typeof(vn_left), typeof(increment)) + vnv = DynamicPPL.loosen_types!!(vnv, typeof(vn_left), eltype(vnv), typeof(increment)) DynamicPPL.setindex_internal!( vnv, to_vec_left(val_left .+ 100), vn_left, increment ) @test vnv[vn_left] == to_vec_left(val_left .+ 110) - vnv = DynamicPPL.loosen_types!!(vnv, typeof(vn_right), typeof(increment)) + vnv = DynamicPPL.loosen_types!!(vnv, typeof(vn_right), eltype(vnv), typeof(increment)) DynamicPPL.setindex_internal!( vnv, to_vec_right(val_right .+ 100), vn_right, increment ) From 4c8b006efa21b7013ef14021f863b4c5562538ed Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Oct 2025 19:26:44 +0000 Subject: [PATCH 04/11] Fix use of set_transformed!! --- src/contexts/init.jl | 4 +++- src/simple_varinfo.jl | 1 + src/varinfo.jl | 12 +++++++++--- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 396e1463f..44dbc5508 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -180,7 +180,9 @@ function tilde_assume!!( end # Neither of these set the `trans` flag so we have to do it manually if # necessary. - insert_transformed_value && set_transformed!!(vi, true, vn) + if insert_transformed_value + vi = set_transformed!!(vi, true, vn) + end # `accumulate_assume!!` wants untransformed values as the second argument. vi = accumulate_assume!!(vi, x, logjac, vn, dist) # We always return the untransformed value here, as that will determine diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 2ba25f142..434480be6 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -484,6 +484,7 @@ function set_transformed!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName) "Individual variables in SimpleVarInfo cannot have different `set_transformed` statuses.", ) end + return vi end is_transformed(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) diff --git a/src/varinfo.jl b/src/varinfo.jl index 734bf3db5..883c9e998 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -789,10 +789,16 @@ function setval!(md::Metadata, val, vn::VarName) return md.vals[getrange(md, vn)] = tovec(val) end +function set_transformed!!(vi::NTVarInfo, val::Bool, vn::VarName) + md = set_transformed!!(getmetadata(vi, vn), val, vn) + return Accessors.@set vi.metadata[getsym(vn)] = md +end + function set_transformed!!(vi::VarInfo, val::Bool, vn::VarName) - set_transformed!!(getmetadata(vi, vn), val, vn) - return vi + md = set_transformed!!(getmetadata(vi, vn), val, vn) + return VarInfo(md, vi.accs) end + function set_transformed!!(metadata::Metadata, val::Bool, vn::VarName) metadata.is_transformed[getidx(metadata, vn)] = val return metadata @@ -800,7 +806,7 @@ end function set_transformed!!(vi::VarInfo, val::Bool) for vn in keys(vi) - set_transformed!!(vi, val, vn) + vi = set_transformed!!(vi, val, vn) end return vi From c8b0b882f94efa2a124025db025b173c92db1da3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Oct 2025 19:37:19 +0000 Subject: [PATCH 05/11] Fix push!! for VarInfos --- src/varinfo.jl | 52 ++++++++++++++++++++++++++++++------------- src/varnamedvector.jl | 7 +++++- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 883c9e998..6cbe27b6a 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1658,30 +1658,45 @@ end Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to the `VarInfo` `vi`, mutating if it makes sense. """ -function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) - if vi isa UntypedVarInfo - @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist" - elseif vi isa NTVarInfo - @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist" - end +function BangBang.push!!(vi::VarInfo, vn::VarName, val, dist::Distribution) + @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist" + md = push!!(getmetadata(vi, vn), vn, val, dist) + return VarInfo(md, vi.accs) +end +function BangBang.push!!(vi::NTVarInfo, vn::VarName, val, dist::Distribution) + @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist" sym = getsym(vn) - if vi isa NTVarInfo && ~haskey(vi.metadata, sym) + meta = if ~haskey(vi.metadata, sym) # The NamedTuple doesn't have an entry for this variable, let's add one. - val = tovec(r) - md = Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false])) - vi = Accessors.@set vi.metadata[sym] = md + _new_submetadata(vi, vn, val, dist) else - meta = getmetadata(vi, vn) - push!(meta, vn, r, dist) + push!!(getmetadata(vi, vn), vn, val, dist) end - + vi = Accessors.@set vi.metadata[sym] = meta return vi end -function Base.push!(vi::UntypedVectorVarInfo, vn::VarName, val, args...) - push!(getmetadata(vi, vn), vn, val, args...) - return vi +""" + _new_submetadata(vi::VarInfo{NamedTuple{Names,SubMetas}}, args...) where {Names,SubMetas} + +Create a new sub-metadata for an NTVarInfo. The type is chosen by the types of existing +SubMetas. +""" +@generated function _new_submetadata(vi::VarInfo{NamedTuple{Names,SubMetas}}, vn, r, dist) where {Names,SubMetas} + has_vnv = any(s -> s <: VarNamedVector, SubMetas.parameters) + return if has_vnv + :(return _new_vnv_submetadata(vn, r, dist)) + else + :(return _new_metadata_submetadata(vn, r, dist)) + end +end + +_new_vnv_submetadata(vn, r, _) = VarNamedVector([vn], [r]) + +function _new_metadata_submetadata(vn, r, dist) + val = tovec(r) + return Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false])) end function Base.push!(vi::UntypedVectorVarInfo, pair::Pair, args...) @@ -1706,6 +1721,11 @@ function Base.push!(meta::Metadata, vn, r, dist) return meta end +function BangBang.push!!(meta::Metadata, vn, r, dist) + push!(meta, vn, r, dist) + return meta +end + function Base.delete!(vi::VarInfo, vn::VarName) delete!(getmetadata(vi, vn), vn) return vi diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 41c40ee72..fbcfb58ce 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -803,11 +803,16 @@ function update_internal!( return nothing end -function BangBang.push!(vnv::VarNamedVector, vn, val, dist) +function Base.push!(vnv::VarNamedVector, vn, val, dist) f = from_vec_transform(dist) return setindex_internal!(vnv, tovec(val), vn, f) end +function BangBang.push!!(vnv::VarNamedVector, vn, val, dist) + f = from_vec_transform(dist) + return setindex_internal!!(vnv, tovec(val), vn, f) +end + # BangBang versions of the above functions. # The only difference is that update_internal!! and insert_internal!! check whether the # container types of the VarNamedVector vector need to be expanded to accommodate the new From 1f7152b2b17859cd5d1e01a07862975e498bfc1a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Oct 2025 19:39:30 +0000 Subject: [PATCH 06/11] Change the default element types in VNV to be Union{} --- src/varnamedvector.jl | 66 +++++++++++++++++++++++------------------- test/varnamedvector.jl | 2 +- 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index fbcfb58ce..3cc0a469e 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -56,26 +56,26 @@ $(FIELDS) The values for different variables are internally all stored in a single vector. For instance, ```jldoctest varnamedvector-struct -julia> using DynamicPPL: ReshapeTransform, VarNamedVector, @varname, setindex!, update!, getindex_internal +julia> using DynamicPPL: ReshapeTransform, VarNamedVector, @varname, setindex!!, update!!, getindex_internal julia> vnv = VarNamedVector(); -julia> setindex!(vnv, [0.0, 0.0, 0.0, 0.0], @varname(x)); +julia> vnv = setindex!!(vnv, [0.0, 0.0, 0.0, 0.0], @varname(x)); -julia> setindex!(vnv, reshape(1:6, (2,3)), @varname(y)); +julia> vnv = setindex!!(vnv, reshape(1:6, (2,3)), @varname(y)); julia> vnv.vals -10-element Vector{Real}: +10-element Vector{Float64}: 0.0 0.0 0.0 0.0 - 1 - 2 - 3 - 4 - 5 - 6 + 1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 ``` The `varnames`, `ranges`, and `varname_to_index` fields keep track of which value belongs to @@ -91,20 +91,20 @@ If a variable is updated with a new value that is of a smaller dimension than th value, rather than resizing `vnv.vals`, some elements in `vnv.vals` are marked as inactive. ```jldoctest varnamedvector-struct -julia> update!(vnv, [46.0, 48.0], @varname(x)) +julia> vnv = update!!(vnv, [46.0, 48.0], @varname(x)); julia> vnv.vals -10-element Vector{Real}: +10-element Vector{Float64}: 46.0 48.0 0.0 0.0 - 1 - 2 - 3 - 4 - 5 - 6 + 1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 julia> println(vnv.num_inactive); Dict(1 => 2) @@ -275,12 +275,7 @@ function VarNamedVector{K,V,T}() where {K,V,T} ) end -# TODO(mhauru) I would like for this to be VarNamedVector(Union{}, Union{}). Simlarly the -# transform vector type above could then be Union{}[]. This would allow expanding the -# VarName and element types only as necessary, which would help keep them concrete. However, -# making that change here opens some other cans of worms related to how VarInfo uses -# BangBang, that I don't want to deal with right now. -VarNamedVector() = VarNamedVector{VarName,Real,An}() +VarNamedVector() = VarNamedVector{Union{},Union{},Union{}}() function VarNamedVector(xs::Pair...; check_consistency=CHECK_CONSISTENCY_DEFAULT) return VarNamedVector(OrderedDict(xs...); check_consistency=check_consistency) end @@ -298,15 +293,16 @@ function VarNamedVector( transforms=fill(identity, length(varnames)); check_consistency=CHECK_CONSISTENCY_DEFAULT, ) + if isempty(varnames) && isempty(orig_vals) && isempty(transforms) + return VarNamedVector{eltype(varnames),eltype(orig_vals),eltype(transforms)}() + end # Convert `vals` into a vector of vectors. vals_vecs = map(tovec, orig_vals) transforms = map( (t, val) -> _compose_no_identity(t, from_vec_transform(val)), transforms, orig_vals ) - # TODO: Is this really the way to do this? - if !(eltype(varnames) <: VarName) - varnames = convert(Vector{VarName}, varnames) - end + # Make `varnames` have as concrete an element type as possible. + varnames = [v for v in varnames] varname_to_index = Dict{eltype(varnames),Int}( vn => i for (i, vn) in enumerate(varnames) ) @@ -1010,7 +1006,7 @@ function setindex_internal!!( end end -function insert_internal!!(vnv::VarNamedVector, val, vn::VarName, transform=nothing) +function insert_internal!!(vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing) if transform === nothing transform = identity end @@ -1019,7 +1015,7 @@ function insert_internal!!(vnv::VarNamedVector, val, vn::VarName, transform=noth return vnv end -function update_internal!!(vnv::VarNamedVector, val, vn::VarName, transform=nothing) +function update_internal!!(vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing) transform_resolved = transform === nothing ? gettransform(vnv, vn) : transform vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform_resolved)) update_internal!(vnv, val, vn, transform) @@ -1544,6 +1540,16 @@ function Base.delete!(vnv::VarNamedVector, vn::VarName) return vnv end +""" + delete!!(vnv::VarNamedVector, vn::VarName) + +Like `delete!!`, but tightens the element types of the returned `VarNamedVector`. + +# See also: +[`tighten_types!!`](@ref) +""" +BangBang.delete!!(vnv::VarNamedVector, vn::VarName) = tighten_types!!(delete!(vnv, vn)) + """ values_as(vnv::VarNamedVector[, T]) diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index 59539a1a7..8750fbf59 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -148,7 +148,7 @@ end # Empty. vnv = DynamicPPL.VarNamedVector() @test isempty(vnv) - @test eltype(vnv) == Real + @test eltype(vnv) == Union{} # Empty with types. vnv = DynamicPPL.VarNamedVector{VarName,Float64,typeof(identity)}() From 61c96b0803776687d0539454c50dd08789de1112 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Oct 2025 19:40:32 +0000 Subject: [PATCH 07/11] In untyped_vector_varinfo, don't rely on Metadata --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 6cbe27b6a..dc216e21d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -315,7 +315,7 @@ function untyped_vector_varinfo( model::Model, init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return untyped_vector_varinfo(untyped_varinfo(rng, model, init_strategy)) + return last(init!!(rng, model, VarInfo(VarNamedVector()), init_strategy)) end function untyped_vector_varinfo( model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() From 2a10be9e1c609312549b5cec4cfc9fbdc54b827b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Oct 2025 19:40:56 +0000 Subject: [PATCH 08/11] Code style --- src/debug_utils.jl | 2 +- src/logdensityfunction.jl | 2 +- src/test_utils/ad.jl | 2 +- src/varinfo.jl | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index deb2a9256..e8b50a0b7 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -407,7 +407,7 @@ julia> @model function demo_incorrect() end demo_incorrect (generic function with 2 methods) -julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually +julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually # alert us to the issue of `x` being sampled twice. model = demo_incorrect(); varinfo = VarInfo(model); diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 68281cfec..21f6817bd 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -49,7 +49,7 @@ box: - [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of linking - [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected - by linking, since transforms are only applied to random variables) + by linking, since transforms are only applied to random variables) !!! note By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 1cd83ec0a..a49ffd18b 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -197,7 +197,7 @@ Everything else is optional, and can be categorised into several groups: 1. _How to specify the results to compare against._ Once logp and its gradient has been calculated with the specified `adtype`, - it can optionally be tested for correctness. The exact way this is tested + it can optionally be tested for correctness. The exact way this is tested is specified in the `test` parameter. There are several options for this: diff --git a/src/varinfo.jl b/src/varinfo.jl index dc216e21d..2ba852cb6 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -983,7 +983,7 @@ function filter_subsumed(filter_vns, filtered_vns) end @generated function _link!!( - ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} + ::NamedTuple{metadata_names}, vi, varnames::NamedTuple{vns_names} ) where {metadata_names,vns_names} expr = Expr(:block) for f in metadata_names @@ -994,7 +994,7 @@ end expr.args, quote f_vns = vi.metadata.$f.vns - f_vns = filter_subsumed(vns.$f, f_vns) + f_vns = filter_subsumed(varnames.$f, f_vns) if !isempty(f_vns) if !is_transformed(vi, f_vns[1]) # Iterate over all `f_vns` and transform From bb83d936c8f338dfe0c54461cd5ad9e71320b5e1 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 29 Oct 2025 09:21:43 +0000 Subject: [PATCH 09/11] Run formatter --- src/logdensityfunction.jl | 2 +- src/varinfo.jl | 4 +++- src/varnamedvector.jl | 8 ++++++-- test/varnamedvector.jl | 8 ++++++-- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 21f6817bd..3b7b84953 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -282,7 +282,7 @@ function LogDensityProblems.logdensity_and_gradient( ) where {M,F,V,AD<:ADTypes.AbstractADType} f.prep === nothing && error("Gradient preparation not available; this should not happen") - x = [val for val in x] # Concretise type + x = [val for val in x] # Concretise type # Make branching statically inferrable, i.e. type-stable (even if the two # branches happen to return different types) return if use_closure(f.adtype) diff --git a/src/varinfo.jl b/src/varinfo.jl index 2ba852cb6..a90b81488 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1683,7 +1683,9 @@ end Create a new sub-metadata for an NTVarInfo. The type is chosen by the types of existing SubMetas. """ -@generated function _new_submetadata(vi::VarInfo{NamedTuple{Names,SubMetas}}, vn, r, dist) where {Names,SubMetas} +@generated function _new_submetadata( + vi::VarInfo{NamedTuple{Names,SubMetas}}, vn, r, dist +) where {Names,SubMetas} has_vnv = any(s -> s <: VarNamedVector, SubMetas.parameters) return if has_vnv :(return _new_vnv_submetadata(vn, r, dist)) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 3cc0a469e..1c6c5116d 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1006,7 +1006,9 @@ function setindex_internal!!( end end -function insert_internal!!(vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing) +function insert_internal!!( + vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing +) if transform === nothing transform = identity end @@ -1015,7 +1017,9 @@ function insert_internal!!(vnv::VarNamedVector, val::AbstractVector, vn::VarName return vnv end -function update_internal!!(vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing) +function update_internal!!( + vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing +) transform_resolved = transform === nothing ? gettransform(vnv, vn) : transform vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform_resolved)) update_internal!(vnv, val, vn, transform) diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index 8750fbf59..b764d517b 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -369,13 +369,17 @@ end # Explicitly setting the transformation. increment(x) = x .+ 10 vnv = deepcopy(vnv_base) - vnv = DynamicPPL.loosen_types!!(vnv, typeof(vn_left), eltype(vnv), typeof(increment)) + vnv = DynamicPPL.loosen_types!!( + vnv, typeof(vn_left), eltype(vnv), typeof(increment) + ) DynamicPPL.setindex_internal!( vnv, to_vec_left(val_left .+ 100), vn_left, increment ) @test vnv[vn_left] == to_vec_left(val_left .+ 110) - vnv = DynamicPPL.loosen_types!!(vnv, typeof(vn_right), eltype(vnv), typeof(increment)) + vnv = DynamicPPL.loosen_types!!( + vnv, typeof(vn_right), eltype(vnv), typeof(increment) + ) DynamicPPL.setindex_internal!( vnv, to_vec_right(val_right .+ 100), vn_right, increment ) From 99a7d325f2e96a14944aef5361be3907b00b9073 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 29 Oct 2025 14:19:29 +0000 Subject: [PATCH 10/11] In VNV, use typejoin rather than promote_type --- src/varnamedvector.jl | 62 ++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 1c6c5116d..2c66e1245 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -65,17 +65,17 @@ julia> vnv = setindex!!(vnv, [0.0, 0.0, 0.0, 0.0], @varname(x)); julia> vnv = setindex!!(vnv, reshape(1:6, (2,3)), @varname(y)); julia> vnv.vals -10-element Vector{Float64}: +10-element Vector{Real}: 0.0 0.0 0.0 0.0 - 1.0 - 2.0 - 3.0 - 4.0 - 5.0 - 6.0 + 1 + 2 + 3 + 4 + 5 + 6 ``` The `varnames`, `ranges`, and `varname_to_index` fields keep track of which value belongs to @@ -94,17 +94,17 @@ value, rather than resizing `vnv.vals`, some elements in `vnv.vals` are marked a julia> vnv = update!!(vnv, [46.0, 48.0], @varname(x)); julia> vnv.vals -10-element Vector{Float64}: +10-element Vector{Real}: 46.0 48.0 0.0 0.0 - 1.0 - 2.0 - 3.0 - 4.0 - 5.0 - 6.0 + 1 + 2 + 3 + 4 + 5 + 6 julia> println(vnv.num_inactive); Dict(1 => 2) @@ -116,20 +116,20 @@ like `setindex!` and `getindex!` rather than directly accessing `vnv.vals`. ```jldoctest varnamedvector-struct julia> vnv[@varname(x)] -2-element Vector{Float64}: +2-element Vector{Real}: 46.0 48.0 julia> getindex_internal(vnv, :) -8-element Vector{Float64}: +8-element Vector{Real}: 46.0 48.0 - 1.0 - 2.0 - 3.0 - 4.0 - 5.0 - 6.0 + 1 + 2 + 3 + 4 + 5 + 6 ``` """ struct VarNamedVector{ @@ -864,9 +864,15 @@ function loosen_types!!( if KNew <: K && VNew <: V && TNew <: T return vnv else - vn_type = promote_type(K, KNew) - val_type = promote_type(V, VNew) - transform_type = promote_type(T, TNew) + # We could use promote_type here, instead of typejoin. However, that would e.g. + # cause Ints to be converted to Float64s, since + # promote_type(Int, Float64) == Float64, which can cause problems. See + # https://github.com/TuringLang/DynamicPPL.jl/pull/1098#discussion_r2472636188. + # Base.promote_typejoin would be like typejoin, but creates Unions out of Nothing + # and Missing, rather than falling back on Any. However, it's not exported. + vn_type = typejoin(K, KNew) + val_type = typejoin(V, VNew) + transform_type = typejoin(T, TNew) # This function would work the same way if the first if statement a few lines above # was skipped, and we only checked for the below condition. However, the first one # is constant propagated away at compile time (at least on Julia v1.11.7), whereas @@ -1171,12 +1177,12 @@ function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) # Determine `eltype` of `vals`. T_left = eltype(left_vnv.vals) T_right = eltype(right_vnv.vals) - T = promote_type(T_left, T_right) + T = typejoin(T_left, T_right) # Determine `eltype` of `varnames`. V_left = eltype(left_vnv.varnames) V_right = eltype(right_vnv.varnames) - V = promote_type(V_left, V_right) + V = typejoin(V_left, V_right) if !(V <: VarName) V = VarName end @@ -1184,7 +1190,7 @@ function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) # Determine `eltype` of `transforms`. F_left = eltype(left_vnv.transforms) F_right = eltype(right_vnv.transforms) - F = promote_type(F_left, F_right) + F = typejoin(F_left, F_right) # Allocate. varname_to_index = Dict{V,Int}() From d30eca80fb23e45274af6c1600c8974a7d121a02 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 29 Oct 2025 14:33:36 +0000 Subject: [PATCH 11/11] Bump patch version to 0.38.4 --- HISTORY.md | 4 ++++ Project.toml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 54b40b7e9..4b8f5980e 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,9 @@ # DynamicPPL Changelog +## 0.38.4 + +Improve performance of VarNamedVector. It should now be very nearly on par with Metadata for all models we've benchmarked on. + ## 0.38.3 Add an implementation of `returned(::Model, ::AbstractDict{<:VarName})`. diff --git a/Project.toml b/Project.toml index d54f9d1da..0773bbe04 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.38.3" +version = "0.38.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"