@@ -81,6 +81,7 @@ using DynamicPPL:
8181 getlogprior_internal,
8282 leafcontext
8383using ADTypes: ADTypes
84+ using BangBang: BangBang
8485using Bijectors: with_logabsdet_jacobian
8586using AbstractPPL: AbstractPPL, VarName
8687using Distributions: Distribution
@@ -108,6 +109,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators())
108109DynamicPPL. maybe_invlink_before_eval!! (vi:: OnlyAccsVarInfo , :: Model ) = vi
109110DynamicPPL. getaccs (vi:: OnlyAccsVarInfo ) = vi. accs
110111DynamicPPL. setaccs!! (:: OnlyAccsVarInfo , accs:: AccumulatorTuple ) = OnlyAccsVarInfo (accs)
112+ @inline Base. haskey (:: OnlyAccsVarInfo , :: VarName ) = false
113+ @inline DynamicPPL. is_transformed (:: OnlyAccsVarInfo ) = false
114+ @inline BangBang. push!! (vi:: OnlyAccsVarInfo , vn, y, dist) = vi
111115function DynamicPPL. get_param_eltype (
112116 :: Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}} , model:: Model
113117)
@@ -117,14 +121,11 @@ function DynamicPPL.get_param_eltype(
117121 leaf_ctx = DynamicPPL. leafcontext (model. context)
118122 if leaf_ctx isa FastEvalVectorContext
119123 return eltype (leaf_ctx. params)
124+ elseif leaf_ctx isa InitContext{<: Any ,<: InitFromParams }
125+ eltype = DynamicPPL. infer_nested_eltype (leaf_ctx. strategy. params)
126+ @info " Inferring parameter eltype as $eltype from InitContext"
127+ return eltype
120128 else
121- # TODO (penelopeysm): In principle this can be done with InitContext{InitWithParams}.
122- # See also `src/simple_varinfo.jl` where `infer_nested_eltype` is used to try to
123- # figure out the parameter type from a NamedTuple or Dict. The benefit of
124- # implementing this for InitContext is that we could then use OnlyAccsVarInfo with
125- # it, which means fast evaluation with NamedTuple or Dict parameters! And I believe
126- # that Mooncake / Enzyme should be able to differentiate through that too and
127- # provide a NamedTuple of gradients (although I haven't tested this yet).
128129 error (
129130 " OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof (leaf_ctx)) " ,
130131 )
0 commit comments