@@ -60,6 +60,10 @@ using DynamicPPL:
6060 AbstractContext,
6161 AbstractVarInfo,
6262 AccumulatorTuple,
63+ InitContext,
64+ InitFromParams,
65+ InitFromPrior,
66+ InitFromUniform,
6367 LogJacobianAccumulator,
6468 LogLikelihoodAccumulator,
6569 LogPriorAccumulator,
@@ -81,6 +85,7 @@ using DynamicPPL:
8185 getlogprior_internal,
8286 leafcontext
8387using ADTypes: ADTypes
88+ using BangBang: BangBang
8489using Bijectors: with_logabsdet_jacobian
8590using AbstractPPL: AbstractPPL, VarName
8691using Distributions: Distribution
@@ -108,6 +113,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators())
108113DynamicPPL. maybe_invlink_before_eval!! (vi:: OnlyAccsVarInfo , :: Model ) = vi
109114DynamicPPL. getaccs (vi:: OnlyAccsVarInfo ) = vi. accs
110115DynamicPPL. setaccs!! (:: OnlyAccsVarInfo , accs:: AccumulatorTuple ) = OnlyAccsVarInfo (accs)
116+ @inline Base. haskey (:: OnlyAccsVarInfo , :: VarName ) = false
117+ @inline DynamicPPL. is_transformed (:: OnlyAccsVarInfo ) = false
118+ @inline BangBang. push!! (vi:: OnlyAccsVarInfo , vn, y, dist) = vi
111119function DynamicPPL. get_param_eltype (
112120 :: Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}} , model:: Model
113121)
@@ -117,14 +125,12 @@ function DynamicPPL.get_param_eltype(
117125 leaf_ctx = DynamicPPL. leafcontext (model. context)
118126 if leaf_ctx isa FastEvalVectorContext
119127 return eltype (leaf_ctx. params)
128+ elseif leaf_ctx isa InitContext{<: Any ,<: InitFromParams }
129+ return DynamicPPL. infer_nested_eltype (typeof (leaf_ctx. strategy. params))
130+ elseif leaf_ctx isa InitContext{<: Any ,<: Union{InitFromPrior,InitFromUniform} }
131+ # No need to enforce any particular eltype here, since new parameters are sampled
132+ return Any
120133 else
121- # TODO (penelopeysm): In principle this can be done with InitContext{InitWithParams}.
122- # See also `src/simple_varinfo.jl` where `infer_nested_eltype` is used to try to
123- # figure out the parameter type from a NamedTuple or Dict. The benefit of
124- # implementing this for InitContext is that we could then use OnlyAccsVarInfo with
125- # it, which means fast evaluation with NamedTuple or Dict parameters! And I believe
126- # that Mooncake / Enzyme should be able to differentiate through that too and
127- # provide a NamedTuple of gradients (although I haven't tested this yet).
128134 error (
129135 " OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof (leaf_ctx)) " ,
130136 )
0 commit comments