From 425addb2ef67404b5992b6416bc305d24b1976f6 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 7 Nov 2025 17:09:01 +0000 Subject: [PATCH 1/5] Make InitContext work with OnlyAccsVarInfo --- src/fasteval.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index c91254d43..dcdbaa608 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -81,6 +81,7 @@ using DynamicPPL: getlogprior_internal, leafcontext using ADTypes: ADTypes +using BangBang: BangBang using Bijectors: with_logabsdet_jacobian using AbstractPPL: AbstractPPL, VarName using Distributions: Distribution @@ -108,6 +109,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) +@inline Base.haskey(::OnlyAccsVarInfo, ::VarName) = false +@inline DynamicPPL.is_transformed(::OnlyAccsVarInfo) = false +@inline BangBang.push!!(vi::OnlyAccsVarInfo, vn, y, dist) = vi function DynamicPPL.get_param_eltype( ::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, model::Model ) @@ -117,14 +121,11 @@ function DynamicPPL.get_param_eltype( leaf_ctx = DynamicPPL.leafcontext(model.context) if leaf_ctx isa FastEvalVectorContext return eltype(leaf_ctx.params) + elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams} + eltype = DynamicPPL.infer_nested_eltype(leaf_ctx.strategy.params) + @info "Inferring parameter eltype as $eltype from InitContext" + return eltype else - # TODO(penelopeysm): In principle this can be done with InitContext{InitWithParams}. - # See also `src/simple_varinfo.jl` where `infer_nested_eltype` is used to try to - # figure out the parameter type from a NamedTuple or Dict. The benefit of - # implementing this for InitContext is that we could then use OnlyAccsVarInfo with - # it, which means fast evaluation with NamedTuple or Dict parameters! And I believe - # that Mooncake / Enzyme should be able to differentiate through that too and - # provide a NamedTuple of gradients (although I haven't tested this yet). error( "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", ) From 9578bce6e2767eb1a62c019220e750666e56bd76 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 7 Nov 2025 17:26:12 +0000 Subject: [PATCH 2/5] Do not convert NamedTuple to Dict --- src/contexts/init.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 44dbc5508..efc6f1087 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -102,7 +102,7 @@ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitS function InitFromParams( params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() ) - return InitFromParams(to_varname_dict(params), fallback) + return new{typeof(params),typeof(fallback)}(params, fallback) end end function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) From f9d61d6ca96f7dfecd6ee3e78129346ceda51026 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 7 Nov 2025 17:34:18 +0000 Subject: [PATCH 3/5] remove logging --- src/fasteval.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index dcdbaa608..3b1ae2550 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -122,9 +122,7 @@ function DynamicPPL.get_param_eltype( if leaf_ctx isa FastEvalVectorContext return eltype(leaf_ctx.params) elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams} - eltype = DynamicPPL.infer_nested_eltype(leaf_ctx.strategy.params) - @info "Inferring parameter eltype as $eltype from InitContext" - return eltype + return DynamicPPL.infer_nested_eltype(leaf_ctx.strategy.params) else error( "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", From cebfcbe379502dcada752989fca1ae89e8b86d83 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 8 Nov 2025 17:24:35 +0000 Subject: [PATCH 4/5] Enable InitFromPrior and InitFromUniform too --- src/fasteval.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/fasteval.jl b/src/fasteval.jl index 3b1ae2550..2e1bccdc6 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -60,6 +60,10 @@ using DynamicPPL: AbstractContext, AbstractVarInfo, AccumulatorTuple, + InitContext, + InitFromParams, + InitFromPrior, + InitFromUniform, LogJacobianAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, @@ -123,6 +127,9 @@ function DynamicPPL.get_param_eltype( return eltype(leaf_ctx.params) elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams} return DynamicPPL.infer_nested_eltype(leaf_ctx.strategy.params) + elseif leaf_ctx isa InitContext{<:Any,<:Union{InitFromPrior,InitFromUniform}} + # No need to enforce any particular eltype here, since new parameters are sampled + return Any else error( "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", From c35dff5c54c0ac91b185110276ef8788cb17d94b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 8 Nov 2025 18:27:36 +0000 Subject: [PATCH 5/5] Fix `infer_nested_eltype` invocation --- src/fasteval.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index 2e1bccdc6..fbc6a61ce 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -126,7 +126,7 @@ function DynamicPPL.get_param_eltype( if leaf_ctx isa FastEvalVectorContext return eltype(leaf_ctx.params) elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams} - return DynamicPPL.infer_nested_eltype(leaf_ctx.strategy.params) + return DynamicPPL.infer_nested_eltype(typeof(leaf_ctx.strategy.params)) elseif leaf_ctx isa InitContext{<:Any,<:Union{InitFromPrior,InitFromUniform}} # No need to enforce any particular eltype here, since new parameters are sampled return Any