Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitS
function InitFromParams(
params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior()
)
return InitFromParams(to_varname_dict(params), fallback)
return new{typeof(params),typeof(fallback)}(params, fallback)
Copy link
Member Author

@penelopeysm penelopeysm Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code, which is present on main, is actually (sadly) quite bad code (written by yours truly). For the trivial model, it had the effect of making simple NamedTuples 10x slower because to_varname_dict returns a Dict{VarName,Any}, which is too loosely typed and leads to slow lookups etc.

DynamicPPL.jl/src/utils.jl

Lines 851 to 855 in 08fffa2

# Convert (x=1,) to Dict(@varname(x) => 1)
function to_varname_dict(nt::NamedTuple)
return Dict{VarName,Any}(VarName{k}() => v for (k, v) in pairs(nt))
end
to_varname_dict(d::AbstractDict) = d

end
end
function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams)
Expand Down
20 changes: 13 additions & 7 deletions src/fasteval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ using DynamicPPL:
AbstractContext,
AbstractVarInfo,
AccumulatorTuple,
InitContext,
InitFromParams,
InitFromPrior,
InitFromUniform,
LogJacobianAccumulator,
LogLikelihoodAccumulator,
LogPriorAccumulator,
Expand All @@ -81,6 +85,7 @@ using DynamicPPL:
getlogprior_internal,
leafcontext
using ADTypes: ADTypes
using BangBang: BangBang
using Bijectors: with_logabsdet_jacobian
using AbstractPPL: AbstractPPL, VarName
using Distributions: Distribution
Expand Down Expand Up @@ -108,6 +113,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators())
DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi
DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs
DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs)
@inline Base.haskey(::OnlyAccsVarInfo, ::VarName) = false
@inline DynamicPPL.is_transformed(::OnlyAccsVarInfo) = false
@inline BangBang.push!!(vi::OnlyAccsVarInfo, vn, y, dist) = vi
function DynamicPPL.get_param_eltype(
::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, model::Model
)
Expand All @@ -117,14 +125,12 @@ function DynamicPPL.get_param_eltype(
leaf_ctx = DynamicPPL.leafcontext(model.context)
if leaf_ctx isa FastEvalVectorContext
return eltype(leaf_ctx.params)
elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams}
return DynamicPPL.infer_nested_eltype(typeof(leaf_ctx.strategy.params))
elseif leaf_ctx isa InitContext{<:Any,<:Union{InitFromPrior,InitFromUniform}}
# No need to enforce any particular eltype here, since new parameters are sampled
return Any
else
# TODO(penelopeysm): In principle this can be done with InitContext{InitWithParams}.
# See also `src/simple_varinfo.jl` where `infer_nested_eltype` is used to try to
# figure out the parameter type from a NamedTuple or Dict. The benefit of
# implementing this for InitContext is that we could then use OnlyAccsVarInfo with
# it, which means fast evaluation with NamedTuple or Dict parameters! And I believe
# that Mooncake / Enzyme should be able to differentiate through that too and
# provide a NamedTuple of gradients (although I haven't tested this yet).
error(
"OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))",
)
Expand Down
Loading