From 7cddac775a3bfa0eb0808bb241551e875c6ac0b2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 5 Nov 2025 23:58:25 +0000 Subject: [PATCH 01/30] Fast Log Density Function --- src/DynamicPPL.jl | 1 + src/fastldf.jl | 91 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 src/fastldf.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e66f3fe11..77d527ced 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -191,6 +191,7 @@ include("simple_varinfo.jl") include("compiler.jl") include("pointwise_logdensities.jl") include("logdensityfunction.jl") +include("fastldf.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") diff --git a/src/fastldf.jl b/src/fastldf.jl new file mode 100644 index 000000000..05b024dbe --- /dev/null +++ b/src/fastldf.jl @@ -0,0 +1,91 @@ +struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo + accs::Accs +end +DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs +DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi +DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) + +struct RangeAndLinked + # indices that the variable corresponds to in the vectorised parameter + range::UnitRange{Int} + # whether it's linked + is_linked::Bool +end + +struct FastLDFContext{T<:AbstractVector{<:Real}} <: AbstractContext + varname_ranges::Dict{VarName,RangeAndLinked} + params::T +end +DynamicPPL.NodeTrait(::FastLDFContext) = IsLeaf() + +function tilde_assume!!( + ctx::FastLDFContext, right::Distribution, vn::VarName, vi::OnlyAccsVarInfo +) + # Don't need to read the data from the varinfo at all since it's + # all inside the context. + range_and_linked = ctx.varname_ranges[vn] + y = @view ctx.params[range_and_linked.range] + is_linked = range_and_linked.is_linked + f = if is_linked + from_linked_vec_transform(right) + else + from_vec_transform(right) + end + x, inv_logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) + return x, vi +end + +function tilde_observe!!( + ::FastLDFContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::OnlyAccsVarInfo, +) + # This is the same as for DefaultContext + vi = accumulate_observe!!(vi, right, left, vn) + return left, vi +end + +struct FastLDF{M<:Model,F<:Function} + _model::M + _getlogdensity::F + _varname_ranges::Dict{VarName,RangeAndLinked} + + function FastLDF( + model::Model, + getlogdensity::Function, + # This only works with typed Metadata-varinfo. + # Obviously, this can be generalised later. + varinfo::VarInfo{<:NamedTuple{syms}}, + ) where {syms} + # Figure out which variable corresponds to which index, and + # which variables are linked. + all_ranges = Dict{VarName,RangeAndLinked}() + offset = 1 + for sym in syms + md = varinfo.metadata[sym] + for (vn, idx) in md.idcs + len = length(md.ranges[idx]) + is_linked = md.is_transformed[idx] + range = offset:(offset + len - 1) + all_ranges[vn] = RangeAndLinked(range, is_linked) + offset += len + end + end + return new{typeof(model),typeof(getlogdensity)}(model, getlogdensity, all_ranges) + end +end + +function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) + ctx = FastLDFContext(fldf._varname_ranges, params) + model = DynamicPPL.setleafcontext(fldf._model, ctx) + # This can obviously also be optimised for the case where not + # all accumulators are needed. + accs = AccumulatorTuple(( + LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator() + )) + _, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs)) + return fldf._getlogdensity(vi) +end From 5ed4295f23b0c5cd2d1568de247da249de42913b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 00:33:49 +0000 Subject: [PATCH 02/30] Make it work with AD --- src/fastldf.jl | 56 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 49 insertions(+), 7 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 05b024dbe..2178c84a2 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -48,17 +48,25 @@ function tilde_observe!!( return left, vi end -struct FastLDF{M<:Model,F<:Function} +struct FastLDF{ + M<:Model, + F<:Function, + AD<:Union{ADTypes.AbstractADType,Nothing}, + ADP<:Union{Nothing,DI.GradientPrep}, +} _model::M _getlogdensity::F _varname_ranges::Dict{VarName,RangeAndLinked} + _adtype::AD + _adprep::ADP function FastLDF( model::Model, getlogdensity::Function, # This only works with typed Metadata-varinfo. # Obviously, this can be generalised later. - varinfo::VarInfo{<:NamedTuple{syms}}, + varinfo::VarInfo{<:NamedTuple{syms}}; + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) where {syms} # Figure out which variable corresponds to which index, and # which variables are linked. @@ -74,18 +82,52 @@ struct FastLDF{M<:Model,F<:Function} offset += len end end - return new{typeof(model),typeof(getlogdensity)}(model, getlogdensity, all_ranges) + # Do AD prep if needed + prep = if adtype === nothing + nothing + else + # Make backend-specific tweaks to the adtype + adtype = tweak_adtype(adtype, model, varinfo) + x = [val for val in varinfo[:]] + DI.prepare_gradient( + FastLogDensityAt(model, getlogdensity, all_ranges), adtype, x + ) + end + + return new{typeof(model),typeof(getlogdensity),typeof(adtype),typeof(prep)}( + model, getlogdensity, all_ranges, adtype, prep + ) end end -function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) - ctx = FastLDFContext(fldf._varname_ranges, params) - model = DynamicPPL.setleafcontext(fldf._model, ctx) +struct FastLogDensityAt{M<:Model,F<:Function} + _model::M + _getlogdensity::F + _varname_ranges::Dict{VarName,RangeAndLinked} +end +function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) + ctx = FastLDFContext(f._varname_ranges, params) + model = DynamicPPL.setleafcontext(f._model, ctx) # This can obviously also be optimised for the case where not # all accumulators are needed. accs = AccumulatorTuple(( LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator() )) _, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs)) - return fldf._getlogdensity(vi) + return f._getlogdensity(vi) +end + +function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) + return FastLogDensityAt(fldf._model, fldf._getlogdensity, fldf._varname_ranges)(params) +end + +function LogDensityProblems.logdensity_and_gradient( + fldf::FastLDF, params::AbstractVector{<:Real} +) + return DI.value_and_gradient( + FastLogDensityAt(fldf._model, fldf._getlogdensity, fldf._varname_ranges), + fldf._adprep, + fldf._adtype, + params, + ) end From e199520dac0b8c0ae82a3ed7fd1e4673bbd7c28a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 01:37:05 +0000 Subject: [PATCH 03/30] Optimise performance for identity VarNames --- src/fastldf.jl | 69 +++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 57 insertions(+), 12 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 2178c84a2..e59b12791 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -12,21 +12,35 @@ struct RangeAndLinked is_linked::Bool end -struct FastLDFContext{T<:AbstractVector{<:Real}} <: AbstractContext +struct FastLDFContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext + # The ranges of identity VarNames are stored in a NamedTuple for performance + # reasons. For just plain evaluation this doesn't make _that_ much of a + # difference (maybe 1.5x), but when doing AD with Mooncake this makes a HUGE + # difference (around 4x). Of course, the exact numbers depend on the model. + iden_varname_ranges::N + # This Dict stores the ranges for all other VarNames varname_ranges::Dict{VarName,RangeAndLinked} + # The full parameter vector which we index into to get variable values params::T end DynamicPPL.NodeTrait(::FastLDFContext) = IsLeaf() +function get_range_and_linked( + ctx::FastLDFContext, ::VarName{sym,typeof(identity)} +) where {sym} + return ctx.iden_varname_ranges[sym] +end +function get_range_and_linked(ctx::FastLDFContext, vn::VarName) + return ctx.varname_ranges[vn] +end function tilde_assume!!( ctx::FastLDFContext, right::Distribution, vn::VarName, vi::OnlyAccsVarInfo ) # Don't need to read the data from the varinfo at all since it's # all inside the context. - range_and_linked = ctx.varname_ranges[vn] + range_and_linked = get_range_and_linked(ctx, vn) y = @view ctx.params[range_and_linked.range] - is_linked = range_and_linked.is_linked - f = if is_linked + f = if range_and_linked.is_linked from_linked_vec_transform(right) else from_vec_transform(right) @@ -51,11 +65,14 @@ end struct FastLDF{ M<:Model, F<:Function, + N<:NamedTuple, AD<:Union{ADTypes.AbstractADType,Nothing}, ADP<:Union{Nothing,DI.GradientPrep}, } _model::M _getlogdensity::F + # See FastLDFContext for explanation of these two fields + _iden_varname_ranges::N _varname_ranges::Dict{VarName,RangeAndLinked} _adtype::AD _adprep::ADP @@ -70,6 +87,7 @@ struct FastLDF{ ) where {syms} # Figure out which variable corresponds to which index, and # which variables are linked. + all_iden_ranges = NamedTuple() all_ranges = Dict{VarName,RangeAndLinked}() offset = 1 for sym in syms @@ -78,7 +96,16 @@ struct FastLDF{ len = length(md.ranges[idx]) is_linked = md.is_transformed[idx] range = offset:(offset + len - 1) - all_ranges[vn] = RangeAndLinked(range, is_linked) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple(( + AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked), + )), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end offset += len end end @@ -90,23 +117,32 @@ struct FastLDF{ adtype = tweak_adtype(adtype, model, varinfo) x = [val for val in varinfo[:]] DI.prepare_gradient( - FastLogDensityAt(model, getlogdensity, all_ranges), adtype, x + FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + adtype, + x, ) end - return new{typeof(model),typeof(getlogdensity),typeof(adtype),typeof(prep)}( - model, getlogdensity, all_ranges, adtype, prep + return new{ + typeof(model), + typeof(getlogdensity), + typeof(all_iden_ranges), + typeof(adtype), + typeof(prep), + }( + model, getlogdensity, all_iden_ranges, all_ranges, adtype, prep ) end end -struct FastLogDensityAt{M<:Model,F<:Function} +struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} _model::M _getlogdensity::F + _iden_varname_ranges::N _varname_ranges::Dict{VarName,RangeAndLinked} end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) - ctx = FastLDFContext(f._varname_ranges, params) + ctx = FastLDFContext(f._iden_varname_ranges, f._varname_ranges, params) model = DynamicPPL.setleafcontext(f._model, ctx) # This can obviously also be optimised for the case where not # all accumulators are needed. @@ -118,14 +154,23 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) end function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) - return FastLogDensityAt(fldf._model, fldf._getlogdensity, fldf._varname_ranges)(params) + return FastLogDensityAt( + fldf._model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges + )( + params + ) end function LogDensityProblems.logdensity_and_gradient( fldf::FastLDF, params::AbstractVector{<:Real} ) return DI.value_and_gradient( - FastLogDensityAt(fldf._model, fldf._getlogdensity, fldf._varname_ranges), + FastLogDensityAt( + fldf._model, + fldf._getlogdensity, + fldf._iden_varname_ranges, + fldf._varname_ranges, + ), fldf._adprep, fldf._adtype, params, From 4cefaca4803107da0570430277a3f05a27ae2146 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 02:43:55 +0000 Subject: [PATCH 04/30] Mark `get_range_and_linked` as having zero derivative --- ext/DynamicPPLEnzymeCoreExt.jl | 13 ++++++------- ext/DynamicPPLMooncakeExt.jl | 3 ++- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index 35159636f..29a4e2cc7 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -1,16 +1,15 @@ module DynamicPPLEnzymeCoreExt -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL - using EnzymeCore -else - using ..DynamicPPL: DynamicPPL - using ..EnzymeCore -end +using DynamicPPL: DynamicPPL +using EnzymeCore # Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme # only checks whether such a method exists, and never runs it. @inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) = nothing +# Likewise for get_range_and_linked. +@inline EnzymeCore.EnzymeRules.inactive( + ::typeof(DynamicPPL.get_range_and_linked), args... +) = nothing end diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index 23a3430eb..e49b81cb2 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -1,9 +1,10 @@ module DynamicPPLMooncakeExt -using DynamicPPL: DynamicPPL, is_transformed +using DynamicPPL: DynamicPPL, is_transformed, get_range_and_linked using Mooncake: Mooncake # This is purely an optimisation. Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg} +Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(get_range_and_linked),Vararg} end # module From 6dfd106ace57d776799c487c5dccbaed48211b13 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 03:11:32 +0000 Subject: [PATCH 05/30] Update comment --- src/fastldf.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index e59b12791..6c8798d4c 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -13,10 +13,8 @@ struct RangeAndLinked end struct FastLDFContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext - # The ranges of identity VarNames are stored in a NamedTuple for performance - # reasons. For just plain evaluation this doesn't make _that_ much of a - # difference (maybe 1.5x), but when doing AD with Mooncake this makes a HUGE - # difference (around 4x). Of course, the exact numbers depend on the model. + # The ranges of identity VarNames are stored in a NamedTuple for improved performance + # (it's around 1.5x faster). iden_varname_ranges::N # This Dict stores the ranges for all other VarNames varname_ranges::Dict{VarName,RangeAndLinked} From 41ee7f3d7b58503de91d835c42d394cd6e1818f1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 13:53:36 +0000 Subject: [PATCH 06/30] make AD testing / benchmarking use FastLDF --- benchmarks/src/DynamicPPLBenchmarks.jl | 4 +- src/fastldf.jl | 2 +- src/test_utils/ad.jl | 9 +-- test/ad.jl | 77 +++++++------------------- 4 files changed, 24 insertions(+), 68 deletions(-) diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 225e40cd8..e6988d3f2 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -94,9 +94,7 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: vi = DynamicPPL.link(vi, model) end - f = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend - ) + f = DynamicPPL.FastLDF(model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend) # The parameters at which we evaluate f. θ = vi[:] diff --git a/src/fastldf.jl b/src/fastldf.jl index 6c8798d4c..7ec193891 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -80,7 +80,7 @@ struct FastLDF{ getlogdensity::Function, # This only works with typed Metadata-varinfo. # Obviously, this can be generalised later. - varinfo::VarInfo{<:NamedTuple{syms}}; + varinfo::VarInfo{<:NamedTuple{syms}}=VarInfo(model); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) where {syms} # Figure out which variable corresponds to which index, and diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index a49ffd18b..fbbae85b7 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,8 +4,7 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: - Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link +using DynamicPPL: Model, FastLDF, VarInfo, AbstractVarInfo, getlogjoint_internal, link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: AbstractRNG, default_rng using Statistics: median @@ -265,7 +264,7 @@ function run_ad( # Calculate log-density and gradient with the backend of interest verbose && @info "Running AD on $(model.f) with $(adtype)\n" verbose && println(" params : $(params)") - ldf = LogDensityFunction(model, getlogdensity, varinfo; adtype=adtype) + ldf = FastLDF(model, getlogdensity, varinfo; adtype=adtype) value, grad = logdensity_and_gradient(ldf, params) # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 @@ -282,9 +281,7 @@ function run_ad( value_true = test.value grad_true = test.grad elseif test isa WithBackend - ldf_reference = LogDensityFunction( - model, getlogdensity, varinfo; adtype=test.adtype - ) + ldf_reference = FastLDF(model, getlogdensity, varinfo; adtype=test.adtype) value_true, grad_true = logdensity_and_gradient(ldf_reference, params) # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 grad_true = collect(grad_true) diff --git a/test/ad.jl b/test/ad.jl index d7505aab2..6d140197e 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,4 +1,4 @@ -using DynamicPPL: LogDensityFunction +using DynamicPPL: FastLDF using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest @testset "Automatic differentiation" begin @@ -15,64 +15,25 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] end - @testset "Unsupported backends" begin - @model demo() = x ~ Normal() - @test_logs (:warn, r"not officially supported") LogDensityFunction( - demo(); adtype=AutoZygote() - ) - end - @testset "Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) - vns = DynamicPPL.TestUtils.varnames(m) - varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) - - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - linked_varinfo = DynamicPPL.link(varinfo, m) - f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) - x = DynamicPPL.getparams(f) - - # Calculate reference logp + gradient of logp using ForwardDiff - ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) - ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual - - @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" - - # Put predicates here to avoid long lines - is_mooncake = adtype isa AutoMooncake - is_1_10 = v"1.10" <= VERSION < v"1.11" - is_1_11 = v"1.11" <= VERSION < v"1.12" - is_svi_vnv = - linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} - is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict} - - # Mooncake doesn't work with several combinations of SimpleVarInfo. - if is_mooncake && is_1_11 && is_svi_vnv - # https://github.com/compintell/Mooncake.jl/issues/470 - @test_throws ArgumentError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_vnv - # TODO: report upstream - @test_throws UndefRefError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_od - # TODO: report upstream - @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - else - @test run_ad( - m, - adtype; - varinfo=linked_varinfo, - test=WithExpectedResult(ref_logp, ref_grad), - ) isa Any - end - end + varinfo = VarInfo(m) + linked_varinfo = DynamicPPL.link(varinfo, m) + f = FastLDF(m, getlogjoint_internal, linked_varinfo) + x = DynamicPPL.getparams(f) + + # Calculate reference logp + gradient of logp using ForwardDiff + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual + + @testset "$adtype" for adtype in test_adtypes + @info "Testing AD on: $(m.f) - $adtype" + @test run_ad( + m, + adtype; + varinfo=linked_varinfo, + test=WithExpectedResult(ref_logp, ref_grad), + ) isa Any end end end @@ -83,7 +44,7 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest test_m = randn(2, 3) function eval_logp_and_grad(model, m, adtype) - ldf = LogDensityFunction(model(); adtype=adtype) + ldf = FastLDF(model(); adtype=adtype) return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) end From 22e32a6fbd0ed365bd327a3dd5fbd4d6d3016bfb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 14:00:40 +0000 Subject: [PATCH 07/30] Fix tests --- src/fastldf.jl | 2 +- test/ad.jl | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 7ec193891..61194ab25 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -77,7 +77,7 @@ struct FastLDF{ function FastLDF( model::Model, - getlogdensity::Function, + getlogdensity::Function=getlogjoint_internal, # This only works with typed Metadata-varinfo. # Obviously, this can be generalised later. varinfo::VarInfo{<:NamedTuple{syms}}=VarInfo(model); diff --git a/test/ad.jl b/test/ad.jl index 6d140197e..48b1b64ec 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,5 +1,6 @@ using DynamicPPL: FastLDF using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest +using Random: Xoshiro @testset "Automatic differentiation" begin # Used as the ground truth that others are compared against. @@ -17,10 +18,10 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest @testset "Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - varinfo = VarInfo(m) + varinfo = VarInfo(Xoshiro(468), m) linked_varinfo = DynamicPPL.link(varinfo, m) f = FastLDF(m, getlogjoint_internal, linked_varinfo) - x = DynamicPPL.getparams(f) + x = linked_varinfo[:] # Calculate reference logp + gradient of logp using ForwardDiff ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) From 79cc1286b87d35fe082f2f02b803af4ebb2cc653 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 14:18:08 +0000 Subject: [PATCH 08/30] Optimise away `make_evaluate_args_and_kwargs` --- src/fastldf.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 61194ab25..ebaf002b4 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -147,10 +147,21 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) accs = AccumulatorTuple(( LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator() )) - _, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs)) + # _, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs)) + args = map(maybe_deepcopy, model.args) + _, vi = model.f(model, OnlyAccsVarInfo(accs), args...; model.defaults...) return f._getlogdensity(vi) end +maybe_deepcopy(@nospecialize(x)) = x +function maybe_deepcopy(x::AbstractArray{T}) where {T} + if T >: Missing + deepcopy(x) + else + x + end +end + function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) return FastLogDensityAt( fldf._model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges From f7c6a78ba3446246abda6798a77513f6893aa49a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 14:23:58 +0000 Subject: [PATCH 09/30] const func annotation --- test/integration/enzyme/main.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index b40bbeb8f..ea4ec497d 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -6,8 +6,10 @@ import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test ADTYPES = Dict( - "EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward)), - "EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse)), + "EnzymeForward" => + AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), + "EnzymeReverse" => + AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const), ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES From b1a76509b6e88ac7f6099039bd042d3d502d4848 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 14:28:34 +0000 Subject: [PATCH 10/30] Disable benchmarks on non-typed-Metadata-VarInfo --- benchmarks/benchmarks.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 035d8ff49..5fe0320cc 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -59,11 +59,11 @@ chosen_combinations = [ false, ), ("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false), - ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), ("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), From e60873a7e2f000733ef1390bcca7e164776e56db Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 14:36:34 +0000 Subject: [PATCH 11/30] Fix `_evaluate!!` correctly to handle submodels --- src/fastldf.jl | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index ebaf002b4..309155606 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -133,6 +133,23 @@ struct FastLDF{ end end +function _evaluate!!( + model::Model{F,A,D,M,TA,TD,<:FastLDFContext}, varinfo::OnlyAccsVarInfo +) where {F,A,D,M,TA,TD} + args = map(maybe_deepcopy, model.args) + return model.f(model, varinfo, args...; model.defaults...) +end +maybe_deepcopy(@nospecialize(x)) = x +function maybe_deepcopy(x::AbstractArray{T}) where {T} + if T >: Missing + # avoid overwriting missing elements of model arguments when + # evaluating the model. + deepcopy(x) + else + x + end +end + struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} _model::M _getlogdensity::F @@ -147,21 +164,10 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) accs = AccumulatorTuple(( LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator() )) - # _, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs)) - args = map(maybe_deepcopy, model.args) - _, vi = model.f(model, OnlyAccsVarInfo(accs), args...; model.defaults...) + _, vi = _evaluate!!(model, OnlyAccsVarInfo(accs)) return f._getlogdensity(vi) end -maybe_deepcopy(@nospecialize(x)) = x -function maybe_deepcopy(x::AbstractArray{T}) where {T} - if T >: Missing - deepcopy(x) - else - x - end -end - function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) return FastLogDensityAt( fldf._model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges From fa0664ee9ec3c6d0eae891af1508423ac77f5643 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 14:41:51 +0000 Subject: [PATCH 12/30] Actually fix submodel evaluate --- src/fastldf.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 309155606..b1794ffa2 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -2,7 +2,6 @@ struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo accs::Accs end DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs -DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) struct RangeAndLinked @@ -133,11 +132,13 @@ struct FastLDF{ end end -function _evaluate!!( - model::Model{F,A,D,M,TA,TD,<:FastLDFContext}, varinfo::OnlyAccsVarInfo -) where {F,A,D,M,TA,TD} - args = map(maybe_deepcopy, model.args) - return model.f(model, varinfo, args...; model.defaults...) +function _evaluate!!(model::Model, varinfo::OnlyAccsVarInfo) + if leafcontext(model.context) isa FastLDFContext + args = map(maybe_deepcopy, model.args) + return model.f(model, varinfo, args...; model.defaults...) + else + error("Shouldn't happen") + end end maybe_deepcopy(@nospecialize(x)) = x function maybe_deepcopy(x::AbstractArray{T}) where {T} From 09a1fbb4787cf70fd42c794ab4365ff99395964d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 17:04:26 +0000 Subject: [PATCH 13/30] Document thoroughly and organise code --- src/compiler.jl | 38 +++--- src/fastldf.jl | 341 +++++++++++++++++++++++++++++++++++++----------- src/model.jl | 20 ++- 3 files changed, 306 insertions(+), 93 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index badba9f9d..3324780ca 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -718,14 +718,15 @@ end # TODO(mhauru) matchingvalue has methods that can accept both types and values. Why? # TODO(mhauru) This function needs a more comprehensive docstring. """ - matchingvalue(vi, value) + matchingvalue(param_eltype, value) -Convert the `value` to the correct type for the `vi` object. +Convert the `value` to the correct type, given the element type of the parameters +being used to evaluate the model. """ -function matchingvalue(vi, value) +function matchingvalue(param_eltype, value) T = typeof(value) if hasmissing(T) - _value = convert(get_matching_type(vi, T), value) + _value = convert(get_matching_type(param_eltype, T), value) # TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we # are happy to return `value` as-is? if _value === value @@ -738,29 +739,30 @@ function matchingvalue(vi, value) end end -function matchingvalue(vi, value::FloatOrArrayType) - return get_matching_type(vi, value) +function matchingvalue(param_eltype, value::FloatOrArrayType) + return get_matching_type(param_eltype, value) end -function matchingvalue(vi, ::TypeWrap{T}) where {T} - return TypeWrap{get_matching_type(vi, T)}() +function matchingvalue(param_eltype, ::TypeWrap{T}) where {T} + return TypeWrap{get_matching_type(param_eltype, T)}() end # TODO(mhauru) This function needs a more comprehensive docstring. What is it for? """ - get_matching_type(vi, ::TypeWrap{T}) where {T} + get_matching_type(param_eltype, ::TypeWrap{T}) where {T} -Get the specialized version of type `T` for `vi`. +Get the specialized version of type `T`, given an element type of the parameters +being used to evaluate the model. """ get_matching_type(_, ::Type{T}) where {T} = T -function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}}) - return Union{Missing,float_type_with_fallback(eltype(vi))} +function get_matching_type(param_eltype, ::Type{<:Union{Missing,AbstractFloat}}) + return Union{Missing,float_type_with_fallback(param_eltype)} end -function get_matching_type(vi, ::Type{<:AbstractFloat}) - return float_type_with_fallback(eltype(vi)) +function get_matching_type(param_eltype, ::Type{<:AbstractFloat}) + return float_type_with_fallback(param_eltype) end -function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N} - return Array{get_matching_type(vi, T),N} +function get_matching_type(param_eltype, ::Type{<:Array{T,N}}) where {T,N} + return Array{get_matching_type(param_eltype, T),N} end -function get_matching_type(vi, ::Type{<:Array{T}}) where {T} - return Array{get_matching_type(vi, T)} +function get_matching_type(param_eltype, ::Type{<:Array{T}}) where {T} + return Array{get_matching_type(param_eltype, T)} end diff --git a/src/fastldf.jl b/src/fastldf.jl index b1794ffa2..215202230 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -1,9 +1,108 @@ +""" +fasteval.jl +----------- + +Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a +given set of parameters: + +1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters + inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. + +2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores + them inside a VarInfo's metadata. + +In general, both of these approaches work fine, but the fact that they modify the VarInfo's +metadata can often be quite wasteful. In particular, it is very common that the only outputs +we care about from model evaluation are those which are stored in accumulators, such as log +probability densities, or `ValuesAsInModel`. + +To avoid this issue, we implement here `OnlyAccsVarInfo`, which is a VarInfo that only +contains accumulators. When evaluating a model with `OnlyAccsVarInfo`, it is mandatory that +the model's leaf context is a `FastEvalContext`, which provides extremely fast access to +parameter values. No writing of values into VarInfo metadata is performed at all. + +Vector parameters +----------------- + +We first consider the case of parameter vectors, i.e., the case which would normally be +handled by `unflatten` and `evaluate!!`. Unfortunately, it is not enough to just store +the vector of parameters in the `FastEvalContext`, because it is not clear: + + - which parts of the vector correspond to which random variables, and + - whether the variables are linked or unlinked. + +Traditionally, this problem has been solved by `unflatten`, because that function would +place values into the VarInfo's metadata alongside the information about ranges and linking. +However, we want to avoid doing this. Thus, here, we _extract this information from the +VarInfo_ a single time when constructing a `FastLDF` object. + +Note that this assumes that the ranges and link status are static throughout the lifetime of +the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable +numbers of parameters, or models which may visit random variables in different orders depending +on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a +general limitation of vectorised parameters: the original `unflatten` + `evaluate!!` +approach also fails with such models. + +NamedTuple and Dict parameters +------------------------------ + +Fast evaluation has not yet been extended to NamedTuple and Dict parameters. Such +representations are capable of handling models with variable sizes and stochastic control +flow. + +However, the path towards implementing these is straightforward: + +1. Currently, `FastLDFVectorContext` allows users to input a VarName and obtain the parameter + value, plus a boolean indicating whether the value is linked or unlinked. See the + `get_range_and_linked` function for details. + +2. We would need to implement similar contexts for NamedTuple and Dict parameters. The + functionality would be quite similar to `InitContext(InitFromParams(...))`. +""" + +""" + OnlyAccsVarInfo + +This is a wrapper around an `AccumulatorTuple` that implements the minimal `AbstractVarInfo` +interface to work with the `accumulate_assume!!` and `accumulate_observe!!` functions. + +Note that this does not implement almost every other AbstractVarInfo interface function, and +so using this outside of FastLDF will lead to errors. + +Conceptually, one can also think of this as a VarInfo that doesn't contain a metadata field. +That is because values for random variables are obtained by reading from a separate entity +(such as a `FastLDFContext`), rather than from the VarInfo itself. +""" struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo accs::Accs end +OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) +function DynamicPPL.get_param_eltype(::OnlyAccsVarInfo, model::Model) + # Because the VarInfo has no parameters stored in it, we need to get the eltype from the + # model's leaf context. This is only possible if said leaf context is indeed a FastEval + # context. + leaf_ctx = DynamicPPL.leafcontext(model) + if leaf_ctx isa FastEvalVectorContext + return eltype(leaf_ctx.params) + else + error( + "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", + ) + end +end +""" + RangeAndLinked + +Suppose we have vectorised parameters `params::AbstractVector{<:Real}`. Each random variable +in the model will in general correspond to a sub-vector of `params`. This struct stores +information about that range, as well as whether the sub-vector represents a linked value or +an unlinked value. + +$(TYPEDFIELDS) +""" struct RangeAndLinked # indices that the variable corresponds to in the vectorised parameter range::UnitRange{Int} @@ -11,30 +110,55 @@ struct RangeAndLinked is_linked::Bool end -struct FastLDFContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext - # The ranges of identity VarNames are stored in a NamedTuple for improved performance - # (it's around 1.5x faster). +""" + AbstractFastEvalContext + +Abstract type representing fast evaluation contexts. This currently is only subtyped by +`FastEvalVectorContext`. However, in the future, similar contexts may be implemented for +NamedTuple and Dict parameters. +""" +abstract type AbstractFastEvalContext <: AbstractContext end +DynamicPPL.NodeTrait(::AbstractFastEvalContext) = IsLeaf() + +""" + FastEvalVectorContext( + iden_varname_ranges::NamedTuple, + varname_ranges::Dict{VarName,RangeAndLinked}, + params::AbstractVector{<:Real}, + ) + +A context that wraps a vector of parameter values, plus information about how random +variables map to ranges in that vector. + +In the simplest case, this could be accomplished only with a single dictionary mapping +VarNames to ranges and link status. However, for performance reasons, we separate out +VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All +non-identity-optic VarNames are stored in the `varname_ranges` Dict. + +It would be nice to unify the NamedTuple and Dict approach. See, e.g. +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. +""" +struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext + # This NamedTuple stores the ranges for identity VarNames iden_varname_ranges::N # This Dict stores the ranges for all other VarNames varname_ranges::Dict{VarName,RangeAndLinked} # The full parameter vector which we index into to get variable values params::T end -DynamicPPL.NodeTrait(::FastLDFContext) = IsLeaf() function get_range_and_linked( - ctx::FastLDFContext, ::VarName{sym,typeof(identity)} + ctx::FastEvalVectorContext, ::VarName{sym,typeof(identity)} ) where {sym} return ctx.iden_varname_ranges[sym] end -function get_range_and_linked(ctx::FastLDFContext, vn::VarName) +function get_range_and_linked(ctx::FastEvalVectorContext, vn::VarName) return ctx.varname_ranges[vn] end function tilde_assume!!( - ctx::FastLDFContext, right::Distribution, vn::VarName, vi::OnlyAccsVarInfo + ctx::FastEvalVectorContext, right::Distribution, vn::VarName, vi::AbstractVarInfo ) - # Don't need to read the data from the varinfo at all since it's - # all inside the context. + # Note that this function does not use the metadata field of `vi` at all. range_and_linked = get_range_and_linked(ctx, vn) y = @view ctx.params[range_and_linked.range] f = if range_and_linked.is_linked @@ -48,64 +172,111 @@ function tilde_assume!!( end function tilde_observe!!( - ::FastLDFContext, + ::FastEvalVectorContext, right::Distribution, left, vn::Union{VarName,Nothing}, - vi::OnlyAccsVarInfo, + vi::AbstractVarInfo, ) # This is the same as for DefaultContext vi = accumulate_observe!!(vi, right, left, vn) return left, vi end +######################################## +# Log-density functions using FastEval # +######################################## + +""" + FastLDF( + model::Model, + getlogdensity::Function=getlogjoint_internal, + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + +A struct which contains a model, along with all the information necessary to: + + - calculate its log density at a given point; + - and if `adtype` is provided, calculate the gradient of the log density at that point. + +This information can be extracted using the LogDensityProblems.jl interface, specifically, +using `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. If +`adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a concrete AD +backend type, then `logdensity_and_gradient` is also implemented. + +There are several options for `getlogdensity` that are 'supported' out of the box: + +- [`getlogjoint_internal`](@ref): calculate the log joint, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogprior_internal`](@ref): calculate the log prior, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring any effects of + linking +- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of + linking +- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected by linking, + since transforms are only applied to random variables) + +!!! note + By default, `FastLDF` uses `getlogjoint_internal`, i.e., the result of + `LogDensityProblems.logdensity(f, x)` will depend on whether the `FastLDF` was created + with a linked or unlinked VarInfo. This is done primarily to ease interoperability with + MCMC samplers. + +If you provide one of these functions, a `VarInfo` will be automatically created for you. If +you provide a different function, you have to manually create a VarInfo and pass it as the +third argument. + +If the `adtype` keyword argument is provided, then this struct will also store the adtype +along with other information for efficient calculation of the gradient of the log density. +Note that preparing a `FastLDF` with an AD type `AutoBackend()` requires the AD backend +itself to have been loaded (e.g. with `import Backend`). + +## Fields + +Note that it is undefined behaviour to access any of a `FastLDF`'s fields, apart from: + +- `fastldf.model`: The original model from which this `FastLDF` was constructed. +- `fastldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD + type was provided. + +## Extended help + +`FastLDF` uses `FastEvalVectorContext` internally to provide extremely rapid evaluation of +the model given a vector of parameters. + +Because it is common to call `LogDensityProblems.logdensity` and +`LogDensityProblems.logdensity_and_gradient` within tight loops, it is beneficial for us to +pre-compute as much of the information as possible when constructing the `FastLDF` object. +In particular, we use the provided VarInfo's metadata to extract the mapping from VarNames +to ranges and link status, and store this mapping inside the `FastLDF` object. We can later +use this to construct a FastEvalVectorContext, without having to look into a metadata again. +""" struct FastLDF{ M<:Model, + AD<:Union{ADTypes.AbstractADType,Nothing}, F<:Function, N<:NamedTuple, - AD<:Union{ADTypes.AbstractADType,Nothing}, ADP<:Union{Nothing,DI.GradientPrep}, } - _model::M + model::M + adtype::AD _getlogdensity::F - # See FastLDFContext for explanation of these two fields + # See FastLDFContext for explanation of these two fields. _iden_varname_ranges::N _varname_ranges::Dict{VarName,RangeAndLinked} - _adtype::AD _adprep::ADP function FastLDF( model::Model, getlogdensity::Function=getlogjoint_internal, - # This only works with typed Metadata-varinfo. - # Obviously, this can be generalised later. - varinfo::VarInfo{<:NamedTuple{syms}}=VarInfo(model); + varinfo::AbstractVarInfo=VarInfo(model); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, - ) where {syms} + ) # Figure out which variable corresponds to which index, and # which variables are linked. - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = 1 - for sym in syms - md = varinfo.metadata[sym] - for (vn, idx) in md.idcs - len = length(md.ranges[idx]) - is_linked = md.is_transformed[idx] - range = offset:(offset + len - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple(( - AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked), - )), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end - offset += len - end - end + all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) # Do AD prep if needed prep = if adtype === nothing nothing @@ -119,37 +290,37 @@ struct FastLDF{ x, ) end - return new{ typeof(model), + typeof(adtype), typeof(getlogdensity), typeof(all_iden_ranges), - typeof(adtype), typeof(prep), }( - model, getlogdensity, all_iden_ranges, all_ranges, adtype, prep + model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep ) end end -function _evaluate!!(model::Model, varinfo::OnlyAccsVarInfo) - if leafcontext(model.context) isa FastLDFContext - args = map(maybe_deepcopy, model.args) - return model.f(model, varinfo, args...; model.defaults...) - else - error("Shouldn't happen") - end +################################### +# LogDensityProblems.jl interface # +################################### +""" + fast_ldf_accs(getlogdensity::Function) + +Determine which accumulators are needed for fast evaluation with the given +`getlogdensity` function. +""" +fast_ldf_accs(::Function) = default_accumulators() +fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() +function fast_ldf_accs(::typeof(getlogjoint)) + return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) end -maybe_deepcopy(@nospecialize(x)) = x -function maybe_deepcopy(x::AbstractArray{T}) where {T} - if T >: Missing - # avoid overwriting missing elements of model arguments when - # evaluating the model. - deepcopy(x) - else - x - end +function fast_ldf_accs(::typeof(getlogprior_internal)) + return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) end +fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) +fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} _model::M @@ -158,20 +329,15 @@ struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} _varname_ranges::Dict{VarName,RangeAndLinked} end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) - ctx = FastLDFContext(f._iden_varname_ranges, f._varname_ranges, params) + ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params) model = DynamicPPL.setleafcontext(f._model, ctx) - # This can obviously also be optimised for the case where not - # all accumulators are needed. - accs = AccumulatorTuple(( - LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator() - )) - _, vi = _evaluate!!(model, OnlyAccsVarInfo(accs)) + _, vi = _evaluate!!(model, OnlyAccsVarInfo(fast_ldf_accs(f._getlogdensity))) return f._getlogdensity(vi) end function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) return FastLogDensityAt( - fldf._model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges + fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges )( params ) @@ -182,13 +348,42 @@ function LogDensityProblems.logdensity_and_gradient( ) return DI.value_and_gradient( FastLogDensityAt( - fldf._model, - fldf._getlogdensity, - fldf._iden_varname_ranges, - fldf._varname_ranges, + fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges ), fldf._adprep, - fldf._adtype, + fldf.adtype, params, ) end + +###################################################### +# Helper functions to extract ranges and link status # +###################################################### + +# TODO: Fails for other VarInfo types. +function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = 1 + for sym in syms + md = varinfo.metadata[sym] + # TODO: Fails for VarNamedVector. + for (vn, idx) in md.idcs + len = length(md.ranges[idx]) + is_linked = md.is_transformed[idx] + range = offset:(offset + len - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple(( + AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked), + )), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += len + end + end + return all_iden_ranges, all_ranges +end diff --git a/src/model.jl b/src/model.jl index edb042ba9..6ca06aea6 100644 --- a/src/model.jl +++ b/src/model.jl @@ -986,9 +986,9 @@ Return the arguments and keyword arguments to be passed to the evaluator of the ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) - :($matchingvalue(varinfo, model.args.$var)...) + :($matchingvalue($get_param_eltype(varinfo, model), model.args.$var)...) else - :($matchingvalue(varinfo, model.args.$var)) + :($matchingvalue($get_param_eltype(varinfo, model), model.args.$var)) end for var in argnames ] return quote @@ -1006,6 +1006,22 @@ Return the arguments and keyword arguments to be passed to the evaluator of the end end +""" + get_param_eltype(varinfo::AbstractVarInfo, model::Model) + +Get the element type of the parameters being used to evaluate the `model` from the +`varinfo`. For example, when performing AD with ForwardDiff, this should return +`ForwardDiff.Dual`. + +By default, this uses `eltype(varinfo)` which is slightly cursed. This relies on the fact +that typically, before evaluation, the parameters will have been inserted into the VarInfo's +metadata field. + +See `OnlyAccsVarInfo` for an example of where this is not true (the parameters are instead +stored in the model's context). +""" +get_param_eltype(varinfo::AbstractVarInfo, ::Model) = eltype(varinfo) + """ getargnames(model::Model) From 7306ba46158cffa4a1dca405f5fe8fd68930046c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 17:27:57 +0000 Subject: [PATCH 14/30] Support more VarInfos, make it thread-safe (?) --- src/fastldf.jl | 99 +++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 78 insertions(+), 21 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 215202230..87dd698dc 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -77,13 +77,14 @@ struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo accs::Accs end OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) +DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) function DynamicPPL.get_param_eltype(::OnlyAccsVarInfo, model::Model) # Because the VarInfo has no parameters stored in it, we need to get the eltype from the # model's leaf context. This is only possible if said leaf context is indeed a FastEval # context. - leaf_ctx = DynamicPPL.leafcontext(model) + leaf_ctx = DynamicPPL.leafcontext(model.context) if leaf_ctx isa FastEvalVectorContext return eltype(leaf_ctx.params) else @@ -138,7 +139,8 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict. It would be nice to unify the NamedTuple and Dict approach. See, e.g. https://github.com/TuringLang/DynamicPPL.jl/issues/1116. """ -struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext +struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: + AbstractFastEvalContext # This NamedTuple stores the ranges for identity VarNames iden_varname_ranges::N # This Dict stores the ranges for all other VarNames @@ -331,7 +333,17 @@ end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params) model = DynamicPPL.setleafcontext(f._model, ctx) - _, vi = _evaluate!!(model, OnlyAccsVarInfo(fast_ldf_accs(f._getlogdensity))) + only_accs_vi = OnlyAccsVarInfo(fast_ldf_accs(f._getlogdensity)) + # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, + # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` + # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic + # here. + vi = if Threads.nthreads() > 1 + ThreadSafeVarInfo(only_accs_vi) + else + only_accs_vi + end + _, vi = _evaluate!!(model, vi) return f._getlogdensity(vi) end @@ -360,30 +372,75 @@ end # Helper functions to extract ranges and link status # ###################################################### -# TODO: Fails for other VarInfo types. +# TODO: Fails for SimpleVarInfo. Do I really care enough? Ehhh, honestly, debatable. + +""" + get_ranges_and_linked(varinfo::VarInfo) + +Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter +representation, along with whether each variable is linked or unlinked. + +This function should return a tuple containing: + +- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` +- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. +""" function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} all_iden_ranges = NamedTuple() all_ranges = Dict{VarName,RangeAndLinked}() offset = 1 for sym in syms md = varinfo.metadata[sym] - # TODO: Fails for VarNamedVector. - for (vn, idx) in md.idcs - len = length(md.ranges[idx]) - is_linked = md.is_transformed[idx] - range = offset:(offset + len - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple(( - AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked), - )), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end - offset += len - end + this_md_iden, this_md_others, new_offset = get_ranges_and_linked_metadata( + md, offset + ) + all_iden_ranges = merge(all_iden_ranges, this_md_iden) + all_ranges = merge(all_ranges, this_md_others) + offset = new_offset end return all_iden_ranges, all_ranges end +function get_ranges_and_linked(varinfo::VarInfo{<:Metadata}) + all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) + return all_iden, all_others +end +function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in md.idcs + len = length(md.ranges[idx]) + is_linked = md.is_transformed[idx] + range = offset:(offset + len - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += len + end + return all_iden_ranges, all_ranges, offset +end +function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in vnv.varname_to_index + len = length(vnv.ranges[idx]) + is_linked = vnv.is_unconstrained[idx] + range = offset:(offset + len - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += len + end + return all_iden_ranges, all_ranges, offset +end From 53bccc13dbd7976a2e12a4c3057ba71c2d56e1ad Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 17:46:06 +0000 Subject: [PATCH 15/30] fix bug in parsing ranges from metadata/VNV --- src/fastldf.jl | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 87dd698dc..15326a83b 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -391,16 +391,13 @@ function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms offset = 1 for sym in syms md = varinfo.metadata[sym] - this_md_iden, this_md_others, new_offset = get_ranges_and_linked_metadata( - md, offset - ) + this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) all_iden_ranges = merge(all_iden_ranges, this_md_iden) all_ranges = merge(all_ranges, this_md_others) - offset = new_offset end return all_iden_ranges, all_ranges end -function get_ranges_and_linked(varinfo::VarInfo{<:Metadata}) +function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) return all_iden, all_others end @@ -409,9 +406,8 @@ function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) all_ranges = Dict{VarName,RangeAndLinked}() offset = start_offset for (vn, idx) in md.idcs - len = length(md.ranges[idx]) is_linked = md.is_transformed[idx] - range = offset:(offset + len - 1) + range = md.ranges[idx] .+ (start_offset - 1) if AbstractPPL.getoptic(vn) === identity all_iden_ranges = merge( all_iden_ranges, @@ -420,7 +416,7 @@ function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) else all_ranges[vn] = RangeAndLinked(range, is_linked) end - offset += len + offset += length(range) end return all_iden_ranges, all_ranges, offset end @@ -429,9 +425,8 @@ function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) all_ranges = Dict{VarName,RangeAndLinked}() offset = start_offset for (vn, idx) in vnv.varname_to_index - len = length(vnv.ranges[idx]) is_linked = vnv.is_unconstrained[idx] - range = offset:(offset + len - 1) + range = vnv.ranges[idx] .+ (start_offset - 1) if AbstractPPL.getoptic(vn) === identity all_iden_ranges = merge( all_iden_ranges, @@ -440,7 +435,7 @@ function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) else all_ranges[vn] = RangeAndLinked(range, is_linked) end - offset += len + offset += length(range) end return all_iden_ranges, all_ranges, offset end From 30b9247080f7e15fa8e5259b2fbc8237c58b5487 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 18:04:40 +0000 Subject: [PATCH 16/30] Fix get_param_eltype for TSVI --- src/fastldf.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 15326a83b..eaca0c795 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -80,7 +80,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) -function DynamicPPL.get_param_eltype(::OnlyAccsVarInfo, model::Model) +function DynamicPPL.get_param_eltype( + ::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, model::Model +) # Because the VarInfo has no parameters stored in it, we need to get the eltype from the # model's leaf context. This is only possible if said leaf context is indeed a FastEval # context. @@ -333,15 +335,16 @@ end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params) model = DynamicPPL.setleafcontext(f._model, ctx) - only_accs_vi = OnlyAccsVarInfo(fast_ldf_accs(f._getlogdensity)) + accs = fast_ldf_accs(f._getlogdensity) # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic # here. vi = if Threads.nthreads() > 1 - ThreadSafeVarInfo(only_accs_vi) + accs = map(acc -> convert_eltype(float_type_with_fallback(eltype(params)), acc), accs) + ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) else - only_accs_vi + OnlyAccsVarInfo(accs) end _, vi = _evaluate!!(model, vi) return f._getlogdensity(vi) From 316937a2ff3014b4a2383d2dadc79134db60a258 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 18:17:26 +0000 Subject: [PATCH 17/30] Disable Enzyme benchmark --- benchmarks/benchmarks.jl | 2 +- src/fastldf.jl | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 5fe0320cc..e78bf602f 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -66,7 +66,7 @@ chosen_combinations = [ # ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), ("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), - ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), + # ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), ("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true), ("Multivariate 1k", multivariate1k, :typed, :mooncake, true), ("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true), diff --git a/src/fastldf.jl b/src/fastldf.jl index eaca0c795..c06f3495c 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -350,6 +350,25 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) return f._getlogdensity(vi) end +function _evaluate!!(model::Model, varinfo::OnlyAccsVarInfo) + if leafcontext(model.context) isa FastEvalVectorContext + args = map(maybe_deepcopy, model.args) + return model.f(model, varinfo, args...; model.defaults...) + else + error("Shouldn't happen") + end +end +maybe_deepcopy(@nospecialize(x)) = x +function maybe_deepcopy(x::AbstractArray{T}) where {T} + if T >: Missing + # avoid overwriting missing elements of model arguments when + # evaluating the model. + deepcopy(x) + else + x + end +end + function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) return FastLogDensityAt( fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges From 075cee8a3da81e1a96558274448fe8a7458f4f6e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 18:20:31 +0000 Subject: [PATCH 18/30] Don't override _evaluate!!, that breaks ForwardDiff (sometimes) --- src/fastldf.jl | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index c06f3495c..eaca0c795 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -350,25 +350,6 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) return f._getlogdensity(vi) end -function _evaluate!!(model::Model, varinfo::OnlyAccsVarInfo) - if leafcontext(model.context) isa FastEvalVectorContext - args = map(maybe_deepcopy, model.args) - return model.f(model, varinfo, args...; model.defaults...) - else - error("Shouldn't happen") - end -end -maybe_deepcopy(@nospecialize(x)) = x -function maybe_deepcopy(x::AbstractArray{T}) where {T} - if T >: Missing - # avoid overwriting missing elements of model arguments when - # evaluating the model. - deepcopy(x) - else - x - end -end - function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) return FastLogDensityAt( fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges From 5f5a92c405fad90ab008ad2963472431a781a6a8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 20:49:18 +0000 Subject: [PATCH 19/30] Move FastLDF to experimental for now --- benchmarks/benchmarks.jl | 12 ++-- benchmarks/src/DynamicPPLBenchmarks.jl | 4 +- src/DynamicPPL.jl | 1 - src/experimental.jl | 2 + src/{fastldf.jl => fasteval.jl} | 0 src/test_utils/ad.jl | 9 ++- test/ad.jl | 78 +++++++++++++++++++------- 7 files changed, 75 insertions(+), 31 deletions(-) rename src/{fastldf.jl => fasteval.jl} (100%) diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index e78bf602f..035d8ff49 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -59,14 +59,14 @@ chosen_combinations = [ false, ), ("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false), - # ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), - # ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), - # ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), - # ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true), - # ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), + ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), + ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), + ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), + ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true), + ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), ("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), - # ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), + ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), ("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true), ("Multivariate 1k", multivariate1k, :typed, :mooncake, true), ("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true), diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index e6988d3f2..225e40cd8 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -94,7 +94,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: vi = DynamicPPL.link(vi, model) end - f = DynamicPPL.FastLDF(model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend) + f = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend + ) # The parameters at which we evaluate f. θ = vi[:] diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 77d527ced..e66f3fe11 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -191,7 +191,6 @@ include("simple_varinfo.jl") include("compiler.jl") include("pointwise_logdensities.jl") include("logdensityfunction.jl") -include("fastldf.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") diff --git a/src/experimental.jl b/src/experimental.jl index 8c82dca68..31cece3b4 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -2,6 +2,8 @@ module Experimental using DynamicPPL: DynamicPPL +include("fastldf.jl") + # This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. """ is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...) diff --git a/src/fastldf.jl b/src/fasteval.jl similarity index 100% rename from src/fastldf.jl rename to src/fasteval.jl diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index fbbae85b7..a49ffd18b 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,7 +4,8 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: Model, FastLDF, VarInfo, AbstractVarInfo, getlogjoint_internal, link +using DynamicPPL: + Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: AbstractRNG, default_rng using Statistics: median @@ -264,7 +265,7 @@ function run_ad( # Calculate log-density and gradient with the backend of interest verbose && @info "Running AD on $(model.f) with $(adtype)\n" verbose && println(" params : $(params)") - ldf = FastLDF(model, getlogdensity, varinfo; adtype=adtype) + ldf = LogDensityFunction(model, getlogdensity, varinfo; adtype=adtype) value, grad = logdensity_and_gradient(ldf, params) # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 @@ -281,7 +282,9 @@ function run_ad( value_true = test.value grad_true = test.grad elseif test isa WithBackend - ldf_reference = FastLDF(model, getlogdensity, varinfo; adtype=test.adtype) + ldf_reference = LogDensityFunction( + model, getlogdensity, varinfo; adtype=test.adtype + ) value_true, grad_true = logdensity_and_gradient(ldf_reference, params) # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 grad_true = collect(grad_true) diff --git a/test/ad.jl b/test/ad.jl index 48b1b64ec..d7505aab2 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,6 +1,5 @@ -using DynamicPPL: FastLDF +using DynamicPPL: LogDensityFunction using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest -using Random: Xoshiro @testset "Automatic differentiation" begin # Used as the ground truth that others are compared against. @@ -16,25 +15,64 @@ using Random: Xoshiro [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] end + @testset "Unsupported backends" begin + @model demo() = x ~ Normal() + @test_logs (:warn, r"not officially supported") LogDensityFunction( + demo(); adtype=AutoZygote() + ) + end + @testset "Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - varinfo = VarInfo(Xoshiro(468), m) - linked_varinfo = DynamicPPL.link(varinfo, m) - f = FastLDF(m, getlogjoint_internal, linked_varinfo) - x = linked_varinfo[:] - - # Calculate reference logp + gradient of logp using ForwardDiff - ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) - ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual - - @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $adtype" - @test run_ad( - m, - adtype; - varinfo=linked_varinfo, - test=WithExpectedResult(ref_logp, ref_grad), - ) isa Any + rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) + vns = DynamicPPL.TestUtils.varnames(m) + varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) + + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + linked_varinfo = DynamicPPL.link(varinfo, m) + f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) + x = DynamicPPL.getparams(f) + + # Calculate reference logp + gradient of logp using ForwardDiff + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual + + @testset "$adtype" for adtype in test_adtypes + @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" + + # Put predicates here to avoid long lines + is_mooncake = adtype isa AutoMooncake + is_1_10 = v"1.10" <= VERSION < v"1.11" + is_1_11 = v"1.11" <= VERSION < v"1.12" + is_svi_vnv = + linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} + is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict} + + # Mooncake doesn't work with several combinations of SimpleVarInfo. + if is_mooncake && is_1_11 && is_svi_vnv + # https://github.com/compintell/Mooncake.jl/issues/470 + @test_throws ArgumentError DynamicPPL.LogDensityFunction( + m, getlogjoint_internal, linked_varinfo; adtype=adtype + ) + elseif is_mooncake && is_1_10 && is_svi_vnv + # TODO: report upstream + @test_throws UndefRefError DynamicPPL.LogDensityFunction( + m, getlogjoint_internal, linked_varinfo; adtype=adtype + ) + elseif is_mooncake && is_1_10 && is_svi_od + # TODO: report upstream + @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( + m, getlogjoint_internal, linked_varinfo; adtype=adtype + ) + else + @test run_ad( + m, + adtype; + varinfo=linked_varinfo, + test=WithExpectedResult(ref_logp, ref_grad), + ) isa Any + end + end end end end @@ -45,7 +83,7 @@ using Random: Xoshiro test_m = randn(2, 3) function eval_logp_and_grad(model, m, adtype) - ldf = FastLDF(model(); adtype=adtype) + ldf = LogDensityFunction(model(); adtype=adtype) return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) end From 0716de57d93976cfbfbee68840d251bf29c628de Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 20:56:09 +0000 Subject: [PATCH 20/30] Fix imports, add tests, etc --- src/experimental.jl | 2 +- src/fasteval.jl | 40 +++++++++++++-- test/fasteval.jl | 115 ++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 153 insertions(+), 5 deletions(-) create mode 100644 test/fasteval.jl diff --git a/src/experimental.jl b/src/experimental.jl index 31cece3b4..c644c09b2 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -2,7 +2,7 @@ module Experimental using DynamicPPL: DynamicPPL -include("fastldf.jl") +include("fasteval.jl") # This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. """ diff --git a/src/fasteval.jl b/src/fasteval.jl index eaca0c795..2bd7b3a0c 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -60,6 +60,35 @@ However, the path towards implementing these is straightforward: functionality would be quite similar to `InitContext(InitFromParams(...))`. """ +using DynamicPPL: + AbstractContext, + AbstractVarInfo, + AccumulatorTuple, + Metadata, + Model, + ThreadSafeVarInfo, + VarInfo, + VarNamedVector, + accumulate_assume!!, + accumulate_observe!!, + default_accumulators, + float_type_with_fallback, + from_linked_vec_transform, + from_vec_transform, + getlogjoint, + getlogjoint_internal, + getloglikelihood, + getlogprior, + getlogprior_internal, + leafcontext +using ADTypes: ADTypes +using Bijectors: with_logabsdet_jacobian +using AbstractPPL: AbstractPPL, VarName +using Distributions: Distribution +using DocStringExtensions: TYPEDFIELDS +using LogDensityProblems: LogDensityProblems +import DifferentiationInterface as DI + """ OnlyAccsVarInfo @@ -121,7 +150,7 @@ Abstract type representing fast evaluation contexts. This currently is only subt NamedTuple and Dict parameters. """ abstract type AbstractFastEvalContext <: AbstractContext end -DynamicPPL.NodeTrait(::AbstractFastEvalContext) = IsLeaf() +DynamicPPL.NodeTrait(::AbstractFastEvalContext) = DynamicPPL.IsLeaf() """ FastEvalVectorContext( @@ -286,7 +315,7 @@ struct FastLDF{ nothing else # Make backend-specific tweaks to the adtype - adtype = tweak_adtype(adtype, model, varinfo) + adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) x = [val for val in varinfo[:]] DI.prepare_gradient( FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), @@ -341,12 +370,15 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic # here. vi = if Threads.nthreads() > 1 - accs = map(acc -> convert_eltype(float_type_with_fallback(eltype(params)), acc), accs) + accs = map( + acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc), + accs, + ) ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) else OnlyAccsVarInfo(accs) end - _, vi = _evaluate!!(model, vi) + _, vi = DynamicPPL._evaluate!!(model, vi) return f._getlogdensity(vi) end diff --git a/test/fasteval.jl b/test/fasteval.jl new file mode 100644 index 000000000..563d2adba --- /dev/null +++ b/test/fasteval.jl @@ -0,0 +1,115 @@ +module DynamicPPLFastLDFTests + +using DynamicPPL +using Distributions +using DistributionsAD: filldist +using ADTypes +using DynamicPPL.Experimental: FastLDF + +@testset "Automatic differentiation" begin + # Used as the ground truth that others are compared against. + ref_adtype = AutoForwardDiff() + + test_adtypes = if MOONCAKE_SUPPORTED + [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] + else + [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] + end + + @testset "Unsupported backends" begin + @model demo() = x ~ Normal() + @test_logs (:warn, r"not officially supported") FastLDF(demo(); adtype=AutoZygote()) + end + + @testset "Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + varinfo = VarInfo(m) + linked_varinfo = DynamicPPL.link(varinfo, m) + f = FastLDF(m, getlogjoint_internal, linked_varinfo) + x = linked_varinfo[:] + + # Calculate reference logp + gradient of logp using ForwardDiff + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual + + @testset "$adtype" for adtype in test_adtypes + @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" + + @test run_ad( + m, + adtype; + varinfo=linked_varinfo, + test=WithExpectedResult(ref_logp, ref_grad), + ) isa Any + end + end + end + + # Test that various different ways of specifying array types as arguments work with all + # ADTypes. + @testset "Array argument types" begin + test_m = randn(2, 3) + + function eval_logp_and_grad(model, m, adtype) + ldf = FastLDF(model(); adtype=adtype) + return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) + end + + @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} + m = Matrix{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_matrix_model_reference = eval_logp_and_grad( + scalar_matrix_model, test_m, ref_adtype + ) + + @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) + + @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} + m = Array{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_array_model_reference = eval_logp_and_grad( + scalar_array_model, test_m, ref_adtype + ) + + @model function array_model(::Type{T}=Array{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) + + @testset "$adtype" for adtype in test_adtypes + scalar_matrix_model_logp_and_grad = eval_logp_and_grad( + scalar_matrix_model, test_m, adtype + ) + @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] + @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] + matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) + @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] + @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] + scalar_array_model_logp_and_grad = eval_logp_and_grad( + scalar_array_model, test_m, adtype + ) + @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] + @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] + array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) + @test array_model_logp_and_grad[1] ≈ array_model_reference[1] + @test array_model_logp_and_grad[2] ≈ array_model_reference[2] + end + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 861d3bb87..10fac8b0f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -89,6 +89,7 @@ include("test_util.jl") include("ext/DynamicPPLMooncakeExt.jl") end include("ad.jl") + include("fasteval.jl") end @testset "prob and logprob macro" begin @test_throws ErrorException prob"..." From cd2461efc96d883e1d824f04324fdfd2ec1658a1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 21:00:14 +0000 Subject: [PATCH 21/30] More test fixes --- ext/DynamicPPLEnzymeCoreExt.jl | 2 +- ext/DynamicPPLMooncakeExt.jl | 6 ++++-- test/fasteval.jl | 18 ++++++++++++------ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index 29a4e2cc7..eacc35046 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -9,7 +9,7 @@ using EnzymeCore nothing # Likewise for get_range_and_linked. @inline EnzymeCore.EnzymeRules.inactive( - ::typeof(DynamicPPL.get_range_and_linked), args... + ::typeof(DynamicPPL.Experimental.get_range_and_linked), args... ) = nothing end diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index e49b81cb2..63c754f4c 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -1,10 +1,12 @@ module DynamicPPLMooncakeExt -using DynamicPPL: DynamicPPL, is_transformed, get_range_and_linked +using DynamicPPL: DynamicPPL, is_transformed using Mooncake: Mooncake # This is purely an optimisation. Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg} -Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(get_range_and_linked),Vararg} +Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{ + typeof(DynamicPPL.Experimental.get_range_and_linked),Vararg +} end # module diff --git a/test/fasteval.jl b/test/fasteval.jl index 563d2adba..3d11c6889 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -5,12 +5,23 @@ using Distributions using DistributionsAD: filldist using ADTypes using DynamicPPL.Experimental: FastLDF +using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest +using Test + +using ForwardDiff: ForwardDiff +using ReverseDiff: ReverseDiff +# Need to include this block here in case we run this test file standalone +@static if VERSION < v"1.12" + using Pkg + Pkg.add("Mooncake") + using Mooncake: Mooncake +end @testset "Automatic differentiation" begin # Used as the ground truth that others are compared against. ref_adtype = AutoForwardDiff() - test_adtypes = if MOONCAKE_SUPPORTED + test_adtypes = @static if VERSION < v"1.12" [ AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true), @@ -20,11 +31,6 @@ using DynamicPPL.Experimental: FastLDF [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] end - @testset "Unsupported backends" begin - @model demo() = x ~ Normal() - @test_logs (:warn, r"not officially supported") FastLDF(demo(); adtype=AutoZygote()) - end - @testset "Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS varinfo = VarInfo(m) From 1b8b87369a29bdd25a4b2c86879ec0064081d85e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 21:37:26 +0000 Subject: [PATCH 22/30] Fix imports / tests --- src/fasteval.jl | 26 ++++++++++++++++---------- test/fasteval.jl | 43 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 57 insertions(+), 12 deletions(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index 2bd7b3a0c..533e19812 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -50,14 +50,10 @@ Fast evaluation has not yet been extended to NamedTuple and Dict parameters. Suc representations are capable of handling models with variable sizes and stochastic control flow. -However, the path towards implementing these is straightforward: - -1. Currently, `FastLDFVectorContext` allows users to input a VarName and obtain the parameter - value, plus a boolean indicating whether the value is linked or unlinked. See the - `get_range_and_linked` function for details. - -2. We would need to implement similar contexts for NamedTuple and Dict parameters. The - functionality would be quite similar to `InitContext(InitFromParams(...))`. +However, the path towards implementing these is straightforward: just make `InitContext` work +correctly with `OnlyAccsVarInfo`. There will probably be a few functions that need to be +overloaded to make this work: for example `push!!` on `OnlyAccsVarInfo` can just be defined +as a no-op. """ using DynamicPPL: @@ -119,6 +115,13 @@ function DynamicPPL.get_param_eltype( if leaf_ctx isa FastEvalVectorContext return eltype(leaf_ctx.params) else + # TODO(penelopeysm): In principle this can be done with InitContext{InitWithParams}. + # See also `src/simple_varinfo.jl` where `infer_nested_eltype` is used to try to + # figure out the parameter type from a NamedTuple or Dict. The benefit of + # implementing this for InitContext is that we could then use OnlyAccsVarInfo with + # it, which means fast evaluation with NamedTuple or Dict parameters! And I believe + # that Mooncake / Enzyme should be able to differentiate through that too and + # provide a NamedTuple of gradients (although I haven't tested this yet). error( "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", ) @@ -188,7 +191,7 @@ function get_range_and_linked(ctx::FastEvalVectorContext, vn::VarName) return ctx.varname_ranges[vn] end -function tilde_assume!!( +function DynamicPPL.tilde_assume!!( ctx::FastEvalVectorContext, right::Distribution, vn::VarName, vi::AbstractVarInfo ) # Note that this function does not use the metadata field of `vi` at all. @@ -204,7 +207,7 @@ function tilde_assume!!( return x, vi end -function tilde_observe!!( +function DynamicPPL.tilde_observe!!( ::FastEvalVectorContext, right::Distribution, left, @@ -369,6 +372,9 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic # here. + # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what + # it _should_ do, but this is wrong regardless. + # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 vi = if Threads.nthreads() > 1 accs = map( acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc), diff --git a/test/fasteval.jl b/test/fasteval.jl index 3d11c6889..277384350 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -1,12 +1,15 @@ module DynamicPPLFastLDFTests +using AbstractPPL: AbstractPPL using DynamicPPL using Distributions using DistributionsAD: filldist using ADTypes using DynamicPPL.Experimental: FastLDF using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest +using LinearAlgebra: I using Test +using LogDensityProblems: LogDensityProblems using ForwardDiff: ForwardDiff using ReverseDiff: ReverseDiff @@ -17,7 +20,43 @@ using ReverseDiff: ReverseDiff using Mooncake: Mooncake end -@testset "Automatic differentiation" begin +@testset "get_ranges_and_linked" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$varinfo_func" for varinfo_func in [ + DynamicPPL.untyped_varinfo, + DynamicPPL.typed_varinfo, + DynamicPPL.untyped_vector_varinfo, + DynamicPPL.typed_vector_varinfo, + ] + unlinked_vi = varinfo_func(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) + params = vi[:] + # Iterate over all variables + for vn in keys(vi) + # Check that `getindex_internal` returns the same thing as using the ranges + # directly + range_with_linked = if AbstractPPL.getoptic(vn) === identity + nt_ranges[AbstractPPL.getsym(vn)] + else + dict_ranges[vn] + end + @test params[range_with_linked.range] == + DynamicPPL.getindex_internal(vi, vn) + # Check that the link status is correct + @test range_with_linked.is_linked == islinked + end + end + end + end +end + +@testset "AD with FastLDF" begin # Used as the ground truth that others are compared against. ref_adtype = AutoForwardDiff() @@ -43,7 +82,7 @@ end ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" + @info "Testing AD on: $(m.f) - $adtype" @test run_ad( m, From ff5680d6cc63d007dd4945892dec1f4932b2aad2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 21:56:32 +0000 Subject: [PATCH 23/30] Remove AbstractFastEvalContext --- src/fasteval.jl | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index 533e19812..a0df8b373 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -145,22 +145,12 @@ struct RangeAndLinked is_linked::Bool end -""" - AbstractFastEvalContext - -Abstract type representing fast evaluation contexts. This currently is only subtyped by -`FastEvalVectorContext`. However, in the future, similar contexts may be implemented for -NamedTuple and Dict parameters. -""" -abstract type AbstractFastEvalContext <: AbstractContext end -DynamicPPL.NodeTrait(::AbstractFastEvalContext) = DynamicPPL.IsLeaf() - """ FastEvalVectorContext( iden_varname_ranges::NamedTuple, varname_ranges::Dict{VarName,RangeAndLinked}, params::AbstractVector{<:Real}, - ) + ) <: AbstractContext A context that wraps a vector of parameter values, plus information about how random variables map to ranges in that vector. @@ -173,8 +163,7 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict. It would be nice to unify the NamedTuple and Dict approach. See, e.g. https://github.com/TuringLang/DynamicPPL.jl/issues/1116. """ -struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: - AbstractFastEvalContext +struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext # This NamedTuple stores the ranges for identity VarNames iden_varname_ranges::N # This Dict stores the ranges for all other VarNames @@ -182,6 +171,8 @@ struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: # The full parameter vector which we index into to get variable values params::T end +DynamicPPL.NodeTrait(::FastEvalVectorContext) = DynamicPPL.IsLeaf() + function get_range_and_linked( ctx::FastEvalVectorContext, ::VarName{sym,typeof(identity)} ) where {sym} From 500d5acbefa26222d1fcb6269c2493a198207542 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 21:59:47 +0000 Subject: [PATCH 24/30] Changelog and patch bump --- HISTORY.md | 10 ++++++++++ Project.toml | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index d9be6da03..22a25bd8e 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,15 @@ # DynamicPPL Changelog +## 0.38.9 + +Added `DynamicPPL.Experimental.FastLDF`, a version of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. +Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. + +Please note that `FastLDF` is currently considered internal and its API may change without warning. +We intend to replace `LogDensityFunction` with `FastLDF` in a release in the near future, but until then we recommend not using it. + +For more information about `FastLDF`, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. + ## 0.38.8 Added a new exported struct, `DynamicPPL.ParamsWithStats`. diff --git a/Project.toml b/Project.toml index c71b89bc7..23f5eec5b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.38.8" +version = "0.38.9" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From e560c301f692f80f613c0a7c9f1f75c0ee726a11 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 22:07:56 +0000 Subject: [PATCH 25/30] Add correctness tests, fix imports --- src/fasteval.jl | 3 +++ test/fasteval.jl | 22 +++++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index a0df8b373..c91254d43 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -60,6 +60,9 @@ using DynamicPPL: AbstractContext, AbstractVarInfo, AccumulatorTuple, + LogJacobianAccumulator, + LogLikelihoodAccumulator, + LogPriorAccumulator, Metadata, Model, ThreadSafeVarInfo, diff --git a/test/fasteval.jl b/test/fasteval.jl index 277384350..051c82f0f 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -20,7 +20,7 @@ using ReverseDiff: ReverseDiff using Mooncake: Mooncake end -@testset "get_ranges_and_linked" begin +@testset "FastLDF: Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS @testset "$varinfo_func" for varinfo_func in [ DynamicPPL.untyped_varinfo, @@ -51,6 +51,26 @@ end # Check that the link status is correct @test range_with_linked.is_linked == islinked end + + # Compare results of FastLDF vs ordinary LogDensityFunction. These tests + # can eventually go once we replace LogDensityFunction with FastLDF, but + # for now it helps to have this check! (Eventually we should just check + # against manually computed log-densities). + # + # TODO(penelopeysm): I think we need to add tests for some really + # pathological models here. + @testset "$getlogdensity" for getlogdensity in ( + DynamicPPL.getlogjoint_internal, + DynamicPPL.getlogjoint, + DynamicPPL.getloglikelihood, + DynamicPPL.getlogprior_internal, + DynamicPPL.getlogprior, + ) + ldf = DynamicPPL.LogDensityFunction(m, getlogdensity, vi) + fldf = FastLDF(m, getlogdensity, vi) + @test LogDensityProblems.logdensity(ldf, params) ≈ + LogDensityProblems.logdensity(fldf, params) + end end end end From 86d8a735d840e25c73860d5eee7d871e178e111c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 8 Nov 2025 17:25:43 +0000 Subject: [PATCH 26/30] Concretise parameter vector in tests --- test/fasteval.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/fasteval.jl b/test/fasteval.jl index 051c82f0f..863ec369e 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -36,7 +36,7 @@ end unlinked_vi end nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) - params = vi[:] + params = map(identity, vi[:]) # Iterate over all variables for vn in keys(vi) # Check that `getindex_internal` returns the same thing as using the ranges @@ -95,7 +95,7 @@ end varinfo = VarInfo(m) linked_varinfo = DynamicPPL.link(varinfo, m) f = FastLDF(m, getlogjoint_internal, linked_varinfo) - x = linked_varinfo[:] + x = map(identity, linked_varinfo[:]) # Calculate reference logp + gradient of logp using ForwardDiff ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) From 4b324b03cb21de67d5a87e813dfde4ad8af528aa Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 9 Nov 2025 02:42:07 +0000 Subject: [PATCH 27/30] Add zero-allocation tests --- test/fasteval.jl | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/fasteval.jl b/test/fasteval.jl index 863ec369e..9907e18a1 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -1,6 +1,7 @@ module DynamicPPLFastLDFTests using AbstractPPL: AbstractPPL +using Chairmarks using DynamicPPL using Distributions using DistributionsAD: filldist @@ -76,6 +77,34 @@ end end end +@testset "FastLDF: performance" begin + # Evaluating these three models should not lead to any allocations. + @model function f() + x ~ Normal() + return 1.0 ~ Normal(x) + end + @model function submodel_inner() + m ~ Normal(0, 1) + s ~ Exponential() + return (m=m, s=s) + end + # Note that for the allocation tests to work on this one, `inner` has + # to be passed as an argument to `submodel_outer`, instead of just + # being called inside the model function itself + @model function submodel_outer(inner) + params ~ to_submodel(inner) + y ~ Normal(params.m, params.s) + return 1.0 ~ Normal(y) + end + @testset for model in (f(), submodel_inner(), submodel_outer(submodel_inner())) + vi = VarInfo(model) + fldf = DynamicPPL.Experimental.FastLDF(model, DynamicPPL.getlogjoint_internal, vi) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity(fldf, x)) + @test iszero(bench.allocs) + end +end + @testset "AD with FastLDF" begin # Used as the ground truth that others are compared against. ref_adtype = AutoForwardDiff() From c55171b80b01d52d08848f17ad46bc1d2398788c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 9 Nov 2025 03:01:34 +0000 Subject: [PATCH 28/30] Add Chairmarks as test dep --- test/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/test/Project.toml b/test/Project.toml index 2dbd5b455..efd916308 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" From 77deae96084361f641ecc4873e1186b341031882 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 9 Nov 2025 03:14:18 +0000 Subject: [PATCH 29/30] Disable allocations tests on multi-threaded --- test/fasteval.jl | 53 ++++++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/test/fasteval.jl b/test/fasteval.jl index 9907e18a1..57a2c937d 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -78,30 +78,35 @@ end end @testset "FastLDF: performance" begin - # Evaluating these three models should not lead to any allocations. - @model function f() - x ~ Normal() - return 1.0 ~ Normal(x) - end - @model function submodel_inner() - m ~ Normal(0, 1) - s ~ Exponential() - return (m=m, s=s) - end - # Note that for the allocation tests to work on this one, `inner` has - # to be passed as an argument to `submodel_outer`, instead of just - # being called inside the model function itself - @model function submodel_outer(inner) - params ~ to_submodel(inner) - y ~ Normal(params.m, params.s) - return 1.0 ~ Normal(y) - end - @testset for model in (f(), submodel_inner(), submodel_outer(submodel_inner())) - vi = VarInfo(model) - fldf = DynamicPPL.Experimental.FastLDF(model, DynamicPPL.getlogjoint_internal, vi) - x = vi[:] - bench = median(@be LogDensityProblems.logdensity(fldf, x)) - @test iszero(bench.allocs) + if Threads.nthreads() == 1 + # Evaluating these three models should not lead to any allocations (but only when + # not using TSVI). + @model function f() + x ~ Normal() + return 1.0 ~ Normal(x) + end + @model function submodel_inner() + m ~ Normal(0, 1) + s ~ Exponential() + return (m=m, s=s) + end + # Note that for the allocation tests to work on this one, `inner` has + # to be passed as an argument to `submodel_outer`, instead of just + # being called inside the model function itself + @model function submodel_outer(inner) + params ~ to_submodel(inner) + y ~ Normal(params.m, params.s) + return 1.0 ~ Normal(y) + end + @testset for model in (f(), submodel_inner(), submodel_outer(submodel_inner())) + vi = VarInfo(model) + fldf = DynamicPPL.Experimental.FastLDF( + model, DynamicPPL.getlogjoint_internal, vi + ) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity(fldf, x)) + @test iszero(bench.allocs) + end end end From 8715446bd4ac39ad20013470554156ec8dbba9a0 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 10 Nov 2025 19:28:52 +0000 Subject: [PATCH 30/30] Fast InitContext (#1125) * Make InitContext work with OnlyAccsVarInfo * Do not convert NamedTuple to Dict * remove logging * Enable InitFromPrior and InitFromUniform too * Fix `infer_nested_eltype` invocation --- src/contexts/init.jl | 2 +- src/fasteval.jl | 20 +++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 44dbc5508..efc6f1087 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -102,7 +102,7 @@ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitS function InitFromParams( params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() ) - return InitFromParams(to_varname_dict(params), fallback) + return new{typeof(params),typeof(fallback)}(params, fallback) end end function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) diff --git a/src/fasteval.jl b/src/fasteval.jl index c91254d43..fbc6a61ce 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -60,6 +60,10 @@ using DynamicPPL: AbstractContext, AbstractVarInfo, AccumulatorTuple, + InitContext, + InitFromParams, + InitFromPrior, + InitFromUniform, LogJacobianAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, @@ -81,6 +85,7 @@ using DynamicPPL: getlogprior_internal, leafcontext using ADTypes: ADTypes +using BangBang: BangBang using Bijectors: with_logabsdet_jacobian using AbstractPPL: AbstractPPL, VarName using Distributions: Distribution @@ -108,6 +113,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) +@inline Base.haskey(::OnlyAccsVarInfo, ::VarName) = false +@inline DynamicPPL.is_transformed(::OnlyAccsVarInfo) = false +@inline BangBang.push!!(vi::OnlyAccsVarInfo, vn, y, dist) = vi function DynamicPPL.get_param_eltype( ::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, model::Model ) @@ -117,14 +125,12 @@ function DynamicPPL.get_param_eltype( leaf_ctx = DynamicPPL.leafcontext(model.context) if leaf_ctx isa FastEvalVectorContext return eltype(leaf_ctx.params) + elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams} + return DynamicPPL.infer_nested_eltype(typeof(leaf_ctx.strategy.params)) + elseif leaf_ctx isa InitContext{<:Any,<:Union{InitFromPrior,InitFromUniform}} + # No need to enforce any particular eltype here, since new parameters are sampled + return Any else - # TODO(penelopeysm): In principle this can be done with InitContext{InitWithParams}. - # See also `src/simple_varinfo.jl` where `infer_nested_eltype` is used to try to - # figure out the parameter type from a NamedTuple or Dict. The benefit of - # implementing this for InitContext is that we could then use OnlyAccsVarInfo with - # it, which means fast evaluation with NamedTuple or Dict parameters! And I believe - # that Mooncake / Enzyme should be able to differentiate through that too and - # provide a NamedTuple of gradients (although I haven't tested this yet). error( "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", )