diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 44dbc5508..efc6f1087 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -102,7 +102,7 @@ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitS function InitFromParams( params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() ) - return InitFromParams(to_varname_dict(params), fallback) + return new{typeof(params),typeof(fallback)}(params, fallback) end end function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) diff --git a/src/fasteval.jl b/src/fasteval.jl index c91254d43..fbc6a61ce 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -60,6 +60,10 @@ using DynamicPPL: AbstractContext, AbstractVarInfo, AccumulatorTuple, + InitContext, + InitFromParams, + InitFromPrior, + InitFromUniform, LogJacobianAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, @@ -81,6 +85,7 @@ using DynamicPPL: getlogprior_internal, leafcontext using ADTypes: ADTypes +using BangBang: BangBang using Bijectors: with_logabsdet_jacobian using AbstractPPL: AbstractPPL, VarName using Distributions: Distribution @@ -108,6 +113,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) +@inline Base.haskey(::OnlyAccsVarInfo, ::VarName) = false +@inline DynamicPPL.is_transformed(::OnlyAccsVarInfo) = false +@inline BangBang.push!!(vi::OnlyAccsVarInfo, vn, y, dist) = vi function DynamicPPL.get_param_eltype( ::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, model::Model ) @@ -117,14 +125,12 @@ function DynamicPPL.get_param_eltype( leaf_ctx = DynamicPPL.leafcontext(model.context) if leaf_ctx isa FastEvalVectorContext return eltype(leaf_ctx.params) + elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams} + return DynamicPPL.infer_nested_eltype(typeof(leaf_ctx.strategy.params)) + elseif leaf_ctx isa InitContext{<:Any,<:Union{InitFromPrior,InitFromUniform}} + # No need to enforce any particular eltype here, since new parameters are sampled + return Any else - # TODO(penelopeysm): In principle this can be done with InitContext{InitWithParams}. - # See also `src/simple_varinfo.jl` where `infer_nested_eltype` is used to try to - # figure out the parameter type from a NamedTuple or Dict. The benefit of - # implementing this for InitContext is that we could then use OnlyAccsVarInfo with - # it, which means fast evaluation with NamedTuple or Dict parameters! And I believe - # that Mooncake / Enzyme should be able to differentiate through that too and - # provide a NamedTuple of gradients (although I haven't tested this yet). error( "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", )