Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,14 @@ include("logdensityfunction.jl")
include("model_utils.jl")
include("extract_priors.jl")
include("values_as_in_model.jl")
include("experimental.jl")
include("chains.jl")
include("bijector.jl")

include("debug_utils.jl")
using .DebugUtils
include("test_utils.jl")

include("experimental.jl")
include("deprecated.jl")

if isdefined(Base.Experimental, :register_error_hint)
Expand Down
36 changes: 36 additions & 0 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,42 @@ function ParamsWithStats(
return ParamsWithStats(params, stats)
end

function ParamsWithStats(
param_vector::AbstractVector,
ldf::DynamicPPL.Experimental.FastLDF,
stats::NamedTuple=NamedTuple();
include_colon_eq::Bool=true,
include_log_probs::Bool=true,
)
ctx = DynamicPPL.Experimental.FastEvalVectorContext(
ldf._iden_varname_ranges, ldf._varname_ranges, param_vector
)
accs = if include_log_probs
(
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),
)
else
(DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),)
end
_, varinfo = DynamicPPL.Experimental.fast_evaluate!!(
ldf.model, ctx, AccumulatorTuple(accs)
)
params = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values
if include_log_probs
stats = merge(
stats,
(
logprior=DynamicPPL.getlogprior(varinfo),
loglikelihood=DynamicPPL.getloglikelihood(varinfo),
lp=DynamicPPL.getlogjoint(varinfo),
),
)
end
return ParamsWithStats(params, stats)
end

# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's much faster to
# convert it to a typed varinfo first, hence this method.
# https://github.com/TuringLang/Turing.jl/issues/2604
Expand Down
30 changes: 18 additions & 12 deletions src/fasteval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,16 +352,9 @@ end
fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))

struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
_model::M
_getlogdensity::F
_iden_varname_ranges::N
_varname_ranges::Dict{VarName,RangeAndLinked}
end
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params)
model = DynamicPPL.setleafcontext(f._model, ctx)
accs = fast_ldf_accs(f._getlogdensity)
function fast_evaluate!!(model::Model, ctx::FastEvalVectorContext, accs::AccumulatorTuple)
model = DynamicPPL.setleafcontext(model, ctx)
vi = OnlyAccsVarInfo(accs)
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
Expand All @@ -371,14 +364,27 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
vi = if Threads.nthreads() > 1
accs = map(
acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc),
acc -> DynamicPPL.convert_eltype(
float_type_with_fallback(eltype(ctx.params)), acc
),
accs,
)
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
else
OnlyAccsVarInfo(accs)
end
_, vi = DynamicPPL._evaluate!!(model, vi)
return DynamicPPL._evaluate!!(model, vi)
end

struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
_model::M
_getlogdensity::F
_iden_varname_ranges::N
_varname_ranges::Dict{VarName,RangeAndLinked}
end
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params)
_, vi = fast_evaluate!!(f._model, ctx, fast_ldf_accs(f._getlogdensity))
return f._getlogdensity(vi)
end

Expand Down
29 changes: 28 additions & 1 deletion test/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using DynamicPPL
using Distributions
using Test

@testset "ParamsWithStats" begin
@testset "ParamsWithStats, from VarInfo" begin
@model function f(z)
x ~ Normal()
y := x + 1
Expand Down Expand Up @@ -66,4 +66,31 @@ using Test
end
end

@testset "ParamsWithStats from FastLDF" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
unlinked_vi = VarInfo(m)
@testset "$islinked" for islinked in (false, true)
vi = if islinked
DynamicPPL.link!!(unlinked_vi, m)
else
unlinked_vi
end
nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi)
params = map(identity, vi[:])

# Get the ParamsWithStats using FastLDF
fldf = DynamicPPL.Experimental.FastLDF(m, getlogjoint, vi)
ps = ParamsWithStats(params, fldf)

# Check that length of parameters is as expected
@test length(ps.params) == length(keys(vi))

# Iterate over all variables to check that their values match
for vn in keys(vi)
@test ps.params[vn] == vi[vn]
end
end
end
end

end # module