@@ -48,7 +48,7 @@ function update!(tree, model, grad)
4848 dict = IdDict {Leaf, Any} ()
4949 grads! (dict, tree, model, grad)
5050 # Second walk is to update the model. The walk taken follows Leaf identity
51- newmodel = fmap (tree, model; exclude = ℓ -> ℓ isa Leaf, walk = _second_walk) do ℓ, x
51+ newmodel = fmap (tree, model; exclude = ℓ -> ℓ isa Leaf, walk = _second_walk, cache = LeafCache () ) do ℓ, x
5252 ℓ. frozen && return x
5353 haskey (dict, ℓ) || return x # no gradient seen, nothing to do
5454 s′, x̄′ = apply! (ℓ. rule, ℓ. state, x, dict[ℓ])
@@ -88,6 +88,21 @@ function _second_walk(f, x, y)
8888 re (map (f, x′, y′))
8989end
9090
91+ # When fmap reconstructs for update!, it should not cache results with trivial nodes like () in the state.
92+ # This cache type has just enough methods to work in Functors, which possibly should be upgraded to just work.
93+ struct LeafCache <: AbstractDict{Leaf,Any}
94+ dict:: IdDict{Leaf,Any}
95+ end
96+ LeafCache () = LeafCache (IdDict {Leaf,Any} ())
97+
98+ Base. setindex! (c:: LeafCache , x, ℓ:: Leaf ) = setindex! (c. dict, x, ℓ)
99+ Base. setindex! (c:: LeafCache , x, _) = nothing
100+ Base. in (k, c:: LeafCache ) = k in c. dict
101+ Base. haskey (c:: LeafCache , k) = haskey (c. dict, k)
102+ Base. getindex (c:: LeafCache , ℓ:: Leaf ) = getindex (c. dict, ℓ)
103+ Base. iterate (c:: LeafCache , i = 0 ) = iterate (c. dict, i)
104+ Base. length (c:: LeafCache ) = length (c. dict)
105+
91106# default all rules to first order calls
92107apply! (o, state, x, dx, dxs... ) = apply! (o, state, x, dx)
93108
0 commit comments