Skip to content

Commit ecca1af

Browse files
committed
Fast InitContext (#1125)
* Make InitContext work with OnlyAccsVarInfo * Do not convert NamedTuple to Dict * remove logging * Enable InitFromPrior and InitFromUniform too * Fix `infer_nested_eltype` invocation
1 parent b060f2f commit ecca1af

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

src/contexts/init.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitS
102102
function InitFromParams(
103103
params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior()
104104
)
105-
return InitFromParams(to_varname_dict(params), fallback)
105+
return new{typeof(params),typeof(fallback)}(params, fallback)
106106
end
107107
end
108108
function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams)

src/fasteval.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ using DynamicPPL:
6060
AbstractContext,
6161
AbstractVarInfo,
6262
AccumulatorTuple,
63+
InitContext,
64+
InitFromParams,
65+
InitFromPrior,
66+
InitFromUniform,
6367
LogJacobianAccumulator,
6468
LogLikelihoodAccumulator,
6569
LogPriorAccumulator,
@@ -81,6 +85,7 @@ using DynamicPPL:
8185
getlogprior_internal,
8286
leafcontext
8387
using ADTypes: ADTypes
88+
using BangBang: BangBang
8489
using Bijectors: with_logabsdet_jacobian
8590
using AbstractPPL: AbstractPPL, VarName
8691
using Distributions: Distribution
@@ -108,6 +113,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators())
108113
DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi
109114
DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs
110115
DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs)
116+
@inline Base.haskey(::OnlyAccsVarInfo, ::VarName) = false
117+
@inline DynamicPPL.is_transformed(::OnlyAccsVarInfo) = false
118+
@inline BangBang.push!!(vi::OnlyAccsVarInfo, vn, y, dist) = vi
111119
function DynamicPPL.get_param_eltype(
112120
::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, model::Model
113121
)
@@ -117,14 +125,12 @@ function DynamicPPL.get_param_eltype(
117125
leaf_ctx = DynamicPPL.leafcontext(model.context)
118126
if leaf_ctx isa FastEvalVectorContext
119127
return eltype(leaf_ctx.params)
128+
elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams}
129+
return DynamicPPL.infer_nested_eltype(typeof(leaf_ctx.strategy.params))
130+
elseif leaf_ctx isa InitContext{<:Any,<:Union{InitFromPrior,InitFromUniform}}
131+
# No need to enforce any particular eltype here, since new parameters are sampled
132+
return Any
120133
else
121-
# TODO(penelopeysm): In principle this can be done with InitContext{InitWithParams}.
122-
# See also `src/simple_varinfo.jl` where `infer_nested_eltype` is used to try to
123-
# figure out the parameter type from a NamedTuple or Dict. The benefit of
124-
# implementing this for InitContext is that we could then use OnlyAccsVarInfo with
125-
# it, which means fast evaluation with NamedTuple or Dict parameters! And I believe
126-
# that Mooncake / Enzyme should be able to differentiate through that too and
127-
# provide a NamedTuple of gradients (although I haven't tested this yet).
128134
error(
129135
"OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))",
130136
)

0 commit comments

Comments
 (0)