Skip to content

Commit a494d00

Browse files
committed
Make InitContext work with OnlyAccsVarInfo
1 parent 4ec0c72 commit a494d00

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

src/fasteval.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ using DynamicPPL:
8181
getlogprior_internal,
8282
leafcontext
8383
using ADTypes: ADTypes
84+
using BangBang: BangBang
8485
using Bijectors: with_logabsdet_jacobian
8586
using AbstractPPL: AbstractPPL, VarName
8687
using Distributions: Distribution
@@ -108,6 +109,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators())
108109
DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi
109110
DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs
110111
DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs)
112+
@inline Base.haskey(::OnlyAccsVarInfo, ::VarName) = false
113+
@inline DynamicPPL.is_transformed(::OnlyAccsVarInfo) = false
114+
@inline BangBang.push!!(vi::OnlyAccsVarInfo, vn, y, dist) = vi
111115
function DynamicPPL.get_param_eltype(
112116
::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, model::Model
113117
)
@@ -117,14 +121,11 @@ function DynamicPPL.get_param_eltype(
117121
leaf_ctx = DynamicPPL.leafcontext(model.context)
118122
if leaf_ctx isa FastEvalVectorContext
119123
return eltype(leaf_ctx.params)
124+
elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams}
125+
eltype = DynamicPPL.infer_nested_eltype(leaf_ctx.strategy.params)
126+
@info "Inferring parameter eltype as $eltype from InitContext"
127+
return eltype
120128
else
121-
# TODO(penelopeysm): In principle this can be done with InitContext{InitWithParams}.
122-
# See also `src/simple_varinfo.jl` where `infer_nested_eltype` is used to try to
123-
# figure out the parameter type from a NamedTuple or Dict. The benefit of
124-
# implementing this for InitContext is that we could then use OnlyAccsVarInfo with
125-
# it, which means fast evaluation with NamedTuple or Dict parameters! And I believe
126-
# that Mooncake / Enzyme should be able to differentiate through that too and
127-
# provide a NamedTuple of gradients (although I haven't tested this yet).
128129
error(
129130
"OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))",
130131
)

0 commit comments

Comments
 (0)