From 66423bf633dae5b434c39b6aa47f3b4ac574c819 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 8 Nov 2025 17:49:10 +0000 Subject: [PATCH 1/3] Implement `ParamsWithStats` from `FastLDF` --- src/chains.jl | 36 ++++++++++++++++++++++++++++++++++++ src/fasteval.jl | 26 +++++++++++++++----------- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/src/chains.jl b/src/chains.jl index 2b5976b9b..01d2b6f0f 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -90,6 +90,42 @@ function ParamsWithStats( return ParamsWithStats(params, stats) end +function ParamsWithStats( + param_vector::AbstractVector, + ldf::DynamicPPL.Experimental.FastLDF, + stats::NamedTuple=NamedTuple(); + include_colon_eq::Bool=true, + include_log_probs::Bool=true, +) + ctx = DynamicPPL.Experimental.FastEvalVectorContext( + ldf._iden_varname_ranges, ldf._varname_ranges, param_vector + ) + accs = if include_log_probs + ( + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq), + ) + else + (DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),) + end + return _, varinfo = DynamicPPL.Experimental.fast_evaluate!!( + ldf.model, ctx, AccumulatorTuple(accs) + ) + params = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values + if include_log_probs + stats = merge( + stats, + ( + logprior=DynamicPPL.getlogprior(varinfo), + loglikelihood=DynamicPPL.getloglikelihood(varinfo), + lp=DynamicPPL.getlogjoint(varinfo), + ), + ) + end + return ParamsWithStats(params, stats) +end + # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's much faster to # convert it to a typed varinfo first, hence this method. # https://github.com/TuringLang/Turing.jl/issues/2604 diff --git a/src/fasteval.jl b/src/fasteval.jl index c91254d43..52abfd5ba 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -352,16 +352,9 @@ end fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) -struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} - _model::M - _getlogdensity::F - _iden_varname_ranges::N - _varname_ranges::Dict{VarName,RangeAndLinked} -end -function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) - ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params) - model = DynamicPPL.setleafcontext(f._model, ctx) - accs = fast_ldf_accs(f._getlogdensity) +function fast_evaluate!!(model::Model, ctx::FastEvalVectorContext, accs::AccumulatorTuple) + model = DynamicPPL.setleafcontext(model, ctx) + vi = OnlyAccsVarInfo(accs) # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic @@ -378,7 +371,18 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) else OnlyAccsVarInfo(accs) end - _, vi = DynamicPPL._evaluate!!(model, vi) + return DynamicPPL._evaluate!!(model, vi) +end + +struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} + _model::M + _getlogdensity::F + _iden_varname_ranges::N + _varname_ranges::Dict{VarName,RangeAndLinked} +end +function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) + ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params) + _, vi = fast_evaluate!!(f._model, ctx, fast_ldf_accs(f._getlogdensity)) return f._getlogdensity(vi) end From f90d4511d5a6ac363955787f675019e8be65d78e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 8 Nov 2025 17:55:57 +0000 Subject: [PATCH 2/3] Add tests, fix import order --- src/DynamicPPL.jl | 2 +- src/chains.jl | 2 +- test/chains.jl | 29 ++++++++++++++++++++++++++++- 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e66f3fe11..c13746326 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -194,6 +194,7 @@ include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") +include("experimental.jl") include("chains.jl") include("bijector.jl") @@ -201,7 +202,6 @@ include("debug_utils.jl") using .DebugUtils include("test_utils.jl") -include("experimental.jl") include("deprecated.jl") if isdefined(Base.Experimental, :register_error_hint) diff --git a/src/chains.jl b/src/chains.jl index 01d2b6f0f..7dbc23d72 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -109,7 +109,7 @@ function ParamsWithStats( else (DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),) end - return _, varinfo = DynamicPPL.Experimental.fast_evaluate!!( + _, varinfo = DynamicPPL.Experimental.fast_evaluate!!( ldf.model, ctx, AccumulatorTuple(accs) ) params = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values diff --git a/test/chains.jl b/test/chains.jl index ab0ff4475..58740f1f5 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -4,7 +4,7 @@ using DynamicPPL using Distributions using Test -@testset "ParamsWithStats" begin +@testset "ParamsWithStats, from VarInfo" begin @model function f(z) x ~ Normal() y := x + 1 @@ -66,4 +66,31 @@ using Test end end +@testset "ParamsWithStats from FastLDF" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + unlinked_vi = VarInfo(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) + params = map(identity, vi[:]) + + # Get the ParamsWithStats using FastLDF + fldf = DynamicPPL.Experimental.FastLDF(m, getlogjoint, vi) + ps = ParamsWithStats(params, fldf) + + # Check that length of parameters is as expected + @test length(ps.params) == length(keys(vi)) + + # Iterate over all variables to check that their values match + for vn in keys(vi) + @test ps.params[vn] == vi[vn] + end + end + end +end + end # module From c475f9617f8f79d359fa0bf984b4413e067e5122 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 8 Nov 2025 22:59:04 +0000 Subject: [PATCH 3/3] Fix bug --- src/fasteval.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index 52abfd5ba..7b8b1aa5d 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -364,7 +364,9 @@ function fast_evaluate!!(model::Model, ctx::FastEvalVectorContext, accs::Accumul # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 vi = if Threads.nthreads() > 1 accs = map( - acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc), + acc -> DynamicPPL.convert_eltype( + float_type_with_fallback(eltype(ctx.params)), acc + ), accs, ) ThreadSafeVarInfo(OnlyAccsVarInfo(accs))