Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,16 @@ results_table = Tuple{String,Int,String,String,Bool,Float64,Float64}[]

for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations
@info "Running benchmark for $model_name"
suite = make_suite(model, varinfo_choice, adbackend, islinked)
results = run(suite)
eval_time = median(results["evaluation"]).time
relative_eval_time = eval_time / reference_time
ad_eval_time = median(results["gradient"]).time
relative_ad_eval_time = ad_eval_time / eval_time
relative_eval_time, relative_ad_eval_time = try
suite = make_suite(model, varinfo_choice, adbackend, islinked)
results = run(suite)
eval_time = median(results["evaluation"]).time
ad_eval_time = median(results["gradient"]).time
(eval_time / reference_time), (ad_eval_time / eval_time)
catch e
@warn "Benchmark failed for $model_name with error: $e"
NaN, NaN
end
push!(
results_table,
(
Expand Down
32 changes: 23 additions & 9 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,9 @@ end
for f in names
mdf = :(metadata.$f)
len = :(sum(length, $mdf.ranges))
push!(exprs, :($f = unflatten_metadata($mdf, x[($offset + 1):($offset + $len)])))
push!(
exprs, :($f = unflatten_metadata($mdf, @view x[($offset + 1):($offset + $len)]))
)
offset = :($offset + $len)
end
length(exprs) == 0 && return :(NamedTuple())
Expand Down Expand Up @@ -751,11 +753,10 @@ function getdist(::VarNamedVector, ::VarName)
end

getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi, vn), vn)
# TODO(torfjelde): Use `view` instead of `getindex`. Requires addressing type-stability issues though,
# since then we might be returning a `SubArray` rather than an `Array`, which is typically
# what a bijector would result in, even if the input is a view (`SubArray`).
# TODO(torfjelde): An alternative is to implement `view` directly instead.
getindex_internal(md::Metadata, vn::VarName) = getindex(md.vals, getrange(md, vn))
function getindex_internal(md::Metadata, vn::VarName)
rng = getrange(md, vn)
return @view md.vals[rng]
end
function getindex_internal(vi::VarInfo, vns::Vector{<:VarName})
return mapreduce(Base.Fix1(getindex_internal, vi), vcat, vns)
end
Expand Down Expand Up @@ -1495,8 +1496,21 @@ space.
If some but only some of the variables in `vi` are transformed, this function will return
`true`. This behavior will likely change in the future.
"""
function is_transformed(vi::VarInfo)
return any(is_transformed(vi, vn) for vn in keys(vi))
function is_transformed(vi::NTVarInfo)
return is_transformed(vi.metadata)
end

@generated function is_transformed(nt::NamedTuple{names}) where {names}
expr = Expr(:block)
push!(expr.args, :(result = false))
for n in names
push!(expr.args, :(result = result || is_transformed(nt.$n)))
end
return expr
end

function is_transformed(md::Metadata)
return any(md.is_transformed)
end

# The default getindex & setindex!() for get & set values
Expand Down Expand Up @@ -1552,7 +1566,7 @@ end
@generated function _getindex(metadata, ranges::NamedTuple{names}) where {names}
expr = Expr(:tuple)
for f in names
push!(expr.args, :(metadata.$f.vals[ranges.$f]))
push!(expr.args, :(@view metadata.$f.vals[ranges.$f]))
end
return expr
end
Expand Down
6 changes: 6 additions & 0 deletions src/varnamedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,12 @@ Return a boolean for whether `vn` is guaranteed to have been transformed so that
is all of Euclidean space.
"""
is_transformed(vnv::VarNamedVector, vn::VarName) = vnv.is_unconstrained[getidx(vnv, vn)]
"""
is_transformed(vnv::VarNamedVector)

Return true if any variable in `vnv` is guaranteed to have been transformed.
"""
is_transformed(vnv::VarNamedVector) = any(vnv.is_unconstrained)

"""
set_transformed!(vnv::VarNamedVector, val::Bool, vn::VarName)
Expand Down
Loading