diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 035d8ff49..3cb088273 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -87,12 +87,16 @@ results_table = Tuple{String,Int,String,String,Bool,Float64,Float64}[] for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations @info "Running benchmark for $model_name" - suite = make_suite(model, varinfo_choice, adbackend, islinked) - results = run(suite) - eval_time = median(results["evaluation"]).time - relative_eval_time = eval_time / reference_time - ad_eval_time = median(results["gradient"]).time - relative_ad_eval_time = ad_eval_time / eval_time + relative_eval_time, relative_ad_eval_time = try + suite = make_suite(model, varinfo_choice, adbackend, islinked) + results = run(suite) + eval_time = median(results["evaluation"]).time + ad_eval_time = median(results["gradient"]).time + (eval_time / reference_time), (ad_eval_time / eval_time) + catch e + @warn "Benchmark failed for $model_name with error: $e" + NaN, NaN + end push!( results_table, ( diff --git a/src/varinfo.jl b/src/varinfo.jl index 486d24191..0d77c70e6 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -394,7 +394,9 @@ end for f in names mdf = :(metadata.$f) len = :(sum(length, $mdf.ranges)) - push!(exprs, :($f = unflatten_metadata($mdf, x[($offset + 1):($offset + $len)]))) + push!( + exprs, :($f = unflatten_metadata($mdf, @view x[($offset + 1):($offset + $len)])) + ) offset = :($offset + $len) end length(exprs) == 0 && return :(NamedTuple()) @@ -751,11 +753,10 @@ function getdist(::VarNamedVector, ::VarName) end getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi, vn), vn) -# TODO(torfjelde): Use `view` instead of `getindex`. Requires addressing type-stability issues though, -# since then we might be returning a `SubArray` rather than an `Array`, which is typically -# what a bijector would result in, even if the input is a view (`SubArray`). -# TODO(torfjelde): An alternative is to implement `view` directly instead. -getindex_internal(md::Metadata, vn::VarName) = getindex(md.vals, getrange(md, vn)) +function getindex_internal(md::Metadata, vn::VarName) + rng = getrange(md, vn) + return @view md.vals[rng] +end function getindex_internal(vi::VarInfo, vns::Vector{<:VarName}) return mapreduce(Base.Fix1(getindex_internal, vi), vcat, vns) end @@ -1495,8 +1496,21 @@ space. If some but only some of the variables in `vi` are transformed, this function will return `true`. This behavior will likely change in the future. """ -function is_transformed(vi::VarInfo) - return any(is_transformed(vi, vn) for vn in keys(vi)) +function is_transformed(vi::NTVarInfo) + return is_transformed(vi.metadata) +end + +@generated function is_transformed(nt::NamedTuple{names}) where {names} + expr = Expr(:block) + push!(expr.args, :(result = false)) + for n in names + push!(expr.args, :(result = result || is_transformed(nt.$n))) + end + return expr +end + +function is_transformed(md::Metadata) + return any(md.is_transformed) end # The default getindex & setindex!() for get & set values @@ -1552,7 +1566,7 @@ end @generated function _getindex(metadata, ranges::NamedTuple{names}) where {names} expr = Expr(:tuple) for f in names - push!(expr.args, :(metadata.$f.vals[ranges.$f])) + push!(expr.args, :(@view metadata.$f.vals[ranges.$f])) end return expr end diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 17b851d1d..a81f33ea5 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -367,6 +367,12 @@ Return a boolean for whether `vn` is guaranteed to have been transformed so that is all of Euclidean space. """ is_transformed(vnv::VarNamedVector, vn::VarName) = vnv.is_unconstrained[getidx(vnv, vn)] +""" + is_transformed(vnv::VarNamedVector) + +Return true if any variable in `vnv` is guaranteed to have been transformed. +""" +is_transformed(vnv::VarNamedVector) = any(vnv.is_unconstrained) """ set_transformed!(vnv::VarNamedVector, val::Bool, vn::VarName)