From 4ca9cf74fe1fd209a83fa1e4e99188b7e74e9a98 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 6 Nov 2025 13:05:35 +0000 Subject: [PATCH 1/5] Squeeze down VarInfo allocations --- src/varinfo.jl | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 486d24191..f149c3522 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -394,7 +394,9 @@ end for f in names mdf = :(metadata.$f) len = :(sum(length, $mdf.ranges)) - push!(exprs, :($f = unflatten_metadata($mdf, x[($offset + 1):($offset + $len)]))) + push!( + exprs, :($f = unflatten_metadata($mdf, @view x[($offset + 1):($offset + $len)])) + ) offset = :($offset + $len) end length(exprs) == 0 && return :(NamedTuple()) @@ -755,7 +757,10 @@ getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi, # since then we might be returning a `SubArray` rather than an `Array`, which is typically # what a bijector would result in, even if the input is a view (`SubArray`). # TODO(torfjelde): An alternative is to implement `view` directly instead. -getindex_internal(md::Metadata, vn::VarName) = getindex(md.vals, getrange(md, vn)) +function getindex_internal(md::Metadata, vn::VarName) + rng = getrange(md, vn) + return @view md.vals[rng] +end function getindex_internal(vi::VarInfo, vns::Vector{<:VarName}) return mapreduce(Base.Fix1(getindex_internal, vi), vcat, vns) end @@ -1495,8 +1500,21 @@ space. If some but only some of the variables in `vi` are transformed, this function will return `true`. This behavior will likely change in the future. """ -function is_transformed(vi::VarInfo) - return any(is_transformed(vi, vn) for vn in keys(vi)) +function is_transformed(vi::NTVarInfo) + return is_transformed(vi.metadata) +end + +@generated function is_transformed(nt::NamedTuple{names}) where {names} + expr = Expr(:block) + push!(expr.args, :(result = false)) + for n in names + push!(expr.args, :(result = result || is_transformed(nt.$n))) + end + return expr +end + +function is_transformed(md::Metadata) + return any(md.is_transformed) end # The default getindex & setindex!() for get & set values @@ -1552,7 +1570,7 @@ end @generated function _getindex(metadata, ranges::NamedTuple{names}) where {names} expr = Expr(:tuple) for f in names - push!(expr.args, :(metadata.$f.vals[ranges.$f])) + push!(expr.args, :(@view metadata.$f.vals[ranges.$f])) end return expr end From 7c6e8c1b692663357cd339bcb96a60cbfd65adfc Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 6 Nov 2025 13:08:32 +0000 Subject: [PATCH 2/5] Remove old out-of-date comment --- src/varinfo.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index f149c3522..0d77c70e6 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -753,10 +753,6 @@ function getdist(::VarNamedVector, ::VarName) end getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi, vn), vn) -# TODO(torfjelde): Use `view` instead of `getindex`. Requires addressing type-stability issues though, -# since then we might be returning a `SubArray` rather than an `Array`, which is typically -# what a bijector would result in, even if the input is a view (`SubArray`). -# TODO(torfjelde): An alternative is to implement `view` directly instead. function getindex_internal(md::Metadata, vn::VarName) rng = getrange(md, vn) return @view md.vals[rng] From 5c817a46ad0b040a1d39ee71d77861ba5b13d86f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 13:25:16 +0000 Subject: [PATCH 3/5] implement `is_transformed(::VarNamedVector)` --- src/varnamedvector.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 17b851d1d..a81f33ea5 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -367,6 +367,12 @@ Return a boolean for whether `vn` is guaranteed to have been transformed so that is all of Euclidean space. """ is_transformed(vnv::VarNamedVector, vn::VarName) = vnv.is_unconstrained[getidx(vnv, vn)] +""" + is_transformed(vnv::VarNamedVector) + +Return true if any variable in `vnv` is guaranteed to have been transformed. +""" +is_transformed(vnv::VarNamedVector) = any(vnv.is_unconstrained) """ set_transformed!(vnv::VarNamedVector, val::Bool, vn::VarName) From 93daa2b3286d1cc68e1c176f9f34b001e3561a95 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 13:36:15 +0000 Subject: [PATCH 4/5] Handle errors in benchmark suite --- benchmarks/benchmarks.jl | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 035d8ff49..cf5e7daa6 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -87,12 +87,18 @@ results_table = Tuple{String,Int,String,String,Bool,Float64,Float64}[] for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations @info "Running benchmark for $model_name" - suite = make_suite(model, varinfo_choice, adbackend, islinked) - results = run(suite) - eval_time = median(results["evaluation"]).time - relative_eval_time = eval_time / reference_time - ad_eval_time = median(results["gradient"]).time - relative_ad_eval_time = ad_eval_time / eval_time + try + suite = make_suite(model, varinfo_choice, adbackend, islinked) + results = run(suite) + eval_time = median(results["evaluation"]).time + relative_eval_time = eval_time / reference_time + ad_eval_time = median(results["gradient"]).time + relative_ad_eval_time = ad_eval_time / eval_time + catch e + @warn "Benchmark failed for $model_name with error: $e" + relative_eval_time = NaN + relative_ad_eval_time = NaN + end push!( results_table, ( From 6e633d2b400f0178fc4391426924eb6f26787366 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 7 Nov 2025 03:09:00 +0000 Subject: [PATCH 5/5] actually fix the benchmarks because i want to see them --- benchmarks/benchmarks.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index cf5e7daa6..3cb088273 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -87,17 +87,15 @@ results_table = Tuple{String,Int,String,String,Bool,Float64,Float64}[] for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations @info "Running benchmark for $model_name" - try + relative_eval_time, relative_ad_eval_time = try suite = make_suite(model, varinfo_choice, adbackend, islinked) results = run(suite) eval_time = median(results["evaluation"]).time - relative_eval_time = eval_time / reference_time ad_eval_time = median(results["gradient"]).time - relative_ad_eval_time = ad_eval_time / eval_time + (eval_time / reference_time), (ad_eval_time / eval_time) catch e @warn "Benchmark failed for $model_name with error: $e" - relative_eval_time = NaN - relative_ad_eval_time = NaN + NaN, NaN end push!( results_table,