From 5bd9748313d495720abb1f58d6441dc4ca05dcde Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 5 Nov 2025 23:58:25 +0000 Subject: [PATCH 01/57] Fast Log Density Function --- src/DynamicPPL.jl | 1 + src/fastldf.jl | 91 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 src/fastldf.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c43bd89d5..1a5c338ff 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -196,6 +196,7 @@ include("simple_varinfo.jl") include("compiler.jl") include("pointwise_logdensities.jl") include("logdensityfunction.jl") +include("fastldf.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") diff --git a/src/fastldf.jl b/src/fastldf.jl new file mode 100644 index 000000000..05b024dbe --- /dev/null +++ b/src/fastldf.jl @@ -0,0 +1,91 @@ +struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo + accs::Accs +end +DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs +DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi +DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) + +struct RangeAndLinked + # indices that the variable corresponds to in the vectorised parameter + range::UnitRange{Int} + # whether it's linked + is_linked::Bool +end + +struct FastLDFContext{T<:AbstractVector{<:Real}} <: AbstractContext + varname_ranges::Dict{VarName,RangeAndLinked} + params::T +end +DynamicPPL.NodeTrait(::FastLDFContext) = IsLeaf() + +function tilde_assume!!( + ctx::FastLDFContext, right::Distribution, vn::VarName, vi::OnlyAccsVarInfo +) + # Don't need to read the data from the varinfo at all since it's + # all inside the context. + range_and_linked = ctx.varname_ranges[vn] + y = @view ctx.params[range_and_linked.range] + is_linked = range_and_linked.is_linked + f = if is_linked + from_linked_vec_transform(right) + else + from_vec_transform(right) + end + x, inv_logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) + return x, vi +end + +function tilde_observe!!( + ::FastLDFContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::OnlyAccsVarInfo, +) + # This is the same as for DefaultContext + vi = accumulate_observe!!(vi, right, left, vn) + return left, vi +end + +struct FastLDF{M<:Model,F<:Function} + _model::M + _getlogdensity::F + _varname_ranges::Dict{VarName,RangeAndLinked} + + function FastLDF( + model::Model, + getlogdensity::Function, + # This only works with typed Metadata-varinfo. + # Obviously, this can be generalised later. + varinfo::VarInfo{<:NamedTuple{syms}}, + ) where {syms} + # Figure out which variable corresponds to which index, and + # which variables are linked. + all_ranges = Dict{VarName,RangeAndLinked}() + offset = 1 + for sym in syms + md = varinfo.metadata[sym] + for (vn, idx) in md.idcs + len = length(md.ranges[idx]) + is_linked = md.is_transformed[idx] + range = offset:(offset + len - 1) + all_ranges[vn] = RangeAndLinked(range, is_linked) + offset += len + end + end + return new{typeof(model),typeof(getlogdensity)}(model, getlogdensity, all_ranges) + end +end + +function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) + ctx = FastLDFContext(fldf._varname_ranges, params) + model = DynamicPPL.setleafcontext(fldf._model, ctx) + # This can obviously also be optimised for the case where not + # all accumulators are needed. + accs = AccumulatorTuple(( + LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator() + )) + _, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs)) + return fldf._getlogdensity(vi) +end From 38d63f42c3dd9ede6461a3a8887d5fbd3357ddac Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 00:33:49 +0000 Subject: [PATCH 02/57] Make it work with AD --- src/fastldf.jl | 56 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 49 insertions(+), 7 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 05b024dbe..2178c84a2 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -48,17 +48,25 @@ function tilde_observe!!( return left, vi end -struct FastLDF{M<:Model,F<:Function} +struct FastLDF{ + M<:Model, + F<:Function, + AD<:Union{ADTypes.AbstractADType,Nothing}, + ADP<:Union{Nothing,DI.GradientPrep}, +} _model::M _getlogdensity::F _varname_ranges::Dict{VarName,RangeAndLinked} + _adtype::AD + _adprep::ADP function FastLDF( model::Model, getlogdensity::Function, # This only works with typed Metadata-varinfo. # Obviously, this can be generalised later. - varinfo::VarInfo{<:NamedTuple{syms}}, + varinfo::VarInfo{<:NamedTuple{syms}}; + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) where {syms} # Figure out which variable corresponds to which index, and # which variables are linked. @@ -74,18 +82,52 @@ struct FastLDF{M<:Model,F<:Function} offset += len end end - return new{typeof(model),typeof(getlogdensity)}(model, getlogdensity, all_ranges) + # Do AD prep if needed + prep = if adtype === nothing + nothing + else + # Make backend-specific tweaks to the adtype + adtype = tweak_adtype(adtype, model, varinfo) + x = [val for val in varinfo[:]] + DI.prepare_gradient( + FastLogDensityAt(model, getlogdensity, all_ranges), adtype, x + ) + end + + return new{typeof(model),typeof(getlogdensity),typeof(adtype),typeof(prep)}( + model, getlogdensity, all_ranges, adtype, prep + ) end end -function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) - ctx = FastLDFContext(fldf._varname_ranges, params) - model = DynamicPPL.setleafcontext(fldf._model, ctx) +struct FastLogDensityAt{M<:Model,F<:Function} + _model::M + _getlogdensity::F + _varname_ranges::Dict{VarName,RangeAndLinked} +end +function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) + ctx = FastLDFContext(f._varname_ranges, params) + model = DynamicPPL.setleafcontext(f._model, ctx) # This can obviously also be optimised for the case where not # all accumulators are needed. accs = AccumulatorTuple(( LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator() )) _, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs)) - return fldf._getlogdensity(vi) + return f._getlogdensity(vi) +end + +function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) + return FastLogDensityAt(fldf._model, fldf._getlogdensity, fldf._varname_ranges)(params) +end + +function LogDensityProblems.logdensity_and_gradient( + fldf::FastLDF, params::AbstractVector{<:Real} +) + return DI.value_and_gradient( + FastLogDensityAt(fldf._model, fldf._getlogdensity, fldf._varname_ranges), + fldf._adprep, + fldf._adtype, + params, + ) end From 0b475cabf3e47998b1aeb3ee0b36c8eb32796b09 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 01:37:05 +0000 Subject: [PATCH 03/57] Optimise performance for identity VarNames --- src/fastldf.jl | 69 +++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 57 insertions(+), 12 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 2178c84a2..e59b12791 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -12,21 +12,35 @@ struct RangeAndLinked is_linked::Bool end -struct FastLDFContext{T<:AbstractVector{<:Real}} <: AbstractContext +struct FastLDFContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext + # The ranges of identity VarNames are stored in a NamedTuple for performance + # reasons. For just plain evaluation this doesn't make _that_ much of a + # difference (maybe 1.5x), but when doing AD with Mooncake this makes a HUGE + # difference (around 4x). Of course, the exact numbers depend on the model. + iden_varname_ranges::N + # This Dict stores the ranges for all other VarNames varname_ranges::Dict{VarName,RangeAndLinked} + # The full parameter vector which we index into to get variable values params::T end DynamicPPL.NodeTrait(::FastLDFContext) = IsLeaf() +function get_range_and_linked( + ctx::FastLDFContext, ::VarName{sym,typeof(identity)} +) where {sym} + return ctx.iden_varname_ranges[sym] +end +function get_range_and_linked(ctx::FastLDFContext, vn::VarName) + return ctx.varname_ranges[vn] +end function tilde_assume!!( ctx::FastLDFContext, right::Distribution, vn::VarName, vi::OnlyAccsVarInfo ) # Don't need to read the data from the varinfo at all since it's # all inside the context. - range_and_linked = ctx.varname_ranges[vn] + range_and_linked = get_range_and_linked(ctx, vn) y = @view ctx.params[range_and_linked.range] - is_linked = range_and_linked.is_linked - f = if is_linked + f = if range_and_linked.is_linked from_linked_vec_transform(right) else from_vec_transform(right) @@ -51,11 +65,14 @@ end struct FastLDF{ M<:Model, F<:Function, + N<:NamedTuple, AD<:Union{ADTypes.AbstractADType,Nothing}, ADP<:Union{Nothing,DI.GradientPrep}, } _model::M _getlogdensity::F + # See FastLDFContext for explanation of these two fields + _iden_varname_ranges::N _varname_ranges::Dict{VarName,RangeAndLinked} _adtype::AD _adprep::ADP @@ -70,6 +87,7 @@ struct FastLDF{ ) where {syms} # Figure out which variable corresponds to which index, and # which variables are linked. + all_iden_ranges = NamedTuple() all_ranges = Dict{VarName,RangeAndLinked}() offset = 1 for sym in syms @@ -78,7 +96,16 @@ struct FastLDF{ len = length(md.ranges[idx]) is_linked = md.is_transformed[idx] range = offset:(offset + len - 1) - all_ranges[vn] = RangeAndLinked(range, is_linked) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple(( + AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked), + )), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end offset += len end end @@ -90,23 +117,32 @@ struct FastLDF{ adtype = tweak_adtype(adtype, model, varinfo) x = [val for val in varinfo[:]] DI.prepare_gradient( - FastLogDensityAt(model, getlogdensity, all_ranges), adtype, x + FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + adtype, + x, ) end - return new{typeof(model),typeof(getlogdensity),typeof(adtype),typeof(prep)}( - model, getlogdensity, all_ranges, adtype, prep + return new{ + typeof(model), + typeof(getlogdensity), + typeof(all_iden_ranges), + typeof(adtype), + typeof(prep), + }( + model, getlogdensity, all_iden_ranges, all_ranges, adtype, prep ) end end -struct FastLogDensityAt{M<:Model,F<:Function} +struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} _model::M _getlogdensity::F + _iden_varname_ranges::N _varname_ranges::Dict{VarName,RangeAndLinked} end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) - ctx = FastLDFContext(f._varname_ranges, params) + ctx = FastLDFContext(f._iden_varname_ranges, f._varname_ranges, params) model = DynamicPPL.setleafcontext(f._model, ctx) # This can obviously also be optimised for the case where not # all accumulators are needed. @@ -118,14 +154,23 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) end function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) - return FastLogDensityAt(fldf._model, fldf._getlogdensity, fldf._varname_ranges)(params) + return FastLogDensityAt( + fldf._model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges + )( + params + ) end function LogDensityProblems.logdensity_and_gradient( fldf::FastLDF, params::AbstractVector{<:Real} ) return DI.value_and_gradient( - FastLogDensityAt(fldf._model, fldf._getlogdensity, fldf._varname_ranges), + FastLogDensityAt( + fldf._model, + fldf._getlogdensity, + fldf._iden_varname_ranges, + fldf._varname_ranges, + ), fldf._adprep, fldf._adtype, params, From 9123b2b20685ec849938e31269e181578b1cc71e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 02:43:55 +0000 Subject: [PATCH 04/57] Mark `get_range_and_linked` as having zero derivative --- ext/DynamicPPLEnzymeCoreExt.jl | 13 ++++++------- ext/DynamicPPLMooncakeExt.jl | 3 ++- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index 35159636f..29a4e2cc7 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -1,16 +1,15 @@ module DynamicPPLEnzymeCoreExt -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL - using EnzymeCore -else - using ..DynamicPPL: DynamicPPL - using ..EnzymeCore -end +using DynamicPPL: DynamicPPL +using EnzymeCore # Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme # only checks whether such a method exists, and never runs it. @inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) = nothing +# Likewise for get_range_and_linked. +@inline EnzymeCore.EnzymeRules.inactive( + ::typeof(DynamicPPL.get_range_and_linked), args... +) = nothing end diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index 23a3430eb..e49b81cb2 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -1,9 +1,10 @@ module DynamicPPLMooncakeExt -using DynamicPPL: DynamicPPL, is_transformed +using DynamicPPL: DynamicPPL, is_transformed, get_range_and_linked using Mooncake: Mooncake # This is purely an optimisation. Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg} +Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(get_range_and_linked),Vararg} end # module From 326f4b179ffd33abe309987e1ffdb1d16d703cdf Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 03:11:32 +0000 Subject: [PATCH 05/57] Update comment --- src/fastldf.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index e59b12791..6c8798d4c 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -13,10 +13,8 @@ struct RangeAndLinked end struct FastLDFContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext - # The ranges of identity VarNames are stored in a NamedTuple for performance - # reasons. For just plain evaluation this doesn't make _that_ much of a - # difference (maybe 1.5x), but when doing AD with Mooncake this makes a HUGE - # difference (around 4x). Of course, the exact numbers depend on the model. + # The ranges of identity VarNames are stored in a NamedTuple for improved performance + # (it's around 1.5x faster). iden_varname_ranges::N # This Dict stores the ranges for all other VarNames varname_ranges::Dict{VarName,RangeAndLinked} From 18c2d2e4879cad906904b99ae929888e3f240c72 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 13:53:36 +0000 Subject: [PATCH 06/57] make AD testing / benchmarking use FastLDF --- benchmarks/src/DynamicPPLBenchmarks.jl | 4 +- src/fastldf.jl | 2 +- src/test_utils/ad.jl | 9 +-- test/ad.jl | 77 +++++++------------------- 4 files changed, 24 insertions(+), 68 deletions(-) diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 225e40cd8..e6988d3f2 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -94,9 +94,7 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: vi = DynamicPPL.link(vi, model) end - f = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend - ) + f = DynamicPPL.FastLDF(model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend) # The parameters at which we evaluate f. θ = vi[:] diff --git a/src/fastldf.jl b/src/fastldf.jl index 6c8798d4c..7ec193891 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -80,7 +80,7 @@ struct FastLDF{ getlogdensity::Function, # This only works with typed Metadata-varinfo. # Obviously, this can be generalised later. - varinfo::VarInfo{<:NamedTuple{syms}}; + varinfo::VarInfo{<:NamedTuple{syms}}=VarInfo(model); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) where {syms} # Figure out which variable corresponds to which index, and diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index a49ffd18b..fbbae85b7 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,8 +4,7 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: - Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link +using DynamicPPL: Model, FastLDF, VarInfo, AbstractVarInfo, getlogjoint_internal, link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: AbstractRNG, default_rng using Statistics: median @@ -265,7 +264,7 @@ function run_ad( # Calculate log-density and gradient with the backend of interest verbose && @info "Running AD on $(model.f) with $(adtype)\n" verbose && println(" params : $(params)") - ldf = LogDensityFunction(model, getlogdensity, varinfo; adtype=adtype) + ldf = FastLDF(model, getlogdensity, varinfo; adtype=adtype) value, grad = logdensity_and_gradient(ldf, params) # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 @@ -282,9 +281,7 @@ function run_ad( value_true = test.value grad_true = test.grad elseif test isa WithBackend - ldf_reference = LogDensityFunction( - model, getlogdensity, varinfo; adtype=test.adtype - ) + ldf_reference = FastLDF(model, getlogdensity, varinfo; adtype=test.adtype) value_true, grad_true = logdensity_and_gradient(ldf_reference, params) # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 grad_true = collect(grad_true) diff --git a/test/ad.jl b/test/ad.jl index d7505aab2..6d140197e 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,4 +1,4 @@ -using DynamicPPL: LogDensityFunction +using DynamicPPL: FastLDF using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest @testset "Automatic differentiation" begin @@ -15,64 +15,25 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] end - @testset "Unsupported backends" begin - @model demo() = x ~ Normal() - @test_logs (:warn, r"not officially supported") LogDensityFunction( - demo(); adtype=AutoZygote() - ) - end - @testset "Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) - vns = DynamicPPL.TestUtils.varnames(m) - varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) - - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - linked_varinfo = DynamicPPL.link(varinfo, m) - f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) - x = DynamicPPL.getparams(f) - - # Calculate reference logp + gradient of logp using ForwardDiff - ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) - ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual - - @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" - - # Put predicates here to avoid long lines - is_mooncake = adtype isa AutoMooncake - is_1_10 = v"1.10" <= VERSION < v"1.11" - is_1_11 = v"1.11" <= VERSION < v"1.12" - is_svi_vnv = - linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} - is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict} - - # Mooncake doesn't work with several combinations of SimpleVarInfo. - if is_mooncake && is_1_11 && is_svi_vnv - # https://github.com/compintell/Mooncake.jl/issues/470 - @test_throws ArgumentError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_vnv - # TODO: report upstream - @test_throws UndefRefError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_od - # TODO: report upstream - @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - else - @test run_ad( - m, - adtype; - varinfo=linked_varinfo, - test=WithExpectedResult(ref_logp, ref_grad), - ) isa Any - end - end + varinfo = VarInfo(m) + linked_varinfo = DynamicPPL.link(varinfo, m) + f = FastLDF(m, getlogjoint_internal, linked_varinfo) + x = DynamicPPL.getparams(f) + + # Calculate reference logp + gradient of logp using ForwardDiff + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual + + @testset "$adtype" for adtype in test_adtypes + @info "Testing AD on: $(m.f) - $adtype" + @test run_ad( + m, + adtype; + varinfo=linked_varinfo, + test=WithExpectedResult(ref_logp, ref_grad), + ) isa Any end end end @@ -83,7 +44,7 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest test_m = randn(2, 3) function eval_logp_and_grad(model, m, adtype) - ldf = LogDensityFunction(model(); adtype=adtype) + ldf = FastLDF(model(); adtype=adtype) return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) end From d5f58f8c6d2599c0ea5fd437fa92ff9b9d5fca83 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 14:00:40 +0000 Subject: [PATCH 07/57] Fix tests --- src/fastldf.jl | 2 +- test/ad.jl | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 7ec193891..61194ab25 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -77,7 +77,7 @@ struct FastLDF{ function FastLDF( model::Model, - getlogdensity::Function, + getlogdensity::Function=getlogjoint_internal, # This only works with typed Metadata-varinfo. # Obviously, this can be generalised later. varinfo::VarInfo{<:NamedTuple{syms}}=VarInfo(model); diff --git a/test/ad.jl b/test/ad.jl index 6d140197e..48b1b64ec 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,5 +1,6 @@ using DynamicPPL: FastLDF using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest +using Random: Xoshiro @testset "Automatic differentiation" begin # Used as the ground truth that others are compared against. @@ -17,10 +18,10 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest @testset "Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - varinfo = VarInfo(m) + varinfo = VarInfo(Xoshiro(468), m) linked_varinfo = DynamicPPL.link(varinfo, m) f = FastLDF(m, getlogjoint_internal, linked_varinfo) - x = DynamicPPL.getparams(f) + x = linked_varinfo[:] # Calculate reference logp + gradient of logp using ForwardDiff ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) From 5620efe57dd76eccd8da11c775b51c459b6f70a2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 14:18:08 +0000 Subject: [PATCH 08/57] Optimise away `make_evaluate_args_and_kwargs` --- src/fastldf.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 61194ab25..ebaf002b4 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -147,10 +147,21 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) accs = AccumulatorTuple(( LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator() )) - _, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs)) + # _, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs)) + args = map(maybe_deepcopy, model.args) + _, vi = model.f(model, OnlyAccsVarInfo(accs), args...; model.defaults...) return f._getlogdensity(vi) end +maybe_deepcopy(@nospecialize(x)) = x +function maybe_deepcopy(x::AbstractArray{T}) where {T} + if T >: Missing + deepcopy(x) + else + x + end +end + function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) return FastLogDensityAt( fldf._model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges From 56a1bbf3cfe7fd0ee6f73f578f49ecdb447c0197 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 14:23:58 +0000 Subject: [PATCH 09/57] const func annotation --- test/integration/enzyme/main.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index b40bbeb8f..ea4ec497d 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -6,8 +6,10 @@ import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test ADTYPES = Dict( - "EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward)), - "EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse)), + "EnzymeForward" => + AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), + "EnzymeReverse" => + AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const), ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES From e795d7a84044caf007d09db0d71315e5cbc318ae Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 14:28:34 +0000 Subject: [PATCH 10/57] Disable benchmarks on non-typed-Metadata-VarInfo --- benchmarks/benchmarks.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 035d8ff49..5fe0320cc 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -59,11 +59,11 @@ chosen_combinations = [ false, ), ("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false), - ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), ("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), From 2e7c723caa355529ad48143c20f89bf333db634e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 14:36:34 +0000 Subject: [PATCH 11/57] Fix `_evaluate!!` correctly to handle submodels --- src/fastldf.jl | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index ebaf002b4..309155606 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -133,6 +133,23 @@ struct FastLDF{ end end +function _evaluate!!( + model::Model{F,A,D,M,TA,TD,<:FastLDFContext}, varinfo::OnlyAccsVarInfo +) where {F,A,D,M,TA,TD} + args = map(maybe_deepcopy, model.args) + return model.f(model, varinfo, args...; model.defaults...) +end +maybe_deepcopy(@nospecialize(x)) = x +function maybe_deepcopy(x::AbstractArray{T}) where {T} + if T >: Missing + # avoid overwriting missing elements of model arguments when + # evaluating the model. + deepcopy(x) + else + x + end +end + struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} _model::M _getlogdensity::F @@ -147,21 +164,10 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) accs = AccumulatorTuple(( LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator() )) - # _, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs)) - args = map(maybe_deepcopy, model.args) - _, vi = model.f(model, OnlyAccsVarInfo(accs), args...; model.defaults...) + _, vi = _evaluate!!(model, OnlyAccsVarInfo(accs)) return f._getlogdensity(vi) end -maybe_deepcopy(@nospecialize(x)) = x -function maybe_deepcopy(x::AbstractArray{T}) where {T} - if T >: Missing - deepcopy(x) - else - x - end -end - function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) return FastLogDensityAt( fldf._model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges From 35116b144de297d693e5eef5d2991269af80db4c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 14:41:51 +0000 Subject: [PATCH 12/57] Actually fix submodel evaluate --- src/fastldf.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 309155606..b1794ffa2 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -2,7 +2,6 @@ struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo accs::Accs end DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs -DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) struct RangeAndLinked @@ -133,11 +132,13 @@ struct FastLDF{ end end -function _evaluate!!( - model::Model{F,A,D,M,TA,TD,<:FastLDFContext}, varinfo::OnlyAccsVarInfo -) where {F,A,D,M,TA,TD} - args = map(maybe_deepcopy, model.args) - return model.f(model, varinfo, args...; model.defaults...) +function _evaluate!!(model::Model, varinfo::OnlyAccsVarInfo) + if leafcontext(model.context) isa FastLDFContext + args = map(maybe_deepcopy, model.args) + return model.f(model, varinfo, args...; model.defaults...) + else + error("Shouldn't happen") + end end maybe_deepcopy(@nospecialize(x)) = x function maybe_deepcopy(x::AbstractArray{T}) where {T} From abe4068ceccb64c9f5951540826e0102b01d89f8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 17:04:26 +0000 Subject: [PATCH 13/57] Document thoroughly and organise code --- src/compiler.jl | 38 +++--- src/fastldf.jl | 341 +++++++++++++++++++++++++++++++++++++----------- src/model.jl | 20 ++- 3 files changed, 306 insertions(+), 93 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index badba9f9d..3324780ca 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -718,14 +718,15 @@ end # TODO(mhauru) matchingvalue has methods that can accept both types and values. Why? # TODO(mhauru) This function needs a more comprehensive docstring. """ - matchingvalue(vi, value) + matchingvalue(param_eltype, value) -Convert the `value` to the correct type for the `vi` object. +Convert the `value` to the correct type, given the element type of the parameters +being used to evaluate the model. """ -function matchingvalue(vi, value) +function matchingvalue(param_eltype, value) T = typeof(value) if hasmissing(T) - _value = convert(get_matching_type(vi, T), value) + _value = convert(get_matching_type(param_eltype, T), value) # TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we # are happy to return `value` as-is? if _value === value @@ -738,29 +739,30 @@ function matchingvalue(vi, value) end end -function matchingvalue(vi, value::FloatOrArrayType) - return get_matching_type(vi, value) +function matchingvalue(param_eltype, value::FloatOrArrayType) + return get_matching_type(param_eltype, value) end -function matchingvalue(vi, ::TypeWrap{T}) where {T} - return TypeWrap{get_matching_type(vi, T)}() +function matchingvalue(param_eltype, ::TypeWrap{T}) where {T} + return TypeWrap{get_matching_type(param_eltype, T)}() end # TODO(mhauru) This function needs a more comprehensive docstring. What is it for? """ - get_matching_type(vi, ::TypeWrap{T}) where {T} + get_matching_type(param_eltype, ::TypeWrap{T}) where {T} -Get the specialized version of type `T` for `vi`. +Get the specialized version of type `T`, given an element type of the parameters +being used to evaluate the model. """ get_matching_type(_, ::Type{T}) where {T} = T -function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}}) - return Union{Missing,float_type_with_fallback(eltype(vi))} +function get_matching_type(param_eltype, ::Type{<:Union{Missing,AbstractFloat}}) + return Union{Missing,float_type_with_fallback(param_eltype)} end -function get_matching_type(vi, ::Type{<:AbstractFloat}) - return float_type_with_fallback(eltype(vi)) +function get_matching_type(param_eltype, ::Type{<:AbstractFloat}) + return float_type_with_fallback(param_eltype) end -function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N} - return Array{get_matching_type(vi, T),N} +function get_matching_type(param_eltype, ::Type{<:Array{T,N}}) where {T,N} + return Array{get_matching_type(param_eltype, T),N} end -function get_matching_type(vi, ::Type{<:Array{T}}) where {T} - return Array{get_matching_type(vi, T)} +function get_matching_type(param_eltype, ::Type{<:Array{T}}) where {T} + return Array{get_matching_type(param_eltype, T)} end diff --git a/src/fastldf.jl b/src/fastldf.jl index b1794ffa2..215202230 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -1,9 +1,108 @@ +""" +fasteval.jl +----------- + +Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a +given set of parameters: + +1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters + inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. + +2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores + them inside a VarInfo's metadata. + +In general, both of these approaches work fine, but the fact that they modify the VarInfo's +metadata can often be quite wasteful. In particular, it is very common that the only outputs +we care about from model evaluation are those which are stored in accumulators, such as log +probability densities, or `ValuesAsInModel`. + +To avoid this issue, we implement here `OnlyAccsVarInfo`, which is a VarInfo that only +contains accumulators. When evaluating a model with `OnlyAccsVarInfo`, it is mandatory that +the model's leaf context is a `FastEvalContext`, which provides extremely fast access to +parameter values. No writing of values into VarInfo metadata is performed at all. + +Vector parameters +----------------- + +We first consider the case of parameter vectors, i.e., the case which would normally be +handled by `unflatten` and `evaluate!!`. Unfortunately, it is not enough to just store +the vector of parameters in the `FastEvalContext`, because it is not clear: + + - which parts of the vector correspond to which random variables, and + - whether the variables are linked or unlinked. + +Traditionally, this problem has been solved by `unflatten`, because that function would +place values into the VarInfo's metadata alongside the information about ranges and linking. +However, we want to avoid doing this. Thus, here, we _extract this information from the +VarInfo_ a single time when constructing a `FastLDF` object. + +Note that this assumes that the ranges and link status are static throughout the lifetime of +the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable +numbers of parameters, or models which may visit random variables in different orders depending +on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a +general limitation of vectorised parameters: the original `unflatten` + `evaluate!!` +approach also fails with such models. + +NamedTuple and Dict parameters +------------------------------ + +Fast evaluation has not yet been extended to NamedTuple and Dict parameters. Such +representations are capable of handling models with variable sizes and stochastic control +flow. + +However, the path towards implementing these is straightforward: + +1. Currently, `FastLDFVectorContext` allows users to input a VarName and obtain the parameter + value, plus a boolean indicating whether the value is linked or unlinked. See the + `get_range_and_linked` function for details. + +2. We would need to implement similar contexts for NamedTuple and Dict parameters. The + functionality would be quite similar to `InitContext(InitFromParams(...))`. +""" + +""" + OnlyAccsVarInfo + +This is a wrapper around an `AccumulatorTuple` that implements the minimal `AbstractVarInfo` +interface to work with the `accumulate_assume!!` and `accumulate_observe!!` functions. + +Note that this does not implement almost every other AbstractVarInfo interface function, and +so using this outside of FastLDF will lead to errors. + +Conceptually, one can also think of this as a VarInfo that doesn't contain a metadata field. +That is because values for random variables are obtained by reading from a separate entity +(such as a `FastLDFContext`), rather than from the VarInfo itself. +""" struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo accs::Accs end +OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) +function DynamicPPL.get_param_eltype(::OnlyAccsVarInfo, model::Model) + # Because the VarInfo has no parameters stored in it, we need to get the eltype from the + # model's leaf context. This is only possible if said leaf context is indeed a FastEval + # context. + leaf_ctx = DynamicPPL.leafcontext(model) + if leaf_ctx isa FastEvalVectorContext + return eltype(leaf_ctx.params) + else + error( + "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", + ) + end +end +""" + RangeAndLinked + +Suppose we have vectorised parameters `params::AbstractVector{<:Real}`. Each random variable +in the model will in general correspond to a sub-vector of `params`. This struct stores +information about that range, as well as whether the sub-vector represents a linked value or +an unlinked value. + +$(TYPEDFIELDS) +""" struct RangeAndLinked # indices that the variable corresponds to in the vectorised parameter range::UnitRange{Int} @@ -11,30 +110,55 @@ struct RangeAndLinked is_linked::Bool end -struct FastLDFContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext - # The ranges of identity VarNames are stored in a NamedTuple for improved performance - # (it's around 1.5x faster). +""" + AbstractFastEvalContext + +Abstract type representing fast evaluation contexts. This currently is only subtyped by +`FastEvalVectorContext`. However, in the future, similar contexts may be implemented for +NamedTuple and Dict parameters. +""" +abstract type AbstractFastEvalContext <: AbstractContext end +DynamicPPL.NodeTrait(::AbstractFastEvalContext) = IsLeaf() + +""" + FastEvalVectorContext( + iden_varname_ranges::NamedTuple, + varname_ranges::Dict{VarName,RangeAndLinked}, + params::AbstractVector{<:Real}, + ) + +A context that wraps a vector of parameter values, plus information about how random +variables map to ranges in that vector. + +In the simplest case, this could be accomplished only with a single dictionary mapping +VarNames to ranges and link status. However, for performance reasons, we separate out +VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All +non-identity-optic VarNames are stored in the `varname_ranges` Dict. + +It would be nice to unify the NamedTuple and Dict approach. See, e.g. +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. +""" +struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext + # This NamedTuple stores the ranges for identity VarNames iden_varname_ranges::N # This Dict stores the ranges for all other VarNames varname_ranges::Dict{VarName,RangeAndLinked} # The full parameter vector which we index into to get variable values params::T end -DynamicPPL.NodeTrait(::FastLDFContext) = IsLeaf() function get_range_and_linked( - ctx::FastLDFContext, ::VarName{sym,typeof(identity)} + ctx::FastEvalVectorContext, ::VarName{sym,typeof(identity)} ) where {sym} return ctx.iden_varname_ranges[sym] end -function get_range_and_linked(ctx::FastLDFContext, vn::VarName) +function get_range_and_linked(ctx::FastEvalVectorContext, vn::VarName) return ctx.varname_ranges[vn] end function tilde_assume!!( - ctx::FastLDFContext, right::Distribution, vn::VarName, vi::OnlyAccsVarInfo + ctx::FastEvalVectorContext, right::Distribution, vn::VarName, vi::AbstractVarInfo ) - # Don't need to read the data from the varinfo at all since it's - # all inside the context. + # Note that this function does not use the metadata field of `vi` at all. range_and_linked = get_range_and_linked(ctx, vn) y = @view ctx.params[range_and_linked.range] f = if range_and_linked.is_linked @@ -48,64 +172,111 @@ function tilde_assume!!( end function tilde_observe!!( - ::FastLDFContext, + ::FastEvalVectorContext, right::Distribution, left, vn::Union{VarName,Nothing}, - vi::OnlyAccsVarInfo, + vi::AbstractVarInfo, ) # This is the same as for DefaultContext vi = accumulate_observe!!(vi, right, left, vn) return left, vi end +######################################## +# Log-density functions using FastEval # +######################################## + +""" + FastLDF( + model::Model, + getlogdensity::Function=getlogjoint_internal, + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + +A struct which contains a model, along with all the information necessary to: + + - calculate its log density at a given point; + - and if `adtype` is provided, calculate the gradient of the log density at that point. + +This information can be extracted using the LogDensityProblems.jl interface, specifically, +using `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. If +`adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a concrete AD +backend type, then `logdensity_and_gradient` is also implemented. + +There are several options for `getlogdensity` that are 'supported' out of the box: + +- [`getlogjoint_internal`](@ref): calculate the log joint, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogprior_internal`](@ref): calculate the log prior, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring any effects of + linking +- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of + linking +- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected by linking, + since transforms are only applied to random variables) + +!!! note + By default, `FastLDF` uses `getlogjoint_internal`, i.e., the result of + `LogDensityProblems.logdensity(f, x)` will depend on whether the `FastLDF` was created + with a linked or unlinked VarInfo. This is done primarily to ease interoperability with + MCMC samplers. + +If you provide one of these functions, a `VarInfo` will be automatically created for you. If +you provide a different function, you have to manually create a VarInfo and pass it as the +third argument. + +If the `adtype` keyword argument is provided, then this struct will also store the adtype +along with other information for efficient calculation of the gradient of the log density. +Note that preparing a `FastLDF` with an AD type `AutoBackend()` requires the AD backend +itself to have been loaded (e.g. with `import Backend`). + +## Fields + +Note that it is undefined behaviour to access any of a `FastLDF`'s fields, apart from: + +- `fastldf.model`: The original model from which this `FastLDF` was constructed. +- `fastldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD + type was provided. + +## Extended help + +`FastLDF` uses `FastEvalVectorContext` internally to provide extremely rapid evaluation of +the model given a vector of parameters. + +Because it is common to call `LogDensityProblems.logdensity` and +`LogDensityProblems.logdensity_and_gradient` within tight loops, it is beneficial for us to +pre-compute as much of the information as possible when constructing the `FastLDF` object. +In particular, we use the provided VarInfo's metadata to extract the mapping from VarNames +to ranges and link status, and store this mapping inside the `FastLDF` object. We can later +use this to construct a FastEvalVectorContext, without having to look into a metadata again. +""" struct FastLDF{ M<:Model, + AD<:Union{ADTypes.AbstractADType,Nothing}, F<:Function, N<:NamedTuple, - AD<:Union{ADTypes.AbstractADType,Nothing}, ADP<:Union{Nothing,DI.GradientPrep}, } - _model::M + model::M + adtype::AD _getlogdensity::F - # See FastLDFContext for explanation of these two fields + # See FastLDFContext for explanation of these two fields. _iden_varname_ranges::N _varname_ranges::Dict{VarName,RangeAndLinked} - _adtype::AD _adprep::ADP function FastLDF( model::Model, getlogdensity::Function=getlogjoint_internal, - # This only works with typed Metadata-varinfo. - # Obviously, this can be generalised later. - varinfo::VarInfo{<:NamedTuple{syms}}=VarInfo(model); + varinfo::AbstractVarInfo=VarInfo(model); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, - ) where {syms} + ) # Figure out which variable corresponds to which index, and # which variables are linked. - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = 1 - for sym in syms - md = varinfo.metadata[sym] - for (vn, idx) in md.idcs - len = length(md.ranges[idx]) - is_linked = md.is_transformed[idx] - range = offset:(offset + len - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple(( - AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked), - )), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end - offset += len - end - end + all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) # Do AD prep if needed prep = if adtype === nothing nothing @@ -119,37 +290,37 @@ struct FastLDF{ x, ) end - return new{ typeof(model), + typeof(adtype), typeof(getlogdensity), typeof(all_iden_ranges), - typeof(adtype), typeof(prep), }( - model, getlogdensity, all_iden_ranges, all_ranges, adtype, prep + model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep ) end end -function _evaluate!!(model::Model, varinfo::OnlyAccsVarInfo) - if leafcontext(model.context) isa FastLDFContext - args = map(maybe_deepcopy, model.args) - return model.f(model, varinfo, args...; model.defaults...) - else - error("Shouldn't happen") - end +################################### +# LogDensityProblems.jl interface # +################################### +""" + fast_ldf_accs(getlogdensity::Function) + +Determine which accumulators are needed for fast evaluation with the given +`getlogdensity` function. +""" +fast_ldf_accs(::Function) = default_accumulators() +fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() +function fast_ldf_accs(::typeof(getlogjoint)) + return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) end -maybe_deepcopy(@nospecialize(x)) = x -function maybe_deepcopy(x::AbstractArray{T}) where {T} - if T >: Missing - # avoid overwriting missing elements of model arguments when - # evaluating the model. - deepcopy(x) - else - x - end +function fast_ldf_accs(::typeof(getlogprior_internal)) + return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) end +fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) +fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} _model::M @@ -158,20 +329,15 @@ struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} _varname_ranges::Dict{VarName,RangeAndLinked} end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) - ctx = FastLDFContext(f._iden_varname_ranges, f._varname_ranges, params) + ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params) model = DynamicPPL.setleafcontext(f._model, ctx) - # This can obviously also be optimised for the case where not - # all accumulators are needed. - accs = AccumulatorTuple(( - LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator() - )) - _, vi = _evaluate!!(model, OnlyAccsVarInfo(accs)) + _, vi = _evaluate!!(model, OnlyAccsVarInfo(fast_ldf_accs(f._getlogdensity))) return f._getlogdensity(vi) end function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) return FastLogDensityAt( - fldf._model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges + fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges )( params ) @@ -182,13 +348,42 @@ function LogDensityProblems.logdensity_and_gradient( ) return DI.value_and_gradient( FastLogDensityAt( - fldf._model, - fldf._getlogdensity, - fldf._iden_varname_ranges, - fldf._varname_ranges, + fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges ), fldf._adprep, - fldf._adtype, + fldf.adtype, params, ) end + +###################################################### +# Helper functions to extract ranges and link status # +###################################################### + +# TODO: Fails for other VarInfo types. +function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = 1 + for sym in syms + md = varinfo.metadata[sym] + # TODO: Fails for VarNamedVector. + for (vn, idx) in md.idcs + len = length(md.ranges[idx]) + is_linked = md.is_transformed[idx] + range = offset:(offset + len - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple(( + AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked), + )), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += len + end + end + return all_iden_ranges, all_ranges +end diff --git a/src/model.jl b/src/model.jl index 94fcd9fd4..27b6157e2 100644 --- a/src/model.jl +++ b/src/model.jl @@ -986,9 +986,9 @@ Return the arguments and keyword arguments to be passed to the evaluator of the ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) - :($matchingvalue(varinfo, model.args.$var)...) + :($matchingvalue($get_param_eltype(varinfo, model), model.args.$var)...) else - :($matchingvalue(varinfo, model.args.$var)) + :($matchingvalue($get_param_eltype(varinfo, model), model.args.$var)) end for var in argnames ] return quote @@ -1006,6 +1006,22 @@ Return the arguments and keyword arguments to be passed to the evaluator of the end end +""" + get_param_eltype(varinfo::AbstractVarInfo, model::Model) + +Get the element type of the parameters being used to evaluate the `model` from the +`varinfo`. For example, when performing AD with ForwardDiff, this should return +`ForwardDiff.Dual`. + +By default, this uses `eltype(varinfo)` which is slightly cursed. This relies on the fact +that typically, before evaluation, the parameters will have been inserted into the VarInfo's +metadata field. + +See `OnlyAccsVarInfo` for an example of where this is not true (the parameters are instead +stored in the model's context). +""" +get_param_eltype(varinfo::AbstractVarInfo, ::Model) = eltype(varinfo) + """ getargnames(model::Model) From 01ba783649d846a6e759af85706dd7c438bfc71f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 17:27:57 +0000 Subject: [PATCH 14/57] Support more VarInfos, make it thread-safe (?) --- src/fastldf.jl | 99 +++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 78 insertions(+), 21 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 215202230..87dd698dc 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -77,13 +77,14 @@ struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo accs::Accs end OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) +DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) function DynamicPPL.get_param_eltype(::OnlyAccsVarInfo, model::Model) # Because the VarInfo has no parameters stored in it, we need to get the eltype from the # model's leaf context. This is only possible if said leaf context is indeed a FastEval # context. - leaf_ctx = DynamicPPL.leafcontext(model) + leaf_ctx = DynamicPPL.leafcontext(model.context) if leaf_ctx isa FastEvalVectorContext return eltype(leaf_ctx.params) else @@ -138,7 +139,8 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict. It would be nice to unify the NamedTuple and Dict approach. See, e.g. https://github.com/TuringLang/DynamicPPL.jl/issues/1116. """ -struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext +struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: + AbstractFastEvalContext # This NamedTuple stores the ranges for identity VarNames iden_varname_ranges::N # This Dict stores the ranges for all other VarNames @@ -331,7 +333,17 @@ end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params) model = DynamicPPL.setleafcontext(f._model, ctx) - _, vi = _evaluate!!(model, OnlyAccsVarInfo(fast_ldf_accs(f._getlogdensity))) + only_accs_vi = OnlyAccsVarInfo(fast_ldf_accs(f._getlogdensity)) + # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, + # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` + # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic + # here. + vi = if Threads.nthreads() > 1 + ThreadSafeVarInfo(only_accs_vi) + else + only_accs_vi + end + _, vi = _evaluate!!(model, vi) return f._getlogdensity(vi) end @@ -360,30 +372,75 @@ end # Helper functions to extract ranges and link status # ###################################################### -# TODO: Fails for other VarInfo types. +# TODO: Fails for SimpleVarInfo. Do I really care enough? Ehhh, honestly, debatable. + +""" + get_ranges_and_linked(varinfo::VarInfo) + +Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter +representation, along with whether each variable is linked or unlinked. + +This function should return a tuple containing: + +- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` +- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. +""" function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} all_iden_ranges = NamedTuple() all_ranges = Dict{VarName,RangeAndLinked}() offset = 1 for sym in syms md = varinfo.metadata[sym] - # TODO: Fails for VarNamedVector. - for (vn, idx) in md.idcs - len = length(md.ranges[idx]) - is_linked = md.is_transformed[idx] - range = offset:(offset + len - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple(( - AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked), - )), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end - offset += len - end + this_md_iden, this_md_others, new_offset = get_ranges_and_linked_metadata( + md, offset + ) + all_iden_ranges = merge(all_iden_ranges, this_md_iden) + all_ranges = merge(all_ranges, this_md_others) + offset = new_offset end return all_iden_ranges, all_ranges end +function get_ranges_and_linked(varinfo::VarInfo{<:Metadata}) + all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) + return all_iden, all_others +end +function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in md.idcs + len = length(md.ranges[idx]) + is_linked = md.is_transformed[idx] + range = offset:(offset + len - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += len + end + return all_iden_ranges, all_ranges, offset +end +function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in vnv.varname_to_index + len = length(vnv.ranges[idx]) + is_linked = vnv.is_unconstrained[idx] + range = offset:(offset + len - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += len + end + return all_iden_ranges, all_ranges, offset +end From 24a4519ecac5185c7d5b59715fd8e57e077e0dc6 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 17:46:06 +0000 Subject: [PATCH 15/57] fix bug in parsing ranges from metadata/VNV --- src/fastldf.jl | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 87dd698dc..15326a83b 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -391,16 +391,13 @@ function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms offset = 1 for sym in syms md = varinfo.metadata[sym] - this_md_iden, this_md_others, new_offset = get_ranges_and_linked_metadata( - md, offset - ) + this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) all_iden_ranges = merge(all_iden_ranges, this_md_iden) all_ranges = merge(all_ranges, this_md_others) - offset = new_offset end return all_iden_ranges, all_ranges end -function get_ranges_and_linked(varinfo::VarInfo{<:Metadata}) +function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) return all_iden, all_others end @@ -409,9 +406,8 @@ function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) all_ranges = Dict{VarName,RangeAndLinked}() offset = start_offset for (vn, idx) in md.idcs - len = length(md.ranges[idx]) is_linked = md.is_transformed[idx] - range = offset:(offset + len - 1) + range = md.ranges[idx] .+ (start_offset - 1) if AbstractPPL.getoptic(vn) === identity all_iden_ranges = merge( all_iden_ranges, @@ -420,7 +416,7 @@ function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) else all_ranges[vn] = RangeAndLinked(range, is_linked) end - offset += len + offset += length(range) end return all_iden_ranges, all_ranges, offset end @@ -429,9 +425,8 @@ function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) all_ranges = Dict{VarName,RangeAndLinked}() offset = start_offset for (vn, idx) in vnv.varname_to_index - len = length(vnv.ranges[idx]) is_linked = vnv.is_unconstrained[idx] - range = offset:(offset + len - 1) + range = vnv.ranges[idx] .+ (start_offset - 1) if AbstractPPL.getoptic(vn) === identity all_iden_ranges = merge( all_iden_ranges, @@ -440,7 +435,7 @@ function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) else all_ranges[vn] = RangeAndLinked(range, is_linked) end - offset += len + offset += length(range) end return all_iden_ranges, all_ranges, offset end From b36ada65b32a9fbb0c4fe2d38770959a0e6605bf Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 18:04:40 +0000 Subject: [PATCH 16/57] Fix get_param_eltype for TSVI --- src/fastldf.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 15326a83b..eaca0c795 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -80,7 +80,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) -function DynamicPPL.get_param_eltype(::OnlyAccsVarInfo, model::Model) +function DynamicPPL.get_param_eltype( + ::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, model::Model +) # Because the VarInfo has no parameters stored in it, we need to get the eltype from the # model's leaf context. This is only possible if said leaf context is indeed a FastEval # context. @@ -333,15 +335,16 @@ end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params) model = DynamicPPL.setleafcontext(f._model, ctx) - only_accs_vi = OnlyAccsVarInfo(fast_ldf_accs(f._getlogdensity)) + accs = fast_ldf_accs(f._getlogdensity) # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic # here. vi = if Threads.nthreads() > 1 - ThreadSafeVarInfo(only_accs_vi) + accs = map(acc -> convert_eltype(float_type_with_fallback(eltype(params)), acc), accs) + ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) else - only_accs_vi + OnlyAccsVarInfo(accs) end _, vi = _evaluate!!(model, vi) return f._getlogdensity(vi) From 379abfd39fa3d168624892019def9180835a61fc Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 18:17:26 +0000 Subject: [PATCH 17/57] Disable Enzyme benchmark --- benchmarks/benchmarks.jl | 2 +- src/fastldf.jl | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 5fe0320cc..e78bf602f 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -66,7 +66,7 @@ chosen_combinations = [ # ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), ("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), - ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), + # ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), ("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true), ("Multivariate 1k", multivariate1k, :typed, :mooncake, true), ("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true), diff --git a/src/fastldf.jl b/src/fastldf.jl index eaca0c795..c06f3495c 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -350,6 +350,25 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) return f._getlogdensity(vi) end +function _evaluate!!(model::Model, varinfo::OnlyAccsVarInfo) + if leafcontext(model.context) isa FastEvalVectorContext + args = map(maybe_deepcopy, model.args) + return model.f(model, varinfo, args...; model.defaults...) + else + error("Shouldn't happen") + end +end +maybe_deepcopy(@nospecialize(x)) = x +function maybe_deepcopy(x::AbstractArray{T}) where {T} + if T >: Missing + # avoid overwriting missing elements of model arguments when + # evaluating the model. + deepcopy(x) + else + x + end +end + function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) return FastLogDensityAt( fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges From 05d7b95098d90694309b8a4013be95ac5c4e3dbb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 18:20:31 +0000 Subject: [PATCH 18/57] Don't override _evaluate!!, that breaks ForwardDiff (sometimes) --- src/fastldf.jl | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index c06f3495c..eaca0c795 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -350,25 +350,6 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) return f._getlogdensity(vi) end -function _evaluate!!(model::Model, varinfo::OnlyAccsVarInfo) - if leafcontext(model.context) isa FastEvalVectorContext - args = map(maybe_deepcopy, model.args) - return model.f(model, varinfo, args...; model.defaults...) - else - error("Shouldn't happen") - end -end -maybe_deepcopy(@nospecialize(x)) = x -function maybe_deepcopy(x::AbstractArray{T}) where {T} - if T >: Missing - # avoid overwriting missing elements of model arguments when - # evaluating the model. - deepcopy(x) - else - x - end -end - function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) return FastLogDensityAt( fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges From 3d206f40af2cb470252d96c0d2fe86771a7a0b1f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 20:49:18 +0000 Subject: [PATCH 19/57] Move FastLDF to experimental for now --- benchmarks/benchmarks.jl | 12 ++-- benchmarks/src/DynamicPPLBenchmarks.jl | 4 +- src/DynamicPPL.jl | 1 - src/experimental.jl | 2 + src/{fastldf.jl => fasteval.jl} | 0 src/test_utils/ad.jl | 9 ++- test/ad.jl | 78 +++++++++++++++++++------- 7 files changed, 75 insertions(+), 31 deletions(-) rename src/{fastldf.jl => fasteval.jl} (100%) diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index e78bf602f..035d8ff49 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -59,14 +59,14 @@ chosen_combinations = [ false, ), ("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false), - # ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), - # ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), - # ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), - # ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true), - # ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), + ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), + ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), + ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), + ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true), + ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), ("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), - # ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), + ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), ("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true), ("Multivariate 1k", multivariate1k, :typed, :mooncake, true), ("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true), diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index e6988d3f2..225e40cd8 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -94,7 +94,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: vi = DynamicPPL.link(vi, model) end - f = DynamicPPL.FastLDF(model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend) + f = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend + ) # The parameters at which we evaluate f. θ = vi[:] diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 1a5c338ff..c43bd89d5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -196,7 +196,6 @@ include("simple_varinfo.jl") include("compiler.jl") include("pointwise_logdensities.jl") include("logdensityfunction.jl") -include("fastldf.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") diff --git a/src/experimental.jl b/src/experimental.jl index 8c82dca68..31cece3b4 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -2,6 +2,8 @@ module Experimental using DynamicPPL: DynamicPPL +include("fastldf.jl") + # This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. """ is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...) diff --git a/src/fastldf.jl b/src/fasteval.jl similarity index 100% rename from src/fastldf.jl rename to src/fasteval.jl diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index fbbae85b7..a49ffd18b 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,7 +4,8 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: Model, FastLDF, VarInfo, AbstractVarInfo, getlogjoint_internal, link +using DynamicPPL: + Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: AbstractRNG, default_rng using Statistics: median @@ -264,7 +265,7 @@ function run_ad( # Calculate log-density and gradient with the backend of interest verbose && @info "Running AD on $(model.f) with $(adtype)\n" verbose && println(" params : $(params)") - ldf = FastLDF(model, getlogdensity, varinfo; adtype=adtype) + ldf = LogDensityFunction(model, getlogdensity, varinfo; adtype=adtype) value, grad = logdensity_and_gradient(ldf, params) # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 @@ -281,7 +282,9 @@ function run_ad( value_true = test.value grad_true = test.grad elseif test isa WithBackend - ldf_reference = FastLDF(model, getlogdensity, varinfo; adtype=test.adtype) + ldf_reference = LogDensityFunction( + model, getlogdensity, varinfo; adtype=test.adtype + ) value_true, grad_true = logdensity_and_gradient(ldf_reference, params) # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 grad_true = collect(grad_true) diff --git a/test/ad.jl b/test/ad.jl index 48b1b64ec..d7505aab2 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,6 +1,5 @@ -using DynamicPPL: FastLDF +using DynamicPPL: LogDensityFunction using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest -using Random: Xoshiro @testset "Automatic differentiation" begin # Used as the ground truth that others are compared against. @@ -16,25 +15,64 @@ using Random: Xoshiro [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] end + @testset "Unsupported backends" begin + @model demo() = x ~ Normal() + @test_logs (:warn, r"not officially supported") LogDensityFunction( + demo(); adtype=AutoZygote() + ) + end + @testset "Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - varinfo = VarInfo(Xoshiro(468), m) - linked_varinfo = DynamicPPL.link(varinfo, m) - f = FastLDF(m, getlogjoint_internal, linked_varinfo) - x = linked_varinfo[:] - - # Calculate reference logp + gradient of logp using ForwardDiff - ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) - ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual - - @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $adtype" - @test run_ad( - m, - adtype; - varinfo=linked_varinfo, - test=WithExpectedResult(ref_logp, ref_grad), - ) isa Any + rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) + vns = DynamicPPL.TestUtils.varnames(m) + varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) + + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + linked_varinfo = DynamicPPL.link(varinfo, m) + f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) + x = DynamicPPL.getparams(f) + + # Calculate reference logp + gradient of logp using ForwardDiff + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual + + @testset "$adtype" for adtype in test_adtypes + @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" + + # Put predicates here to avoid long lines + is_mooncake = adtype isa AutoMooncake + is_1_10 = v"1.10" <= VERSION < v"1.11" + is_1_11 = v"1.11" <= VERSION < v"1.12" + is_svi_vnv = + linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} + is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict} + + # Mooncake doesn't work with several combinations of SimpleVarInfo. + if is_mooncake && is_1_11 && is_svi_vnv + # https://github.com/compintell/Mooncake.jl/issues/470 + @test_throws ArgumentError DynamicPPL.LogDensityFunction( + m, getlogjoint_internal, linked_varinfo; adtype=adtype + ) + elseif is_mooncake && is_1_10 && is_svi_vnv + # TODO: report upstream + @test_throws UndefRefError DynamicPPL.LogDensityFunction( + m, getlogjoint_internal, linked_varinfo; adtype=adtype + ) + elseif is_mooncake && is_1_10 && is_svi_od + # TODO: report upstream + @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( + m, getlogjoint_internal, linked_varinfo; adtype=adtype + ) + else + @test run_ad( + m, + adtype; + varinfo=linked_varinfo, + test=WithExpectedResult(ref_logp, ref_grad), + ) isa Any + end + end end end end @@ -45,7 +83,7 @@ using Random: Xoshiro test_m = randn(2, 3) function eval_logp_and_grad(model, m, adtype) - ldf = FastLDF(model(); adtype=adtype) + ldf = LogDensityFunction(model(); adtype=adtype) return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) end From 007da25bdb337ef78968ddeeaedc62ca63cc3d96 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 20:56:09 +0000 Subject: [PATCH 20/57] Fix imports, add tests, etc --- src/experimental.jl | 2 +- src/fasteval.jl | 40 +++++++++++++-- test/fasteval.jl | 115 ++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 153 insertions(+), 5 deletions(-) create mode 100644 test/fasteval.jl diff --git a/src/experimental.jl b/src/experimental.jl index 31cece3b4..c644c09b2 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -2,7 +2,7 @@ module Experimental using DynamicPPL: DynamicPPL -include("fastldf.jl") +include("fasteval.jl") # This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. """ diff --git a/src/fasteval.jl b/src/fasteval.jl index eaca0c795..2bd7b3a0c 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -60,6 +60,35 @@ However, the path towards implementing these is straightforward: functionality would be quite similar to `InitContext(InitFromParams(...))`. """ +using DynamicPPL: + AbstractContext, + AbstractVarInfo, + AccumulatorTuple, + Metadata, + Model, + ThreadSafeVarInfo, + VarInfo, + VarNamedVector, + accumulate_assume!!, + accumulate_observe!!, + default_accumulators, + float_type_with_fallback, + from_linked_vec_transform, + from_vec_transform, + getlogjoint, + getlogjoint_internal, + getloglikelihood, + getlogprior, + getlogprior_internal, + leafcontext +using ADTypes: ADTypes +using Bijectors: with_logabsdet_jacobian +using AbstractPPL: AbstractPPL, VarName +using Distributions: Distribution +using DocStringExtensions: TYPEDFIELDS +using LogDensityProblems: LogDensityProblems +import DifferentiationInterface as DI + """ OnlyAccsVarInfo @@ -121,7 +150,7 @@ Abstract type representing fast evaluation contexts. This currently is only subt NamedTuple and Dict parameters. """ abstract type AbstractFastEvalContext <: AbstractContext end -DynamicPPL.NodeTrait(::AbstractFastEvalContext) = IsLeaf() +DynamicPPL.NodeTrait(::AbstractFastEvalContext) = DynamicPPL.IsLeaf() """ FastEvalVectorContext( @@ -286,7 +315,7 @@ struct FastLDF{ nothing else # Make backend-specific tweaks to the adtype - adtype = tweak_adtype(adtype, model, varinfo) + adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) x = [val for val in varinfo[:]] DI.prepare_gradient( FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), @@ -341,12 +370,15 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic # here. vi = if Threads.nthreads() > 1 - accs = map(acc -> convert_eltype(float_type_with_fallback(eltype(params)), acc), accs) + accs = map( + acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc), + accs, + ) ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) else OnlyAccsVarInfo(accs) end - _, vi = _evaluate!!(model, vi) + _, vi = DynamicPPL._evaluate!!(model, vi) return f._getlogdensity(vi) end diff --git a/test/fasteval.jl b/test/fasteval.jl new file mode 100644 index 000000000..563d2adba --- /dev/null +++ b/test/fasteval.jl @@ -0,0 +1,115 @@ +module DynamicPPLFastLDFTests + +using DynamicPPL +using Distributions +using DistributionsAD: filldist +using ADTypes +using DynamicPPL.Experimental: FastLDF + +@testset "Automatic differentiation" begin + # Used as the ground truth that others are compared against. + ref_adtype = AutoForwardDiff() + + test_adtypes = if MOONCAKE_SUPPORTED + [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] + else + [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] + end + + @testset "Unsupported backends" begin + @model demo() = x ~ Normal() + @test_logs (:warn, r"not officially supported") FastLDF(demo(); adtype=AutoZygote()) + end + + @testset "Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + varinfo = VarInfo(m) + linked_varinfo = DynamicPPL.link(varinfo, m) + f = FastLDF(m, getlogjoint_internal, linked_varinfo) + x = linked_varinfo[:] + + # Calculate reference logp + gradient of logp using ForwardDiff + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual + + @testset "$adtype" for adtype in test_adtypes + @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" + + @test run_ad( + m, + adtype; + varinfo=linked_varinfo, + test=WithExpectedResult(ref_logp, ref_grad), + ) isa Any + end + end + end + + # Test that various different ways of specifying array types as arguments work with all + # ADTypes. + @testset "Array argument types" begin + test_m = randn(2, 3) + + function eval_logp_and_grad(model, m, adtype) + ldf = FastLDF(model(); adtype=adtype) + return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) + end + + @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} + m = Matrix{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_matrix_model_reference = eval_logp_and_grad( + scalar_matrix_model, test_m, ref_adtype + ) + + @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) + + @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} + m = Array{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_array_model_reference = eval_logp_and_grad( + scalar_array_model, test_m, ref_adtype + ) + + @model function array_model(::Type{T}=Array{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) + + @testset "$adtype" for adtype in test_adtypes + scalar_matrix_model_logp_and_grad = eval_logp_and_grad( + scalar_matrix_model, test_m, adtype + ) + @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] + @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] + matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) + @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] + @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] + scalar_array_model_logp_and_grad = eval_logp_and_grad( + scalar_array_model, test_m, adtype + ) + @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] + @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] + array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) + @test array_model_logp_and_grad[1] ≈ array_model_reference[1] + @test array_model_logp_and_grad[2] ≈ array_model_reference[2] + end + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 861d3bb87..10fac8b0f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -89,6 +89,7 @@ include("test_util.jl") include("ext/DynamicPPLMooncakeExt.jl") end include("ad.jl") + include("fasteval.jl") end @testset "prob and logprob macro" begin @test_throws ErrorException prob"..." From 1ee190f429a164e94ce35dabff40c9943c336d66 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 21:00:14 +0000 Subject: [PATCH 21/57] More test fixes --- ext/DynamicPPLEnzymeCoreExt.jl | 2 +- ext/DynamicPPLMooncakeExt.jl | 6 ++++-- test/fasteval.jl | 18 ++++++++++++------ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index 29a4e2cc7..eacc35046 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -9,7 +9,7 @@ using EnzymeCore nothing # Likewise for get_range_and_linked. @inline EnzymeCore.EnzymeRules.inactive( - ::typeof(DynamicPPL.get_range_and_linked), args... + ::typeof(DynamicPPL.Experimental.get_range_and_linked), args... ) = nothing end diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index e49b81cb2..63c754f4c 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -1,10 +1,12 @@ module DynamicPPLMooncakeExt -using DynamicPPL: DynamicPPL, is_transformed, get_range_and_linked +using DynamicPPL: DynamicPPL, is_transformed using Mooncake: Mooncake # This is purely an optimisation. Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg} -Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(get_range_and_linked),Vararg} +Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{ + typeof(DynamicPPL.Experimental.get_range_and_linked),Vararg +} end # module diff --git a/test/fasteval.jl b/test/fasteval.jl index 563d2adba..3d11c6889 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -5,12 +5,23 @@ using Distributions using DistributionsAD: filldist using ADTypes using DynamicPPL.Experimental: FastLDF +using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest +using Test + +using ForwardDiff: ForwardDiff +using ReverseDiff: ReverseDiff +# Need to include this block here in case we run this test file standalone +@static if VERSION < v"1.12" + using Pkg + Pkg.add("Mooncake") + using Mooncake: Mooncake +end @testset "Automatic differentiation" begin # Used as the ground truth that others are compared against. ref_adtype = AutoForwardDiff() - test_adtypes = if MOONCAKE_SUPPORTED + test_adtypes = @static if VERSION < v"1.12" [ AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true), @@ -20,11 +31,6 @@ using DynamicPPL.Experimental: FastLDF [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] end - @testset "Unsupported backends" begin - @model demo() = x ~ Normal() - @test_logs (:warn, r"not officially supported") FastLDF(demo(); adtype=AutoZygote()) - end - @testset "Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS varinfo = VarInfo(m) From 0ad40841260d9bd6742d50940a90ded429be3e22 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 21:37:26 +0000 Subject: [PATCH 22/57] Fix imports / tests --- src/fasteval.jl | 26 ++++++++++++++++---------- test/fasteval.jl | 43 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 57 insertions(+), 12 deletions(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index 2bd7b3a0c..533e19812 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -50,14 +50,10 @@ Fast evaluation has not yet been extended to NamedTuple and Dict parameters. Suc representations are capable of handling models with variable sizes and stochastic control flow. -However, the path towards implementing these is straightforward: - -1. Currently, `FastLDFVectorContext` allows users to input a VarName and obtain the parameter - value, plus a boolean indicating whether the value is linked or unlinked. See the - `get_range_and_linked` function for details. - -2. We would need to implement similar contexts for NamedTuple and Dict parameters. The - functionality would be quite similar to `InitContext(InitFromParams(...))`. +However, the path towards implementing these is straightforward: just make `InitContext` work +correctly with `OnlyAccsVarInfo`. There will probably be a few functions that need to be +overloaded to make this work: for example `push!!` on `OnlyAccsVarInfo` can just be defined +as a no-op. """ using DynamicPPL: @@ -119,6 +115,13 @@ function DynamicPPL.get_param_eltype( if leaf_ctx isa FastEvalVectorContext return eltype(leaf_ctx.params) else + # TODO(penelopeysm): In principle this can be done with InitContext{InitWithParams}. + # See also `src/simple_varinfo.jl` where `infer_nested_eltype` is used to try to + # figure out the parameter type from a NamedTuple or Dict. The benefit of + # implementing this for InitContext is that we could then use OnlyAccsVarInfo with + # it, which means fast evaluation with NamedTuple or Dict parameters! And I believe + # that Mooncake / Enzyme should be able to differentiate through that too and + # provide a NamedTuple of gradients (although I haven't tested this yet). error( "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", ) @@ -188,7 +191,7 @@ function get_range_and_linked(ctx::FastEvalVectorContext, vn::VarName) return ctx.varname_ranges[vn] end -function tilde_assume!!( +function DynamicPPL.tilde_assume!!( ctx::FastEvalVectorContext, right::Distribution, vn::VarName, vi::AbstractVarInfo ) # Note that this function does not use the metadata field of `vi` at all. @@ -204,7 +207,7 @@ function tilde_assume!!( return x, vi end -function tilde_observe!!( +function DynamicPPL.tilde_observe!!( ::FastEvalVectorContext, right::Distribution, left, @@ -369,6 +372,9 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic # here. + # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what + # it _should_ do, but this is wrong regardless. + # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 vi = if Threads.nthreads() > 1 accs = map( acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc), diff --git a/test/fasteval.jl b/test/fasteval.jl index 3d11c6889..277384350 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -1,12 +1,15 @@ module DynamicPPLFastLDFTests +using AbstractPPL: AbstractPPL using DynamicPPL using Distributions using DistributionsAD: filldist using ADTypes using DynamicPPL.Experimental: FastLDF using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest +using LinearAlgebra: I using Test +using LogDensityProblems: LogDensityProblems using ForwardDiff: ForwardDiff using ReverseDiff: ReverseDiff @@ -17,7 +20,43 @@ using ReverseDiff: ReverseDiff using Mooncake: Mooncake end -@testset "Automatic differentiation" begin +@testset "get_ranges_and_linked" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$varinfo_func" for varinfo_func in [ + DynamicPPL.untyped_varinfo, + DynamicPPL.typed_varinfo, + DynamicPPL.untyped_vector_varinfo, + DynamicPPL.typed_vector_varinfo, + ] + unlinked_vi = varinfo_func(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) + params = vi[:] + # Iterate over all variables + for vn in keys(vi) + # Check that `getindex_internal` returns the same thing as using the ranges + # directly + range_with_linked = if AbstractPPL.getoptic(vn) === identity + nt_ranges[AbstractPPL.getsym(vn)] + else + dict_ranges[vn] + end + @test params[range_with_linked.range] == + DynamicPPL.getindex_internal(vi, vn) + # Check that the link status is correct + @test range_with_linked.is_linked == islinked + end + end + end + end +end + +@testset "AD with FastLDF" begin # Used as the ground truth that others are compared against. ref_adtype = AutoForwardDiff() @@ -43,7 +82,7 @@ end ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" + @info "Testing AD on: $(m.f) - $adtype" @test run_ad( m, From 1c1ca932c00073e56578021e0604aa07153ad790 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 21:56:32 +0000 Subject: [PATCH 23/57] Remove AbstractFastEvalContext --- src/fasteval.jl | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index 533e19812..a0df8b373 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -145,22 +145,12 @@ struct RangeAndLinked is_linked::Bool end -""" - AbstractFastEvalContext - -Abstract type representing fast evaluation contexts. This currently is only subtyped by -`FastEvalVectorContext`. However, in the future, similar contexts may be implemented for -NamedTuple and Dict parameters. -""" -abstract type AbstractFastEvalContext <: AbstractContext end -DynamicPPL.NodeTrait(::AbstractFastEvalContext) = DynamicPPL.IsLeaf() - """ FastEvalVectorContext( iden_varname_ranges::NamedTuple, varname_ranges::Dict{VarName,RangeAndLinked}, params::AbstractVector{<:Real}, - ) + ) <: AbstractContext A context that wraps a vector of parameter values, plus information about how random variables map to ranges in that vector. @@ -173,8 +163,7 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict. It would be nice to unify the NamedTuple and Dict approach. See, e.g. https://github.com/TuringLang/DynamicPPL.jl/issues/1116. """ -struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: - AbstractFastEvalContext +struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext # This NamedTuple stores the ranges for identity VarNames iden_varname_ranges::N # This Dict stores the ranges for all other VarNames @@ -182,6 +171,8 @@ struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: # The full parameter vector which we index into to get variable values params::T end +DynamicPPL.NodeTrait(::FastEvalVectorContext) = DynamicPPL.IsLeaf() + function get_range_and_linked( ctx::FastEvalVectorContext, ::VarName{sym,typeof(identity)} ) where {sym} From f2b56243cdde13dadd307e3f35c19fbed8ba2da3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 21:59:47 +0000 Subject: [PATCH 24/57] Changelog and patch bump --- HISTORY.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index f181897f7..f56f85e8d 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -21,6 +21,18 @@ You should not need to use these directly, please use `AbstractPPL.condition` an Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. +### Other changes + +#### FastLDF + +Added `DynamicPPL.Experimental.FastLDF`, a version of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. +Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. + +Please note that `FastLDF` is currently considered internal and its API may change without warning. +We intend to replace `LogDensityFunction` with `FastLDF` in a release in the near future, but until then we recommend not using it. + +For more information about `FastLDF`, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. + ## 0.38.9 Remove warning when using Enzyme as the AD backend. From 4cf8a1bcd0e9b110bcd51d7f94bd11c6ad24206f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 22:07:56 +0000 Subject: [PATCH 25/57] Add correctness tests, fix imports --- src/fasteval.jl | 3 +++ test/fasteval.jl | 22 +++++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index a0df8b373..c91254d43 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -60,6 +60,9 @@ using DynamicPPL: AbstractContext, AbstractVarInfo, AccumulatorTuple, + LogJacobianAccumulator, + LogLikelihoodAccumulator, + LogPriorAccumulator, Metadata, Model, ThreadSafeVarInfo, diff --git a/test/fasteval.jl b/test/fasteval.jl index 277384350..051c82f0f 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -20,7 +20,7 @@ using ReverseDiff: ReverseDiff using Mooncake: Mooncake end -@testset "get_ranges_and_linked" begin +@testset "FastLDF: Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS @testset "$varinfo_func" for varinfo_func in [ DynamicPPL.untyped_varinfo, @@ -51,6 +51,26 @@ end # Check that the link status is correct @test range_with_linked.is_linked == islinked end + + # Compare results of FastLDF vs ordinary LogDensityFunction. These tests + # can eventually go once we replace LogDensityFunction with FastLDF, but + # for now it helps to have this check! (Eventually we should just check + # against manually computed log-densities). + # + # TODO(penelopeysm): I think we need to add tests for some really + # pathological models here. + @testset "$getlogdensity" for getlogdensity in ( + DynamicPPL.getlogjoint_internal, + DynamicPPL.getlogjoint, + DynamicPPL.getloglikelihood, + DynamicPPL.getlogprior_internal, + DynamicPPL.getlogprior, + ) + ldf = DynamicPPL.LogDensityFunction(m, getlogdensity, vi) + fldf = FastLDF(m, getlogdensity, vi) + @test LogDensityProblems.logdensity(ldf, params) ≈ + LogDensityProblems.logdensity(fldf, params) + end end end end From 7a6d0be541a860d1a127fa54a99df6b822b70866 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 8 Nov 2025 17:25:43 +0000 Subject: [PATCH 26/57] Concretise parameter vector in tests --- test/fasteval.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/fasteval.jl b/test/fasteval.jl index 051c82f0f..863ec369e 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -36,7 +36,7 @@ end unlinked_vi end nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) - params = vi[:] + params = map(identity, vi[:]) # Iterate over all variables for vn in keys(vi) # Check that `getindex_internal` returns the same thing as using the ranges @@ -95,7 +95,7 @@ end varinfo = VarInfo(m) linked_varinfo = DynamicPPL.link(varinfo, m) f = FastLDF(m, getlogjoint_internal, linked_varinfo) - x = linked_varinfo[:] + x = map(identity, linked_varinfo[:]) # Calculate reference logp + gradient of logp using ForwardDiff ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) From 587fa457099a56c33f11c2529bf4d285b989792f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 9 Nov 2025 02:42:07 +0000 Subject: [PATCH 27/57] Add zero-allocation tests --- test/fasteval.jl | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/fasteval.jl b/test/fasteval.jl index 863ec369e..9907e18a1 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -1,6 +1,7 @@ module DynamicPPLFastLDFTests using AbstractPPL: AbstractPPL +using Chairmarks using DynamicPPL using Distributions using DistributionsAD: filldist @@ -76,6 +77,34 @@ end end end +@testset "FastLDF: performance" begin + # Evaluating these three models should not lead to any allocations. + @model function f() + x ~ Normal() + return 1.0 ~ Normal(x) + end + @model function submodel_inner() + m ~ Normal(0, 1) + s ~ Exponential() + return (m=m, s=s) + end + # Note that for the allocation tests to work on this one, `inner` has + # to be passed as an argument to `submodel_outer`, instead of just + # being called inside the model function itself + @model function submodel_outer(inner) + params ~ to_submodel(inner) + y ~ Normal(params.m, params.s) + return 1.0 ~ Normal(y) + end + @testset for model in (f(), submodel_inner(), submodel_outer(submodel_inner())) + vi = VarInfo(model) + fldf = DynamicPPL.Experimental.FastLDF(model, DynamicPPL.getlogjoint_internal, vi) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity(fldf, x)) + @test iszero(bench.allocs) + end +end + @testset "AD with FastLDF" begin # Used as the ground truth that others are compared against. ref_adtype = AutoForwardDiff() From 045385ed5c54d96e33ddb4b76e173a09b1752145 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 9 Nov 2025 03:01:34 +0000 Subject: [PATCH 28/57] Add Chairmarks as test dep --- test/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/test/Project.toml b/test/Project.toml index 2dbd5b455..efd916308 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" From 6e9fc1d3151ef539388c154e19f6c1f3d3de0751 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 9 Nov 2025 03:14:18 +0000 Subject: [PATCH 29/57] Disable allocations tests on multi-threaded --- test/fasteval.jl | 53 ++++++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/test/fasteval.jl b/test/fasteval.jl index 9907e18a1..57a2c937d 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -78,30 +78,35 @@ end end @testset "FastLDF: performance" begin - # Evaluating these three models should not lead to any allocations. - @model function f() - x ~ Normal() - return 1.0 ~ Normal(x) - end - @model function submodel_inner() - m ~ Normal(0, 1) - s ~ Exponential() - return (m=m, s=s) - end - # Note that for the allocation tests to work on this one, `inner` has - # to be passed as an argument to `submodel_outer`, instead of just - # being called inside the model function itself - @model function submodel_outer(inner) - params ~ to_submodel(inner) - y ~ Normal(params.m, params.s) - return 1.0 ~ Normal(y) - end - @testset for model in (f(), submodel_inner(), submodel_outer(submodel_inner())) - vi = VarInfo(model) - fldf = DynamicPPL.Experimental.FastLDF(model, DynamicPPL.getlogjoint_internal, vi) - x = vi[:] - bench = median(@be LogDensityProblems.logdensity(fldf, x)) - @test iszero(bench.allocs) + if Threads.nthreads() == 1 + # Evaluating these three models should not lead to any allocations (but only when + # not using TSVI). + @model function f() + x ~ Normal() + return 1.0 ~ Normal(x) + end + @model function submodel_inner() + m ~ Normal(0, 1) + s ~ Exponential() + return (m=m, s=s) + end + # Note that for the allocation tests to work on this one, `inner` has + # to be passed as an argument to `submodel_outer`, instead of just + # being called inside the model function itself + @model function submodel_outer(inner) + params ~ to_submodel(inner) + y ~ Normal(params.m, params.s) + return 1.0 ~ Normal(y) + end + @testset for model in (f(), submodel_inner(), submodel_outer(submodel_inner())) + vi = VarInfo(model) + fldf = DynamicPPL.Experimental.FastLDF( + model, DynamicPPL.getlogjoint_internal, vi + ) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity(fldf, x)) + @test iszero(bench.allocs) + end end end From 1fee38f1d21ff56be9368287ef89a6b386701ad0 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 10 Nov 2025 19:28:52 +0000 Subject: [PATCH 30/57] Fast InitContext (#1125) * Make InitContext work with OnlyAccsVarInfo * Do not convert NamedTuple to Dict * remove logging * Enable InitFromPrior and InitFromUniform too * Fix `infer_nested_eltype` invocation --- src/contexts/init.jl | 2 +- src/fasteval.jl | 20 +++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 83507353f..be2dc1b8a 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -102,7 +102,7 @@ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitS function InitFromParams( params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() ) - return InitFromParams(to_varname_dict(params), fallback) + return new{typeof(params),typeof(fallback)}(params, fallback) end end function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) diff --git a/src/fasteval.jl b/src/fasteval.jl index c91254d43..fbc6a61ce 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -60,6 +60,10 @@ using DynamicPPL: AbstractContext, AbstractVarInfo, AccumulatorTuple, + InitContext, + InitFromParams, + InitFromPrior, + InitFromUniform, LogJacobianAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, @@ -81,6 +85,7 @@ using DynamicPPL: getlogprior_internal, leafcontext using ADTypes: ADTypes +using BangBang: BangBang using Bijectors: with_logabsdet_jacobian using AbstractPPL: AbstractPPL, VarName using Distributions: Distribution @@ -108,6 +113,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) +@inline Base.haskey(::OnlyAccsVarInfo, ::VarName) = false +@inline DynamicPPL.is_transformed(::OnlyAccsVarInfo) = false +@inline BangBang.push!!(vi::OnlyAccsVarInfo, vn, y, dist) = vi function DynamicPPL.get_param_eltype( ::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, model::Model ) @@ -117,14 +125,12 @@ function DynamicPPL.get_param_eltype( leaf_ctx = DynamicPPL.leafcontext(model.context) if leaf_ctx isa FastEvalVectorContext return eltype(leaf_ctx.params) + elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams} + return DynamicPPL.infer_nested_eltype(typeof(leaf_ctx.strategy.params)) + elseif leaf_ctx isa InitContext{<:Any,<:Union{InitFromPrior,InitFromUniform}} + # No need to enforce any particular eltype here, since new parameters are sampled + return Any else - # TODO(penelopeysm): In principle this can be done with InitContext{InitWithParams}. - # See also `src/simple_varinfo.jl` where `infer_nested_eltype` is used to try to - # figure out the parameter type from a NamedTuple or Dict. The benefit of - # implementing this for InitContext is that we could then use OnlyAccsVarInfo with - # it, which means fast evaluation with NamedTuple or Dict parameters! And I believe - # that Mooncake / Enzyme should be able to differentiate through that too and - # provide a NamedTuple of gradients (although I haven't tested this yet). error( "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", ) From 40ac87d5e99533236f0c933dc7f2a4e13b0877e3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 12:19:18 +0000 Subject: [PATCH 31/57] Refactor FastLDF to use InitContext --- ext/DynamicPPLEnzymeCoreExt.jl | 2 +- ext/DynamicPPLMooncakeExt.jl | 2 +- src/DynamicPPL.jl | 1 + src/contexts/init.jl | 145 +++++++++++++++++++----- src/fasteval.jl | 196 ++++++--------------------------- src/model.jl | 38 +++++-- src/onlyaccs.jl | 29 +++++ 7 files changed, 214 insertions(+), 199 deletions(-) create mode 100644 src/onlyaccs.jl diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index eacc35046..ef21c255b 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -9,7 +9,7 @@ using EnzymeCore nothing # Likewise for get_range_and_linked. @inline EnzymeCore.EnzymeRules.inactive( - ::typeof(DynamicPPL.Experimental.get_range_and_linked), args... + ::typeof(DynamicPPL._get_range_and_linked), args... ) = nothing end diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index 63c754f4c..8adf66030 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -6,7 +6,7 @@ using Mooncake: Mooncake # This is purely an optimisation. Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg} Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{ - typeof(DynamicPPL.Experimental.get_range_and_linked),Vararg + typeof(DynamicPPL._get_range_and_linked),Vararg } end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c43bd89d5..06d8effd8 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -193,6 +193,7 @@ include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") include("simple_varinfo.jl") +include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") include("logdensityfunction.jl") diff --git a/src/contexts/init.jl b/src/contexts/init.jl index be2dc1b8a..6aff63e29 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -9,16 +9,24 @@ Any subtype of `AbstractInitStrategy` must implement the """ abstract type AbstractInitStrategy end +""" + InitValue{T,F} + +A wrapper type representing a value of type `T`. The function `F` indicates what transform +to apply to the value to convert it back to the unlinked space. If `value` is already in +unlinked space, then `transform` can be `identity`. +""" +struct InitValue{T,F} + value::T + transform::F +end + """ init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy) Generate a new value for a random variable with the given distribution. -!!! warning "Return values must be unlinked" - The values returned by `init` must always be in the untransformed space, i.e., - they must be within the support of the original distribution. That means that, - for example, `init(rng, dist, u::InitFromUniform)` will in general return values that - are outside the range [u.lower, u.upper]. +This function must return a `InitValue`. """ function init end @@ -29,7 +37,7 @@ Obtain new values by sampling from the prior distribution. """ struct InitFromPrior <: AbstractInitStrategy end function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior) - return rand(rng, dist) + return InitValue(rand(rng, dist), identity) end """ @@ -69,7 +77,7 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFro if x isa Array{<:Any,0} x = x[] end - return x + return InitValue(x, identity) end """ @@ -93,19 +101,20 @@ will be thrown. The default for `fallback` is `InitFromPrior()`. struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy params::P fallback::S + function InitFromParams( - params::AbstractDict{<:VarName}, - fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior(), - ) - return new{typeof(params),typeof(fallback)}(params, fallback) - end - function InitFromParams( - params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() - ) - return new{typeof(params),typeof(fallback)}(params, fallback) + params::P, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() + ) where {P} + return new{P,typeof(fallback)}(params, fallback) end end -function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) + +function init( + rng::Random.AbstractRNG, + vn::VarName, + dist::Distribution, + p::InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}}, +) # TODO(penelopeysm): It would be nice to do a check to make sure that all # of the parameters in `p.params` were actually used, and either warn or # error if they aren't. This is actually quite non-trivial though because @@ -119,7 +128,7 @@ function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitF else # TODO(penelopeysm): Since x is user-supplied, maybe we could also # check here that the type / size of x matches the dist? - x + InitValue(x, identity) end else p.fallback === nothing && error("No value was provided for the variable `$(vn)`.") @@ -127,6 +136,74 @@ function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitF end end +""" + RangeAndLinked + +Suppose we have vectorised parameters `params::AbstractVector{<:Real}`. Each random variable +in the model will in general correspond to a sub-vector of `params`. This struct stores +information about that range, as well as whether the sub-vector represents a linked value or +an unlinked value. + +$(TYPEDFIELDS) +""" +struct RangeAndLinked + # indices that the variable corresponds to in the vectorised parameter + range::UnitRange{Int} + # whether it's linked + is_linked::Bool +end + +""" + VectorWithRanges( + iden_varname_ranges::NamedTuple, + varname_ranges::Dict{VarName,RangeAndLinked}, + vect::AbstractVector{<:Real}, + ) + +A struct that wraps a vector of parameter values, plus information about how random +variables map to ranges in that vector. + +In the simplest case, this could be accomplished only with a single dictionary mapping +VarNames to ranges and link status. However, for performance reasons, we separate out +VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All +non-identity-optic VarNames are stored in the `varname_ranges` Dict. + +It would be nice to improve the NamedTuple and Dict approach. See, e.g. +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. +""" +struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}} + # This NamedTuple stores the ranges for identity VarNames + iden_varname_ranges::N + # This Dict stores the ranges for all other VarNames + varname_ranges::Dict{VarName,RangeAndLinked} + # The full parameter vector which we index into to get variable values + vect::T +end + +function _get_range_and_linked( + vr::VectorWithRanges, ::VarName{sym,typeof(identity)} +) where {sym} + return vr.iden_varname_ranges[sym] +end +function _get_range_and_linked(vr::VectorWithRanges, vn::VarName) + return vr.varname_ranges[vn] +end +function init( + ::Random.AbstractRNG, + vn::VarName, + dist::Distribution, + p::InitFromParams{<:VectorWithRanges}, +) + vr = p.params + range_and_linked = _get_range_and_linked(vr, vn) + transform = if range_and_linked.is_linked + from_linked_vec_transform(dist) + else + from_vec_transform(dist) + end + return InitValue((@view vr.vect[range_and_linked.range]), transform) +end + """ InitContext( [rng::Random.AbstractRNG=Random.default_rng()], @@ -155,9 +232,8 @@ function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) in_varinfo = haskey(vi, vn) - # `init()` always returns values in original space, i.e. possibly - # constrained - x = init(ctx.rng, vn, dist, ctx.strategy) + init_val = init(ctx.rng, vn, dist, ctx.strategy) + x, inv_logjac = with_logabsdet_jacobian(init_val.transform, init_val.value) # Determine whether to insert a transformed value into the VarInfo. # If the VarInfo alrady had a value for this variable, we will # keep the same linked status as in the original VarInfo. If not, we @@ -165,17 +241,34 @@ function tilde_assume!!( # is_transformed(vi) returns true if vi is nonempty and all variables in vi # are linked. insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi) - y, logjac = if insert_transformed_value - with_logabsdet_jacobian(link_transform(dist), x) + val_to_insert, logjac = if insert_transformed_value + # Calculate the forward logjac and sum them up. + y, fwd_logjac = with_logabsdet_jacobian(link_transform(dist), x) + # Note that if we use VectorWithRanges with a full VarInfo, this double-Jacobian + # calculation wastes a lot of time going from linked vectorised -> unlinked -> + # linked, and `inv_logjac` will also just be the negative of `fwd_logjac`. However, + # `VectorWithRanges` is only really used with `OnlyAccsVarInfo`, in which case this + # branch is never hit (since `in_varinfo` will always be false). So we can leave + # this branch in for full generality with other combinations of init strategies / + # VarInfo. + # + # TODO(penelopeysm): Figure out one day how to refactor this. The crux of the issue + # is that the transform used by `VectorWithRanges` is `from_linked_VEC_transform`, + # which is NOT the same as `inverse(link_transform)` (because there is an additional + # vectorisation step). We need `init` and `tilde_assume!!` to share this information + # but it's not clear right now how to do this. In my opinion, the most productive + # way forward would be to standardise the behaviour of bijectors so that we can have + # a clean separation between the linking and vectorisation parts of it. + y, inv_logjac + fwd_logjac else - x, zero(LogProbType) + x, inv_logjac end # Add the new value to the VarInfo. `push!!` errors if the value already # exists, hence the need for setindex!!. if in_varinfo - vi = setindex!!(vi, y, vn) + vi = setindex!!(vi, val_to_insert, vn) else - vi = push!!(vi, vn, y, dist) + vi = push!!(vi, vn, val_to_insert, dist) end # Neither of these set the `trans` flag so we have to do it manually if # necessary. diff --git a/src/fasteval.jl b/src/fasteval.jl index fbc6a61ce..3f933adf6 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -17,16 +17,25 @@ we care about from model evaluation are those which are stored in accumulators, probability densities, or `ValuesAsInModel`. To avoid this issue, we implement here `OnlyAccsVarInfo`, which is a VarInfo that only -contains accumulators. When evaluating a model with `OnlyAccsVarInfo`, it is mandatory that -the model's leaf context is a `FastEvalContext`, which provides extremely fast access to -parameter values. No writing of values into VarInfo metadata is performed at all. +contains accumulators. It implements enough of the `AbstractVarInfo` interface to not error +during model evaluation. + +Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with +it, it is mandatory that parameters are provided from outside the VarInfo, namely via +`InitContext{<:InitFromParams}`. + +NamedTuple and Dict parameters +------------------------------ + +OnlyAccsVarInfo works out of the box with the existing `InitContext{<:InitFromParams{<:P}}`, +functionality, where `P` is either a NamedTuple or a Dict. Vector parameters ----------------- -We first consider the case of parameter vectors, i.e., the case which would normally be -handled by `unflatten` and `evaluate!!`. Unfortunately, it is not enough to just store -the vector of parameters in the `FastEvalContext`, because it is not clear: +Vector parameters are more complicated, since it is not possible to directly implement +`DynamicPPL.init(rng, vn, dist, strategy)` for `strategy::InitFromParams{<:AbstractVector}`. +In particular, it is not clear: - which parts of the vector correspond to which random variables, and - whether the variables are linked or unlinked. @@ -36,192 +45,50 @@ place values into the VarInfo's metadata alongside the information about ranges However, we want to avoid doing this. Thus, here, we _extract this information from the VarInfo_ a single time when constructing a `FastLDF` object. +This creates a single struct, `ParamsWithRanges`, which contains: + + - the vector of parameters + - a mapping from VarNames to ranges in that vector, along with link status + +When evaluating the model, we create a `FastEvalVectorContext`, which reads parameter + Note that this assumes that the ranges and link status are static throughout the lifetime of the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable numbers of parameters, or models which may visit random variables in different orders depending on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a general limitation of vectorised parameters: the original `unflatten` + `evaluate!!` approach also fails with such models. - -NamedTuple and Dict parameters ------------------------------- - -Fast evaluation has not yet been extended to NamedTuple and Dict parameters. Such -representations are capable of handling models with variable sizes and stochastic control -flow. - -However, the path towards implementing these is straightforward: just make `InitContext` work -correctly with `OnlyAccsVarInfo`. There will probably be a few functions that need to be -overloaded to make this work: for example `push!!` on `OnlyAccsVarInfo` can just be defined -as a no-op. """ using DynamicPPL: - AbstractContext, AbstractVarInfo, AccumulatorTuple, InitContext, InitFromParams, - InitFromPrior, - InitFromUniform, LogJacobianAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, - Metadata, Model, ThreadSafeVarInfo, VarInfo, + OnlyAccsVarInfo, + RangeAndLinked, + VectorWithRanges, + Metadata, VarNamedVector, - accumulate_assume!!, - accumulate_observe!!, default_accumulators, float_type_with_fallback, - from_linked_vec_transform, - from_vec_transform, getlogjoint, getlogjoint_internal, getloglikelihood, getlogprior, - getlogprior_internal, - leafcontext + getlogprior_internal using ADTypes: ADTypes using BangBang: BangBang -using Bijectors: with_logabsdet_jacobian using AbstractPPL: AbstractPPL, VarName -using Distributions: Distribution -using DocStringExtensions: TYPEDFIELDS using LogDensityProblems: LogDensityProblems import DifferentiationInterface as DI - -""" - OnlyAccsVarInfo - -This is a wrapper around an `AccumulatorTuple` that implements the minimal `AbstractVarInfo` -interface to work with the `accumulate_assume!!` and `accumulate_observe!!` functions. - -Note that this does not implement almost every other AbstractVarInfo interface function, and -so using this outside of FastLDF will lead to errors. - -Conceptually, one can also think of this as a VarInfo that doesn't contain a metadata field. -That is because values for random variables are obtained by reading from a separate entity -(such as a `FastLDFContext`), rather than from the VarInfo itself. -""" -struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo - accs::Accs -end -OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) -DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi -DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs -DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) -@inline Base.haskey(::OnlyAccsVarInfo, ::VarName) = false -@inline DynamicPPL.is_transformed(::OnlyAccsVarInfo) = false -@inline BangBang.push!!(vi::OnlyAccsVarInfo, vn, y, dist) = vi -function DynamicPPL.get_param_eltype( - ::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, model::Model -) - # Because the VarInfo has no parameters stored in it, we need to get the eltype from the - # model's leaf context. This is only possible if said leaf context is indeed a FastEval - # context. - leaf_ctx = DynamicPPL.leafcontext(model.context) - if leaf_ctx isa FastEvalVectorContext - return eltype(leaf_ctx.params) - elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams} - return DynamicPPL.infer_nested_eltype(typeof(leaf_ctx.strategy.params)) - elseif leaf_ctx isa InitContext{<:Any,<:Union{InitFromPrior,InitFromUniform}} - # No need to enforce any particular eltype here, since new parameters are sampled - return Any - else - error( - "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", - ) - end -end - -""" - RangeAndLinked - -Suppose we have vectorised parameters `params::AbstractVector{<:Real}`. Each random variable -in the model will in general correspond to a sub-vector of `params`. This struct stores -information about that range, as well as whether the sub-vector represents a linked value or -an unlinked value. - -$(TYPEDFIELDS) -""" -struct RangeAndLinked - # indices that the variable corresponds to in the vectorised parameter - range::UnitRange{Int} - # whether it's linked - is_linked::Bool -end - -""" - FastEvalVectorContext( - iden_varname_ranges::NamedTuple, - varname_ranges::Dict{VarName,RangeAndLinked}, - params::AbstractVector{<:Real}, - ) <: AbstractContext - -A context that wraps a vector of parameter values, plus information about how random -variables map to ranges in that vector. - -In the simplest case, this could be accomplished only with a single dictionary mapping -VarNames to ranges and link status. However, for performance reasons, we separate out -VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All -non-identity-optic VarNames are stored in the `varname_ranges` Dict. - -It would be nice to unify the NamedTuple and Dict approach. See, e.g. -https://github.com/TuringLang/DynamicPPL.jl/issues/1116. -""" -struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext - # This NamedTuple stores the ranges for identity VarNames - iden_varname_ranges::N - # This Dict stores the ranges for all other VarNames - varname_ranges::Dict{VarName,RangeAndLinked} - # The full parameter vector which we index into to get variable values - params::T -end -DynamicPPL.NodeTrait(::FastEvalVectorContext) = DynamicPPL.IsLeaf() - -function get_range_and_linked( - ctx::FastEvalVectorContext, ::VarName{sym,typeof(identity)} -) where {sym} - return ctx.iden_varname_ranges[sym] -end -function get_range_and_linked(ctx::FastEvalVectorContext, vn::VarName) - return ctx.varname_ranges[vn] -end - -function DynamicPPL.tilde_assume!!( - ctx::FastEvalVectorContext, right::Distribution, vn::VarName, vi::AbstractVarInfo -) - # Note that this function does not use the metadata field of `vi` at all. - range_and_linked = get_range_and_linked(ctx, vn) - y = @view ctx.params[range_and_linked.range] - f = if range_and_linked.is_linked - from_linked_vec_transform(right) - else - from_vec_transform(right) - end - x, inv_logjac = with_logabsdet_jacobian(f, y) - vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) - return x, vi -end - -function DynamicPPL.tilde_observe!!( - ::FastEvalVectorContext, - right::Distribution, - left, - vn::Union{VarName,Nothing}, - vi::AbstractVarInfo, -) - # This is the same as for DefaultContext - vi = accumulate_observe!!(vi, right, left, vn) - return left, vi -end - -######################################## -# Log-density functions using FastEval # -######################################## +using Random: Random """ FastLDF( @@ -365,7 +232,12 @@ struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} _varname_ranges::Dict{VarName,RangeAndLinked} end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) - ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params) + ctx = InitContext( + Random.default_rng(), + InitFromParams( + VectorWithRanges(f._iden_varname_ranges, f._varname_ranges, params), nothing + ), + ) model = DynamicPPL.setleafcontext(f._model, ctx) accs = fast_ldf_accs(f._getlogdensity) # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, diff --git a/src/model.jl b/src/model.jl index 27b6157e2..71b4d228f 100644 --- a/src/model.jl +++ b/src/model.jl @@ -986,9 +986,13 @@ Return the arguments and keyword arguments to be passed to the evaluator of the ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) - :($matchingvalue($get_param_eltype(varinfo, model), model.args.$var)...) + :( + $matchingvalue( + $get_param_eltype(varinfo, model.context), model.args.$var + )... + ) else - :($matchingvalue($get_param_eltype(varinfo, model), model.args.$var)) + :($matchingvalue($get_param_eltype(varinfo, model.context), model.args.$var)) end for var in argnames ] return quote @@ -1007,20 +1011,36 @@ Return the arguments and keyword arguments to be passed to the evaluator of the end """ - get_param_eltype(varinfo::AbstractVarInfo, model::Model) + get_param_eltype(varinfo::AbstractVarInfo, context::AbstractContext) -Get the element type of the parameters being used to evaluate the `model` from the -`varinfo`. For example, when performing AD with ForwardDiff, this should return -`ForwardDiff.Dual`. +Get the element type of the parameters being used to evaluate a model, using a `varinfo` +under the given `context`. For example, when evaluating a model with ForwardDiff AD, this +should return `ForwardDiff.Dual`. By default, this uses `eltype(varinfo)` which is slightly cursed. This relies on the fact that typically, before evaluation, the parameters will have been inserted into the VarInfo's metadata field. -See `OnlyAccsVarInfo` for an example of where this is not true (the parameters are instead -stored in the model's context). +For InitContext, it's quite different: because InitContext is responsible for supplying the +parameters, we can avoid using `eltype(varinfo)` and instead query the parameters inside it. """ -get_param_eltype(varinfo::AbstractVarInfo, ::Model) = eltype(varinfo) +get_param_eltype(vi::AbstractVarInfo, ::DefaultContext) = eltype(vi) +function get_param_eltype(vi::AbstractVarInfo, ctx::AbstractContext) + return get_param_eltype(vi, DynamicPPL.childcontext(ctx)) +end +function get_param_eltype(::AbstractVarInfo, ctx::InitContext) + return _get_strat_param_eltype(ctx.strategy) +end +function _get_strat_param_eltype(strategy::InitFromParams{<:VectorWithRanges}) + return eltype(strategy.params.vect) +end +function _get_strat_param_eltype( + strategy::InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}} +) + return infer_nested_eltype(typeof(strategy.params)) +end +# No need to specify a type since new ones are generated +_get_strat_param_eltype(::Union{InitFromPrior,InitFromUniform}) = Any """ getargnames(model::Model) diff --git a/src/onlyaccs.jl b/src/onlyaccs.jl new file mode 100644 index 000000000..2e9f34177 --- /dev/null +++ b/src/onlyaccs.jl @@ -0,0 +1,29 @@ +""" + OnlyAccsVarInfo + +This is a wrapper around an `AccumulatorTuple` that implements the minimal `AbstractVarInfo` +interface to work with the `accumulate_assume!!` and `accumulate_observe!!` functions. + +Note that this does not implement almost every other AbstractVarInfo interface function, and +so using this outside of FastLDF will lead to errors. + +Conceptually, one can also think of this as a VarInfo that doesn't contain a metadata field. +That is because values for random variables are obtained by reading from a separate entity +(such as a `FastLDFContext`), rather than from the VarInfo itself. +""" +struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo + accs::Accs +end +OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) +function OnlyAccsVarInfo(accs::NTuple{N,AbstractAccumulator}) where {N} + return OnlyAccsVarInfo(AccumulatorTuple(accs)) +end + +# AbstractVarInfo interface +@inline DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi +@inline DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs +@inline DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = + OnlyAccsVarInfo(accs) +@inline Base.haskey(::OnlyAccsVarInfo, ::VarName) = false +@inline DynamicPPL.is_transformed(::OnlyAccsVarInfo) = false +@inline BangBang.push!!(vi::OnlyAccsVarInfo, vn, y, dist) = vi From 83add24ac1b258e3de438982addb935b2a47915e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 12:34:51 +0000 Subject: [PATCH 32/57] note init breaking change --- HISTORY.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index f56f85e8d..5c3315b1b 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -21,6 +21,9 @@ You should not need to use these directly, please use `AbstractPPL.condition` an Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. +The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return an `InitValue` struct, which holds both a value as well as a transform function that maps it back to unlinked space. +This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function). + ### Other changes #### FastLDF From f840dbc8cbf309fe3944c2045e3bcca3769684aa Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 12:55:58 +0000 Subject: [PATCH 33/57] fix logjac sign --- src/contexts/init.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 6aff63e29..bb600f226 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -259,9 +259,9 @@ function tilde_assume!!( # but it's not clear right now how to do this. In my opinion, the most productive # way forward would be to standardise the behaviour of bijectors so that we can have # a clean separation between the linking and vectorisation parts of it. - y, inv_logjac + fwd_logjac + y, -inv_logjac + fwd_logjac else - x, inv_logjac + x, -inv_logjac end # Add the new value to the VarInfo. `push!!` errors if the value already # exists, hence the need for setindex!!. From 5d54e13d3c0656af0a6cd111b401a5a552e9bce1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 13:08:19 +0000 Subject: [PATCH 34/57] workaround Mooncake segfault --- src/contexts/init.jl | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index bb600f226..daca475ca 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -12,9 +12,6 @@ abstract type AbstractInitStrategy end """ InitValue{T,F} -A wrapper type representing a value of type `T`. The function `F` indicates what transform -to apply to the value to convert it back to the unlinked space. If `value` is already in -unlinked space, then `transform` can be `identity`. """ struct InitValue{T,F} value::T @@ -26,7 +23,11 @@ end Generate a new value for a random variable with the given distribution. -This function must return a `InitValue`. +This function must return a tuple of: + +- the generated value +- a function that transforms the generated value back to the unlinked space. If the value is + already in unlinked space, then this should be `identity`. """ function init end @@ -37,7 +38,7 @@ Obtain new values by sampling from the prior distribution. """ struct InitFromPrior <: AbstractInitStrategy end function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior) - return InitValue(rand(rng, dist), identity) + return rand(rng, dist), identity end """ @@ -77,7 +78,7 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFro if x isa Array{<:Any,0} x = x[] end - return InitValue(x, identity) + return x, identity end """ @@ -128,7 +129,7 @@ function init( else # TODO(penelopeysm): Since x is user-supplied, maybe we could also # check here that the type / size of x matches the dist? - InitValue(x, identity) + x, identity end else p.fallback === nothing && error("No value was provided for the variable `$(vn)`.") @@ -201,7 +202,7 @@ function init( else from_vec_transform(dist) end - return InitValue((@view vr.vect[range_and_linked.range]), transform) + return (@view vr.vect[range_and_linked.range]), transform end """ @@ -232,8 +233,8 @@ function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) in_varinfo = haskey(vi, vn) - init_val = init(ctx.rng, vn, dist, ctx.strategy) - x, inv_logjac = with_logabsdet_jacobian(init_val.transform, init_val.value) + val, transform = init(ctx.rng, vn, dist, ctx.strategy) + x, inv_logjac = with_logabsdet_jacobian(transform, val) # Determine whether to insert a transformed value into the VarInfo. # If the VarInfo alrady had a value for this variable, we will # keep the same linked status as in the original VarInfo. If not, we From c129125f94c51dfcdbc5fba5d4295a461e6bf4ab Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 13:10:56 +0000 Subject: [PATCH 35/57] fix changelog too --- HISTORY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 5c3315b1b..0f0102ce4 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -21,7 +21,7 @@ You should not need to use these directly, please use `AbstractPPL.condition` an Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. -The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return an `InitValue` struct, which holds both a value as well as a transform function that maps it back to unlinked space. +The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space. This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function). ### Other changes From 00ad3ac92272dce4c5b65d3fc263a12d154f1d0a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 15:29:14 +0000 Subject: [PATCH 36/57] Fix get_param_eltype for context stacks --- src/model.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/model.jl b/src/model.jl index 71b4d228f..0219e6046 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1024,13 +1024,14 @@ metadata field. For InitContext, it's quite different: because InitContext is responsible for supplying the parameters, we can avoid using `eltype(varinfo)` and instead query the parameters inside it. """ -get_param_eltype(vi::AbstractVarInfo, ::DefaultContext) = eltype(vi) -function get_param_eltype(vi::AbstractVarInfo, ctx::AbstractContext) +function get_param_eltype(vi::AbstractVarInfo, ctx::AbstractParentContext) return get_param_eltype(vi, DynamicPPL.childcontext(ctx)) end +get_param_eltype(vi::AbstractVarInfo, ::AbstractContext) = eltype(vi) function get_param_eltype(::AbstractVarInfo, ctx::InitContext) return _get_strat_param_eltype(ctx.strategy) end + function _get_strat_param_eltype(strategy::InitFromParams{<:VectorWithRanges}) return eltype(strategy.params.vect) end From a371b886390b9c78c6acb5fa2e9c4d9da052120c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 15:41:12 +0000 Subject: [PATCH 37/57] Add a test for threaded observe --- test/fasteval.jl | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/fasteval.jl b/test/fasteval.jl index 57a2c937d..f1c535643 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -75,6 +75,24 @@ end end end end + + @testset "Threaded observe" begin + if Threads.nthreads() > 1 + @model function threaded(y) + x ~ Normal() + Threads.@threads for i in eachindex(y) + y[i] ~ Normal(x) + end + end + N = 100 + model = threaded(zeros(N)) + ldf = DynamicPPL.Experimental.FastLDF(model) + + xs = [1.0] + @test LogDensityProblems.logdensity(ldf, xs) ≈ + logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) + end + end end @testset "FastLDF: performance" begin From 2c3563c8c8991ab5a7ecc4cb85c03d9862b83fdf Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 15:48:27 +0000 Subject: [PATCH 38/57] Export init --- src/DynamicPPL.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 06d8effd8..0baa7e828 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -113,6 +113,7 @@ export AbstractVarInfo, InitFromPrior, InitFromUniform, InitFromParams, + init, # Pseudo distributions NamedDist, NoDist, From f12b4077ec647e4d60c166c9b1e789bfc3b7184c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 15:56:13 +0000 Subject: [PATCH 39/57] Remove dead code --- docs/src/api.md | 4 ++-- src/contexts/init.jl | 9 --------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 63dafdfca..99ffc4a4c 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -519,8 +519,8 @@ InitFromParams If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method. ```@docs -DynamicPPL.AbstractInitStrategy -DynamicPPL.init +AbstractInitStrategy +init ``` ### Choosing a suitable VarInfo diff --git a/src/contexts/init.jl b/src/contexts/init.jl index daca475ca..a1626b67f 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -9,15 +9,6 @@ Any subtype of `AbstractInitStrategy` must implement the """ abstract type AbstractInitStrategy end -""" - InitValue{T,F} - -""" -struct InitValue{T,F} - value::T - transform::F -end - """ init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy) From de88c78c1a20cfedd03a9cdde154a9536eb51d32 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 16:42:17 +0000 Subject: [PATCH 40/57] fix transforms for pathological distributions --- src/contexts/init.jl | 26 ++++++++++++++++---------- src/utils.jl | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index a1626b67f..870802ace 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -29,7 +29,7 @@ Obtain new values by sampling from the prior distribution. """ struct InitFromPrior <: AbstractInitStrategy end function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior) - return rand(rng, dist), identity + return rand(rng, dist), _typed_identity end """ @@ -69,7 +69,7 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFro if x isa Array{<:Any,0} x = x[] end - return x, identity + return x, _typed_identity end """ @@ -120,7 +120,7 @@ function init( else # TODO(penelopeysm): Since x is user-supplied, maybe we could also # check here that the type / size of x matches the dist? - x, identity + x, _typed_identity end else p.fallback === nothing && error("No value was provided for the variable `$(vn)`.") @@ -238,19 +238,25 @@ function tilde_assume!!( y, fwd_logjac = with_logabsdet_jacobian(link_transform(dist), x) # Note that if we use VectorWithRanges with a full VarInfo, this double-Jacobian # calculation wastes a lot of time going from linked vectorised -> unlinked -> - # linked, and `inv_logjac` will also just be the negative of `fwd_logjac`. However, - # `VectorWithRanges` is only really used with `OnlyAccsVarInfo`, in which case this - # branch is never hit (since `in_varinfo` will always be false). So we can leave - # this branch in for full generality with other combinations of init strategies / - # VarInfo. + # linked, and `inv_logjac` will also just be the negative of `fwd_logjac`. + # + # However, `VectorWithRanges` is only really used with `OnlyAccsVarInfo`, in which + # case this branch is never hit (since `in_varinfo` will always be false). It does + # mean that the combination of InitFromParams{<:VectorWithRanges} with a full, + # linked, VarInfo will be very slow. That should never really be used, though. So + # (at least for now) we can leave this branch in for full generality with other + # combinations of init strategies / VarInfo. # # TODO(penelopeysm): Figure out one day how to refactor this. The crux of the issue # is that the transform used by `VectorWithRanges` is `from_linked_VEC_transform`, # which is NOT the same as `inverse(link_transform)` (because there is an additional # vectorisation step). We need `init` and `tilde_assume!!` to share this information # but it's not clear right now how to do this. In my opinion, the most productive - # way forward would be to standardise the behaviour of bijectors so that we can have - # a clean separation between the linking and vectorisation parts of it. + # way forward would be to clean up the behaviour of bijectors so that we can have a + # clean separation between the linking and vectorisation parts of it. That way, `x` + # can either be unlinked, unlinked vectorised, linked, or linked vectorised, and + # regardless of which it is, we should only need to apply at most one linking and + # one vectorisation transform. y, -inv_logjac + fwd_logjac else x, -inv_logjac diff --git a/src/utils.jl b/src/utils.jl index b55a2f715..cd91ff332 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -15,6 +15,41 @@ This is Float64 on 64-bit systems and Float32 on 32-bit systems. """ const LogProbType = float(Real) +""" + _typed_identity(x) + +Identity function, but with an overload for `with_logabsdet_jacobian` to ensure +that it returns a sensible zero logjac. + +The problem with plain old `identity` is that the default definition of +`with_logabsdet_jacobian` for `identity` returns `zero(eltype(x))`: +https://github.com/JuliaMath/ChangesOfVariables.jl/blob/d6a8115fc9b9419decbdb48e2c56ec9675b4c6a4/src/with_ladj.jl#L154 + +This is fine for most samples `x`, but if `eltype(x)` doesn't return a sensible type (e.g. +if it's `Any`), then using `identity` will error with `zero(Any)`. This can happen with, +for example, `ProductNamedTupleDistribution`: + +```julia +julia> using Distributions; d = product_distribution((a = Normal(), b = LKJCholesky(3, 0.5))); + +julia> eltype(rand(d)) +Any +``` + +The same problem precludes us from eventually broadening the scope of DynamicPPL.jl to +support distributions with non-numeric samples. + +Furthermore, in principle, the type of the log-probability should be separate from the type +of the sample. Thus, instead of using `zero(LogProbType)`, we should use the eltype of the +LogJacobianAccumulator. There's no easy way to thread that through here, but if a way to do +this is discovered, then `_typed_identity` is what will allow us to obtain that custom +behaviour. +""" +function _typed_identity end +@inline _typed_identity(x) = x +@inline Bijectors.with_logabsdet_jacobian(::typeof(_typed_identity), x) = + (x, zero(LogProbType)) + """ @addlogprob!(ex) From 1156a490e98c763c7e75c53412826e546d262367 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 19:32:48 +0000 Subject: [PATCH 41/57] Tidy up loads of things --- docs/src/api.md | 8 ++++ src/DynamicPPL.jl | 2 +- src/contexts/init.jl | 112 ++++++++++++++++++++++++++++++++++--------- src/model.jl | 19 ++------ 4 files changed, 103 insertions(+), 38 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 99ffc4a4c..e81f18dc7 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -170,6 +170,12 @@ DynamicPPL.prefix ## Utilities +`typed_identity` is the same as `identity`, but with an overload for `with_logabsdet_jacobian` that ensures that it never errors. + +```@docs +typed_identity +``` + It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function. ```@docs @@ -517,10 +523,12 @@ InitFromParams ``` If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method. +In very rare situations, you may also need to implement `get_param_eltype`, which defines the element type of the parameters generated by the strategy. ```@docs AbstractInitStrategy init +get_param_eltype ``` ### Choosing a suitable VarInfo diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 0baa7e828..36f6791a9 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -84,8 +84,8 @@ export AbstractVarInfo, # Compiler @model, # Utilities - init, OrderedDict, + typed_identity, # Model Model, getmissings, diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 870802ace..2fe71efbf 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -1,11 +1,11 @@ """ AbstractInitStrategy -Abstract type representing the possible ways of initialising new values for -the random variables in a model (e.g., when creating a new VarInfo). +Abstract type representing the possible ways of initialising new values for the random +variables in a model (e.g., when creating a new VarInfo). -Any subtype of `AbstractInitStrategy` must implement the -[`DynamicPPL.init`](@ref) method. +Any subtype of `AbstractInitStrategy` must implement the [`DynamicPPL.init`](@ref) method, +and very rarely, [`DynamicPPL.get_param_eltype`](@ref). """ abstract type AbstractInitStrategy end @@ -14,14 +14,58 @@ abstract type AbstractInitStrategy end Generate a new value for a random variable with the given distribution. -This function must return a tuple of: +This function must return a tuple `(x, trf)`, where -- the generated value -- a function that transforms the generated value back to the unlinked space. If the value is - already in unlinked space, then this should be `identity`. +- `x` is the generated value + +- `trf` is a function that transforms the generated value back to the unlinked space. If the + value is already in unlinked space, then this should be `DynamicPPL.typed_identity`. You + can also use `Base.identity`, but if you use this, you **must** be confident that + `zero(eltype(x))` will **never** error. See the docstring of `typed_identity` for more + information. """ function init end +""" + DynamicPPL.get_param_eltype(strategy::AbstractInitStrategy) + +Return the element type of the parameters generated from the given initialisation strategy. + +The default implementation returns `Any`. However, for `InitFromParams` which provides known +parameters for evaluating the model, methods are implemented in order to return more specific +types. + +For the most part, a return value of `Any` will actually suffice. However, there are a few +edge cases in DynamicPPL where the element type is needed. These largely relate to +determining the element type of accumulators ahead of time (_before_ evaluation), as well as +promoting type parameters in model arguments. The classic case is when evaluating a model +with ForwardDiff: the accumulators must be set to `Dual`s, and any `Vector{Float64}` +arguments must be promoted to `Vector{Dual}`. Other tracer types, for example those in +SparseConnectivityTracer.jl, also require similar treatment. + +If `AbstractInitStrategy` is never used in combination with tracer types, then it is +perfectly safe to return `Any`. This does not lead to type instability downstream because +the actual accumulators will still be created with concrete Float types (the `Any` is just +used to determine whether the float type needs to be modified). + +(Detail: in fact, the above is not always true. Firstly, the accumulator argument is only +true when evaluating with ThreadSafeVarInfo. See the comments in `DynamicPPL.unflatten` for +more details. For non-threadsafe evaluation, Julia is capable of automatically promoting the +types on its own. Secondly, the promotion only matters if you are trying to directly assign +into a `Vector{Float64}` with a `ForwardDiff.Dual` or similar tracer type, for example using +`xs[i] = MyDual`. This doesn't actually apply to tilde-statements like `xs[i] ~ ...` because +those use `Accessors.@set` under the hood, which also does the promotion for you.) +""" +get_param_eltype(::AbstractInitStrategy) = Any +function get_param_eltype(strategy::InitFromParams{<:VectorWithRanges}) + return eltype(strategy.params.vect) +end +function get_param_eltype( + strategy::InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}} +) + return infer_nested_eltype(typeof(strategy.params)) +end + """ InitFromPrior() @@ -74,21 +118,46 @@ end """ InitFromParams( - params::Union{AbstractDict{<:VarName},NamedTuple}, + params::Any fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() ) -Obtain new values by extracting them from the given dictionary or NamedTuple. +Obtain new values by extracting them from the given set of `params`. + +The most common use case is to provide a `NamedTuple` or `AbstractDict{<:VarName}`, which +provides a mapping from variable names to values. However, we leave the type of `params` +open in order to allow for custom parameter storage types. + +## Custom parameter storage types -The parameter `fallback` specifies how new values are to be obtained if they -cannot be found in `params`, or they are specified as `missing`. `fallback` -can either be an initialisation strategy itself, in which case it will be -used to obtain new values, or it can be `nothing`, in which case an error -will be thrown. The default for `fallback` is `InitFromPrior()`. +For `InitFromParams` to work correctly with a custom `params::P`, you need to implement -!!! note - The values in `params` must be provided in the space of the untransformed - distribution. +```julia +DynamicPPL.init(rng, vn::VarName, dist::Distribution, p::InitFromParams{P}) where {P} +``` + +This tells you how to obtain values for the random variable `vn` from `p.params`. Note that +the last argument is `InitFromParams(params)`, not just `params` itself. Please see the +docstring of [`DynamicPPL.init`](@ref) for more information on the expected behaviour. + +If you only use `InitFromParams` with `DynamicPPL.OnlyAccsVarInfo`, as is usually the case, +then you will not need to implement anything else. So far, this is the same as you would do +for creating any new `AbstractInitStrategy` subtype. + +However, to use `InitFromParams` with a full `DynamicPPL.VarInfo`, you *may* also need to +implement + +```julia +DynamicPPL.get_param_eltype(p::InitFromParams{P}) where {P} +``` + +See the docstring of [`DynamicPPL.get_param_eltype`](@ref) for more information on when this +is needed. + +The argument `fallback` specifies how new values are to be obtained if they cannot be found +in `params`, or they are specified as `missing`. `fallback` can either be an initialisation +strategy itself, in which case it will be used to obtain new values, or it can be `nothing`, +in which case an error will be thrown. The default for `fallback` is `InitFromPrior()`. """ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy params::P @@ -102,11 +171,8 @@ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitS end function init( - rng::Random.AbstractRNG, - vn::VarName, - dist::Distribution, - p::InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}}, -) + rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams{P} +) where {P<:Union{AbstractDict{<:VarName},NamedTuple}} # TODO(penelopeysm): It would be nice to do a check to make sure that all # of the parameters in `p.params` were actually used, and either warn or # error if they aren't. This is actually quite non-trivial though because diff --git a/src/model.jl b/src/model.jl index 0219e6046..2bcfe8f98 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1021,28 +1021,19 @@ By default, this uses `eltype(varinfo)` which is slightly cursed. This relies on that typically, before evaluation, the parameters will have been inserted into the VarInfo's metadata field. -For InitContext, it's quite different: because InitContext is responsible for supplying the -parameters, we can avoid using `eltype(varinfo)` and instead query the parameters inside it. +For `InitContext`, it's quite different: because `InitContext` is responsible for supplying +the parameters, we can avoid using `eltype(varinfo)` and instead query the parameters inside +it. See the docstring of `get_param_eltype(strategy::AbstractInitStrategy)` for more +explanation. """ function get_param_eltype(vi::AbstractVarInfo, ctx::AbstractParentContext) return get_param_eltype(vi, DynamicPPL.childcontext(ctx)) end get_param_eltype(vi::AbstractVarInfo, ::AbstractContext) = eltype(vi) function get_param_eltype(::AbstractVarInfo, ctx::InitContext) - return _get_strat_param_eltype(ctx.strategy) + return get_param_eltype(ctx.strategy) end -function _get_strat_param_eltype(strategy::InitFromParams{<:VectorWithRanges}) - return eltype(strategy.params.vect) -end -function _get_strat_param_eltype( - strategy::InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}} -) - return infer_nested_eltype(typeof(strategy.params)) -end -# No need to specify a type since new ones are generated -_get_strat_param_eltype(::Union{InitFromPrior,InitFromUniform}) = Any - """ getargnames(model::Model) From a3a4795edabf6540a68ce53ceec9bfe3c322baa3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 19:36:11 +0000 Subject: [PATCH 42/57] fix typed_identity spelling --- src/contexts/init.jl | 6 +++--- src/utils.jl | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 2fe71efbf..86a588613 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -73,7 +73,7 @@ Obtain new values by sampling from the prior distribution. """ struct InitFromPrior <: AbstractInitStrategy end function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior) - return rand(rng, dist), _typed_identity + return rand(rng, dist), typed_identity end """ @@ -113,7 +113,7 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFro if x isa Array{<:Any,0} x = x[] end - return x, _typed_identity + return x, typed_identity end """ @@ -186,7 +186,7 @@ function init( else # TODO(penelopeysm): Since x is user-supplied, maybe we could also # check here that the type / size of x matches the dist? - x, _typed_identity + x, typed_identity end else p.fallback === nothing && error("No value was provided for the variable `$(vn)`.") diff --git a/src/utils.jl b/src/utils.jl index cd91ff332..75fb805dc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -16,7 +16,7 @@ This is Float64 on 64-bit systems and Float32 on 32-bit systems. const LogProbType = float(Real) """ - _typed_identity(x) + typed_identity(x) Identity function, but with an overload for `with_logabsdet_jacobian` to ensure that it returns a sensible zero logjac. @@ -42,12 +42,12 @@ support distributions with non-numeric samples. Furthermore, in principle, the type of the log-probability should be separate from the type of the sample. Thus, instead of using `zero(LogProbType)`, we should use the eltype of the LogJacobianAccumulator. There's no easy way to thread that through here, but if a way to do -this is discovered, then `_typed_identity` is what will allow us to obtain that custom +this is discovered, then `typed_identity` is what will allow us to obtain that custom behaviour. """ -function _typed_identity end -@inline _typed_identity(x) = x -@inline Bijectors.with_logabsdet_jacobian(::typeof(_typed_identity), x) = +function typed_identity end +@inline typed_identity(x) = x +@inline Bijectors.with_logabsdet_jacobian(::typeof(typed_identity), x) = (x, zero(LogProbType)) """ From cda780abed061caa6c7aabca6f172ce0ce29bdae Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 19:44:05 +0000 Subject: [PATCH 43/57] fix definition order --- src/contexts/init.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 86a588613..c32024adc 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -57,14 +57,6 @@ into a `Vector{Float64}` with a `ForwardDiff.Dual` or similar tracer type, for e those use `Accessors.@set` under the hood, which also does the promotion for you.) """ get_param_eltype(::AbstractInitStrategy) = Any -function get_param_eltype(strategy::InitFromParams{<:VectorWithRanges}) - return eltype(strategy.params.vect) -end -function get_param_eltype( - strategy::InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}} -) - return infer_nested_eltype(typeof(strategy.params)) -end """ InitFromPrior() @@ -193,6 +185,11 @@ function init( init(rng, vn, dist, p.fallback) end end +function get_param_eltype( + strategy::InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}} +) + return infer_nested_eltype(typeof(strategy.params)) +end """ RangeAndLinked @@ -261,6 +258,9 @@ function init( end return (@view vr.vect[range_and_linked.range]), transform end +function get_param_eltype(strategy::InitFromParams{<:VectorWithRanges}) + return eltype(strategy.params.vect) +end """ InitContext( From 99a4a200657d540d7065795a98faa4031e2e80fe Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 19:54:05 +0000 Subject: [PATCH 44/57] Improve docstrings --- src/fasteval.jl | 123 ++++++++++++++++++++---------------------------- 1 file changed, 51 insertions(+), 72 deletions(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index 3f933adf6..8bedd81f4 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -1,65 +1,3 @@ -""" -fasteval.jl ------------ - -Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a -given set of parameters: - -1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters - inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. - -2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores - them inside a VarInfo's metadata. - -In general, both of these approaches work fine, but the fact that they modify the VarInfo's -metadata can often be quite wasteful. In particular, it is very common that the only outputs -we care about from model evaluation are those which are stored in accumulators, such as log -probability densities, or `ValuesAsInModel`. - -To avoid this issue, we implement here `OnlyAccsVarInfo`, which is a VarInfo that only -contains accumulators. It implements enough of the `AbstractVarInfo` interface to not error -during model evaluation. - -Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with -it, it is mandatory that parameters are provided from outside the VarInfo, namely via -`InitContext{<:InitFromParams}`. - -NamedTuple and Dict parameters ------------------------------- - -OnlyAccsVarInfo works out of the box with the existing `InitContext{<:InitFromParams{<:P}}`, -functionality, where `P` is either a NamedTuple or a Dict. - -Vector parameters ------------------ - -Vector parameters are more complicated, since it is not possible to directly implement -`DynamicPPL.init(rng, vn, dist, strategy)` for `strategy::InitFromParams{<:AbstractVector}`. -In particular, it is not clear: - - - which parts of the vector correspond to which random variables, and - - whether the variables are linked or unlinked. - -Traditionally, this problem has been solved by `unflatten`, because that function would -place values into the VarInfo's metadata alongside the information about ranges and linking. -However, we want to avoid doing this. Thus, here, we _extract this information from the -VarInfo_ a single time when constructing a `FastLDF` object. - -This creates a single struct, `ParamsWithRanges`, which contains: - - - the vector of parameters - - a mapping from VarNames to ranges in that vector, along with link status - -When evaluating the model, we create a `FastEvalVectorContext`, which reads parameter - -Note that this assumes that the ranges and link status are static throughout the lifetime of -the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable -numbers of parameters, or models which may visit random variables in different orders depending -on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a -general limitation of vectorised parameters: the original `unflatten` + `evaluate!!` -approach also fails with such models. -""" - using DynamicPPL: AbstractVarInfo, AccumulatorTuple, @@ -146,15 +84,53 @@ Note that it is undefined behaviour to access any of a `FastLDF`'s fields, apart ## Extended help -`FastLDF` uses `FastEvalVectorContext` internally to provide extremely rapid evaluation of -the model given a vector of parameters. +Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a +given set of parameters: + +1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters + inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. + +2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores + them inside a VarInfo's metadata. + +In general, both of these approaches work fine, but the fact that they modify the VarInfo's +metadata can often be quite wasteful. In particular, it is very common that the only outputs +we care about from model evaluation are those which are stored in accumulators, such as log +probability densities, or `ValuesAsInModel`. + +To avoid this issue, we use `OnlyAccsVarInfo`, which is a VarInfo that only contains +accumulators. It implements enough of the `AbstractVarInfo` interface to not error during +model evaluation. + +Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with +it, it is mandatory that parameters are provided from outside the VarInfo, namely via +`InitContext{<:InitFromParams}`. -Because it is common to call `LogDensityProblems.logdensity` and -`LogDensityProblems.logdensity_and_gradient` within tight loops, it is beneficial for us to -pre-compute as much of the information as possible when constructing the `FastLDF` object. -In particular, we use the provided VarInfo's metadata to extract the mapping from VarNames -to ranges and link status, and store this mapping inside the `FastLDF` object. We can later -use this to construct a FastEvalVectorContext, without having to look into a metadata again. +The main problem that we face is that it is not possible to directly implement +`DynamicPPL.init(rng, vn, dist, strategy)` for `strategy::InitFromParams{<:AbstractVector}`. +In particular, it is not clear: + + - which parts of the vector correspond to which random variables, and + - whether the variables are linked or unlinked. + +Traditionally, this problem has been solved by `unflatten`, because that function would +place values into the VarInfo's metadata alongside the information about ranges and linking. +That way, when we evaluate with `DefaultContext`, we can read this information out again. +However, we want to avoid doing this. Thus, here, we _extract this information from the +VarInfo_ a single time when constructing a `FastLDF` object. Inside the `FastLDF, we store: + + - the vector of parameters + - a mapping from VarNames to ranges in that vector, along with link status + +When evaluating the model, this allows us to create an `InitFromParams{VectorWithRanges}`, which +lets us very quickly read parameter values from the vector. + +Note that this assumes that the ranges and link status are static throughout the lifetime of +the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable +numbers of parameters, or models which may visit random variables in different orders depending +on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a +general limitation of vectorised parameters: the original `unflatten` + `evaluate!!` +approach also fails with such models. """ struct FastLDF{ M<:Model, @@ -285,8 +261,11 @@ end # Helper functions to extract ranges and link status # ###################################################### -# TODO: Fails for SimpleVarInfo. Do I really care enough? Ehhh, honestly, debatable. - +# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The +# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges +# and link status. So there is no motivation to use SimpleVarInfo inside a +# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue +# that there is no purpose in supporting untyped VarInfo either. """ get_ranges_and_linked(varinfo::VarInfo) From b58f90a0b4bfafb69331642560643c29183319de Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 20:02:41 +0000 Subject: [PATCH 45/57] Remove stray comment --- src/fasteval.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index 8bedd81f4..5b9b767df 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -142,7 +142,6 @@ struct FastLDF{ model::M adtype::AD _getlogdensity::F - # See FastLDFContext for explanation of these two fields. _iden_varname_ranges::N _varname_ranges::Dict{VarName,RangeAndLinked} _adprep::ADP From bf00ca1bc280248ea591302fefedecb7d90d6d4c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 20:40:37 +0000 Subject: [PATCH 46/57] export get_param_eltype (unfortunatley) --- src/DynamicPPL.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 36f6791a9..e9b902363 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -114,6 +114,7 @@ export AbstractVarInfo, InitFromUniform, InitFromParams, init, + get_param_eltype, # Pseudo distributions NamedDist, NoDist, From 6e8c1466049169d79076dcc7d8e1e27f73fba409 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 21:32:50 +0000 Subject: [PATCH 47/57] Add more comment --- src/contexts/init.jl | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index c32024adc..065a11f29 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -317,12 +317,21 @@ function tilde_assume!!( # is that the transform used by `VectorWithRanges` is `from_linked_VEC_transform`, # which is NOT the same as `inverse(link_transform)` (because there is an additional # vectorisation step). We need `init` and `tilde_assume!!` to share this information - # but it's not clear right now how to do this. In my opinion, the most productive - # way forward would be to clean up the behaviour of bijectors so that we can have a - # clean separation between the linking and vectorisation parts of it. That way, `x` - # can either be unlinked, unlinked vectorised, linked, or linked vectorised, and - # regardless of which it is, we should only need to apply at most one linking and - # one vectorisation transform. + # but it's not clear right now how to do this. In my opinion, there are a couple of + # potential ways forward: + # + # 1. Just remove metadata entirely so that there is never any need to construct + # a linked vectorised value again. This would require us to use VAIMAcc as the only + # way of getting values. I consider this the best option, but it might take a long + # time. + # + # 2. Clean up the behaviour of bijectors so that we can have a complete separation + # between the linking and vectorisation parts of it. That way, `x` can either be + # unlinked, unlinked vectorised, linked, or linked vectorised, and regardless of + # which it is, we should only need to apply at most one linking and one + # vectorisation transform. Doing so would allow us to remove the first call to + # `with_logabsdet_jacobian`, and instead compose and/or uncompose the + # transformations before calling `with_logabsdet_jacobian` once. y, -inv_logjac + fwd_logjac else x, -inv_logjac From 62a8746c20431cefe2514e4716d44d0364a01328 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 11 Nov 2025 21:37:43 +0000 Subject: [PATCH 48/57] Update comment --- src/onlyaccs.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/onlyaccs.jl b/src/onlyaccs.jl index 2e9f34177..6aff9d04d 100644 --- a/src/onlyaccs.jl +++ b/src/onlyaccs.jl @@ -2,14 +2,16 @@ OnlyAccsVarInfo This is a wrapper around an `AccumulatorTuple` that implements the minimal `AbstractVarInfo` -interface to work with the `accumulate_assume!!` and `accumulate_observe!!` functions. +interface to work with the `tilde_assume!!` and `tilde_observe!!` functions for +`InitContext`. Note that this does not implement almost every other AbstractVarInfo interface function, and -so using this outside of FastLDF will lead to errors. +so using attempting to use this with a different leaf context such as `DefaultContext` will +result in errors. Conceptually, one can also think of this as a VarInfo that doesn't contain a metadata field. -That is because values for random variables are obtained by reading from a separate entity -(such as a `FastLDFContext`), rather than from the VarInfo itself. +This is also why it only works with `InitContext`: in this case, the parameters used for +evaluation are supplied by the context instead of the metadata. """ struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo accs::Accs @@ -19,7 +21,7 @@ function OnlyAccsVarInfo(accs::NTuple{N,AbstractAccumulator}) where {N} return OnlyAccsVarInfo(AccumulatorTuple(accs)) end -# AbstractVarInfo interface +# Minimal AbstractVarInfo interface @inline DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi @inline DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs @inline DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = From 19af4cc3f4617ec0dbac52e6052945faa3456bd3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 12 Nov 2025 15:32:35 +0000 Subject: [PATCH 49/57] Remove inlines, fix OAVI docstring --- src/onlyaccs.jl | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/onlyaccs.jl b/src/onlyaccs.jl index 6aff9d04d..7aee7a3bc 100644 --- a/src/onlyaccs.jl +++ b/src/onlyaccs.jl @@ -6,8 +6,7 @@ interface to work with the `tilde_assume!!` and `tilde_observe!!` functions for `InitContext`. Note that this does not implement almost every other AbstractVarInfo interface function, and -so using attempting to use this with a different leaf context such as `DefaultContext` will -result in errors. +so using this with a different leaf context such as `DefaultContext` will result in errors. Conceptually, one can also think of this as a VarInfo that doesn't contain a metadata field. This is also why it only works with `InitContext`: in this case, the parameters used for @@ -22,10 +21,9 @@ function OnlyAccsVarInfo(accs::NTuple{N,AbstractAccumulator}) where {N} end # Minimal AbstractVarInfo interface -@inline DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi -@inline DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs -@inline DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = - OnlyAccsVarInfo(accs) -@inline Base.haskey(::OnlyAccsVarInfo, ::VarName) = false -@inline DynamicPPL.is_transformed(::OnlyAccsVarInfo) = false -@inline BangBang.push!!(vi::OnlyAccsVarInfo, vn, y, dist) = vi +DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi +DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs +DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) +Base.haskey(::OnlyAccsVarInfo, ::VarName) = false +DynamicPPL.is_transformed(::OnlyAccsVarInfo) = false +BangBang.push!!(vi::OnlyAccsVarInfo, vn, y, dist) = vi From 5f158e6e126f7537139e8d737a207205eadb72b3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 12 Nov 2025 15:49:21 +0000 Subject: [PATCH 50/57] Improve docstrings --- src/contexts/init.jl | 40 +++++++++++++++++++++++++--------------- src/fasteval.jl | 2 +- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 065a11f29..420c31f16 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -5,7 +5,7 @@ Abstract type representing the possible ways of initialising new values for the variables in a model (e.g., when creating a new VarInfo). Any subtype of `AbstractInitStrategy` must implement the [`DynamicPPL.init`](@ref) method, -and very rarely, [`DynamicPPL.get_param_eltype`](@ref). +and in some cases, [`DynamicPPL.get_param_eltype`](@ref) (see its docstring for details). """ abstract type AbstractInitStrategy end @@ -17,7 +17,6 @@ Generate a new value for a random variable with the given distribution. This function must return a tuple `(x, trf)`, where - `x` is the generated value - - `trf` is a function that transforms the generated value back to the unlinked space. If the value is already in unlinked space, then this should be `DynamicPPL.typed_identity`. You can also use `Base.identity`, but if you use this, you **must** be confident that @@ -35,26 +34,37 @@ The default implementation returns `Any`. However, for `InitFromParams` which pr parameters for evaluating the model, methods are implemented in order to return more specific types. -For the most part, a return value of `Any` will actually suffice. However, there are a few -edge cases in DynamicPPL where the element type is needed. These largely relate to -determining the element type of accumulators ahead of time (_before_ evaluation), as well as -promoting type parameters in model arguments. The classic case is when evaluating a model -with ForwardDiff: the accumulators must be set to `Dual`s, and any `Vector{Float64}` +In general, if you are implementing a custom `AbstractInitStrategy`, correct behaviour can +only be guaranteed if you implement this method as well. However, quite often, the default +return value of `Any` will actually suffice. The cases where this does *not* suffice, and +where you _do_ have to manually implement `get_param_eltype`, are explained in the extended +help (see `??DynamicPPL.get_param_eltype` in the REPL). + +# Extended help + +There are a few edge cases in DynamicPPL where the element type is needed. These largely +relate to determining the element type of accumulators ahead of time (_before_ evaluation), +as well as promoting type parameters in model arguments. The classic case is when evaluating +a model with ForwardDiff: the accumulators must be set to `Dual`s, and any `Vector{Float64}` arguments must be promoted to `Vector{Dual}`. Other tracer types, for example those in SparseConnectivityTracer.jl, also require similar treatment. -If `AbstractInitStrategy` is never used in combination with tracer types, then it is +If the `AbstractInitStrategy` is never used in combination with tracer types, then it is perfectly safe to return `Any`. This does not lead to type instability downstream because the actual accumulators will still be created with concrete Float types (the `Any` is just used to determine whether the float type needs to be modified). -(Detail: in fact, the above is not always true. Firstly, the accumulator argument is only -true when evaluating with ThreadSafeVarInfo. See the comments in `DynamicPPL.unflatten` for -more details. For non-threadsafe evaluation, Julia is capable of automatically promoting the -types on its own. Secondly, the promotion only matters if you are trying to directly assign -into a `Vector{Float64}` with a `ForwardDiff.Dual` or similar tracer type, for example using -`xs[i] = MyDual`. This doesn't actually apply to tilde-statements like `xs[i] ~ ...` because -those use `Accessors.@set` under the hood, which also does the promotion for you.) +In case that wasn't enough: in fact, even the above is not always true. Firstly, the +accumulator argument is only true when evaluating with ThreadSafeVarInfo. See the comments +in `DynamicPPL.unflatten` for more details. For non-threadsafe evaluation, Julia is capable +of automatically promoting the types on its own. Secondly, the promotion only matters if you +are trying to directly assign into a `Vector{Float64}` with a `ForwardDiff.Dual` or similar +tracer type, for example using `xs[i] = MyDual`. This doesn't actually apply to +tilde-statements like `xs[i] ~ ...` because those use `Accessors.@set` under the hood, which +also does the promotion for you. For the gory details, see the following issues: + +- https://github.com/TuringLang/DynamicPPL.jl/issues/906 for accumulator types +- https://github.com/TuringLang/DynamicPPL.jl/issues/823 for type argument promotion """ get_param_eltype(::AbstractInitStrategy) = Any diff --git a/src/fasteval.jl b/src/fasteval.jl index 5b9b767df..f71779d7a 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -82,7 +82,7 @@ Note that it is undefined behaviour to access any of a `FastLDF`'s fields, apart - `fastldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD type was provided. -## Extended help +# Extended help Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a given set of parameters: From 7d38b5cfb343511c91a590106b4d2292705e21e1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 12 Nov 2025 15:53:02 +0000 Subject: [PATCH 51/57] Simplify InitFromParams constructor --- src/contexts/init.jl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 420c31f16..a79969a13 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -164,13 +164,8 @@ in which case an error will be thrown. The default for `fallback` is `InitFromPr struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy params::P fallback::S - - function InitFromParams( - params::P, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() - ) where {P} - return new{P,typeof(fallback)}(params, fallback) - end end +InitFromParams(params) = InitFromParams(params, InitFromPrior()) function init( rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams{P} From 7af9c2ec546fe5fb1f21c3251d923fdc93dcabbd Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 12 Nov 2025 15:53:09 +0000 Subject: [PATCH 52/57] Replace map(identity, x[:]) with [i for i in x[:]] --- test/fasteval.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/fasteval.jl b/test/fasteval.jl index f1c535643..66fea093a 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -37,7 +37,7 @@ end unlinked_vi end nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) - params = map(identity, vi[:]) + params = [x for x in vi[:]] # Iterate over all variables for vn in keys(vi) # Check that `getindex_internal` returns the same thing as using the ranges @@ -147,7 +147,7 @@ end varinfo = VarInfo(m) linked_varinfo = DynamicPPL.link(varinfo, m) f = FastLDF(m, getlogjoint_internal, linked_varinfo) - x = map(identity, linked_varinfo[:]) + x = [p for p in linked_varinfo[:]] # Calculate reference logp + gradient of logp using ForwardDiff ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) From ff0065554015bd21db8f3d759bc6d6a5d0800196 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 12 Nov 2025 16:07:37 +0000 Subject: [PATCH 53/57] Simplify implementation for InitContext/OAVI --- src/contexts/init.jl | 4 ++-- src/onlyaccs.jl | 19 ++++++++++++++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index a79969a13..52921111d 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -174,8 +174,8 @@ function init( # of the parameters in `p.params` were actually used, and either warn or # error if they aren't. This is actually quite non-trivial though because # the structure of Dicts in particular can have arbitrary nesting. - return if hasvalue(p.params, vn, dist) - x = getvalue(p.params, vn, dist) + return if hasvalue(p.params, vn) + x = getvalue(p.params, vn) if x === missing p.fallback === nothing && error("A `missing` value was provided for the variable `$(vn)`.") diff --git a/src/onlyaccs.jl b/src/onlyaccs.jl index 7aee7a3bc..940f23124 100644 --- a/src/onlyaccs.jl +++ b/src/onlyaccs.jl @@ -24,6 +24,19 @@ end DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) -Base.haskey(::OnlyAccsVarInfo, ::VarName) = false -DynamicPPL.is_transformed(::OnlyAccsVarInfo) = false -BangBang.push!!(vi::OnlyAccsVarInfo, vn, y, dist) = vi + +# Ideally, we'd define this together with InitContext, but alas that file comes way before +# this one, and sorting out the include order is a pain. +function tilde_assume!!( + ctx::InitContext, + dist::Distribution, + vn::VarName, + vi::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, +) + # For OnlyAccsVarInfo, since we don't need to write into the VarInfo, we can + # cut out a lot of the code above. + val, transform = init(ctx.rng, vn, dist, ctx.strategy) + x, inv_logjac = with_logabsdet_jacobian(transform, val) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist) + return x, vi +end From 565caa3e23ef5d51b052fa735fb5d0bfa05b52b2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 12 Nov 2025 16:09:47 +0000 Subject: [PATCH 54/57] Add another model to allocation tests Co-authored-by: Markus Hauru --- test/fasteval.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/fasteval.jl b/test/fasteval.jl index 66fea093a..065a28dfc 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -116,7 +116,7 @@ end y ~ Normal(params.m, params.s) return 1.0 ~ Normal(y) end - @testset for model in (f(), submodel_inner(), submodel_outer(submodel_inner())) + @testset for model in (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) vi = VarInfo(model) fldf = DynamicPPL.Experimental.FastLDF( model, DynamicPPL.getlogjoint_internal, vi From f6abcd254d05e7f4d6f4e3a34f00e5b59c562f19 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 12 Nov 2025 16:10:16 +0000 Subject: [PATCH 55/57] Revert removal of dist argument (oops) --- src/contexts/init.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 52921111d..a79969a13 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -174,8 +174,8 @@ function init( # of the parameters in `p.params` were actually used, and either warn or # error if they aren't. This is actually quite non-trivial though because # the structure of Dicts in particular can have arbitrary nesting. - return if hasvalue(p.params, vn) - x = getvalue(p.params, vn) + return if hasvalue(p.params, vn, dist) + x = getvalue(p.params, vn, dist) if x === missing p.fallback === nothing && error("A `missing` value was provided for the variable `$(vn)`.") From f72728093cc6a5e634702bd294382b99a5c3b3bf Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 12 Nov 2025 16:12:38 +0000 Subject: [PATCH 56/57] Format --- test/fasteval.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/fasteval.jl b/test/fasteval.jl index 065a28dfc..db2333711 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -116,7 +116,8 @@ end y ~ Normal(params.m, params.s) return 1.0 ~ Normal(y) end - @testset for model in (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) + @testset for model in + (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) vi = VarInfo(model) fldf = DynamicPPL.Experimental.FastLDF( model, DynamicPPL.getlogjoint_internal, vi From 5e003a86fe866ce0fead27f0cc044e9f803ea02d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 12 Nov 2025 16:26:51 +0000 Subject: [PATCH 57/57] Update some outdated bits of FastLDF docstring --- src/fasteval.jl | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index f71779d7a..082763b85 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -104,7 +104,7 @@ model evaluation. Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with it, it is mandatory that parameters are provided from outside the VarInfo, namely via -`InitContext{<:InitFromParams}`. +`InitContext`. The main problem that we face is that it is not possible to directly implement `DynamicPPL.init(rng, vn, dist, strategy)` for `strategy::InitFromParams{<:AbstractVector}`. @@ -116,14 +116,18 @@ In particular, it is not clear: Traditionally, this problem has been solved by `unflatten`, because that function would place values into the VarInfo's metadata alongside the information about ranges and linking. That way, when we evaluate with `DefaultContext`, we can read this information out again. -However, we want to avoid doing this. Thus, here, we _extract this information from the -VarInfo_ a single time when constructing a `FastLDF` object. Inside the `FastLDF, we store: - - - the vector of parameters - - a mapping from VarNames to ranges in that vector, along with link status - -When evaluating the model, this allows us to create an `InitFromParams{VectorWithRanges}`, which -lets us very quickly read parameter values from the vector. +However, we want to avoid using a metadata. Thus, here, we _extract this information from +the VarInfo_ a single time when constructing a `FastLDF` object. Inside the FastLDF, we +store a mapping from VarNames to ranges in that vector, along with link status. + +For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all +other VarNames, this is stored in a Dict. The internal data structure used to represent this +could almost certainly be optimised further. See e.g. the discussion in +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. + +When evaluating the model, this allows us to combine the parameter vector together with those +ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quickly read +parameter values from the vector. Note that this assumes that the ranges and link status are static throughout the lifetime of the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable