@@ -10,10 +10,10 @@ abstract type AbstractRule end
1010# ## setup
1111# ##
1212
13- mutable struct Leaf{R,S}
13+ mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing
1414 rule:: R
1515 state:: S
16- frozen:: Bool
16+ frozen:: Bool # mutability also allows this flag to be changed
1717end
1818
1919@functor Leaf
4545function update! (tree, model, grad)
4646 # First walk is to accumulate the gradient. This recursion visits every copy of
4747 # shared leaves, but stops when branches are absent from the gradient:
48- gdict = IdDict {Leaf, Any} ()
49- grads! (gdict, tree, model, grad)
50- # Second walk is to update the model, using same fmap walk as setup:
51- xdict = IdDict {Leaf, Any} () # (this exists to allow for shared ℓ without shared x)
52- newmodel = fmap (model, tree; exclude = isnumeric) do x, ℓ
53- ℓ isa Leaf || error (" this state does not match the model, expected a Leaf here" )
48+ dict = IdDict {Leaf, Any} ()
49+ grads! (dict, tree, model, grad)
50+ # Second walk is to update the model. The walk taken follows Leaf identity
51+ newmodel = fmap (tree, model; exclude = ℓ -> ℓ isa Leaf, walk = _second_walk) do ℓ, x
5452 ℓ. frozen && return x
55- haskey (gdict, ℓ) || return x # no gradient seen, nothing to do
56- if haskey (xdict, ℓ)
57- # This means that shared ℓ encodes sharing not noted in x. Won't happen with setup above, no API yet.
58- x′ = xdict[ℓ] # ... and is why xdict exists.
59- size (x′) == size (x) || error (" the same Leaf belongs to arrays of size $(size (x)) and $(size (x′)) " )
60- return x′
61- end
62- s′, x̄′ = apply! (ℓ. rule, ℓ. state, x, gdict[ℓ])
53+ haskey (dict, ℓ) || return x # no gradient seen, nothing to do
54+ s′, x̄′ = apply! (ℓ. rule, ℓ. state, x, dict[ℓ])
6355 ℓ. state = s′ # to get state out of here, rely on mutability of Leaf
64- xdict[ℓ] = subtract! (x, x̄′)
56+ subtract! (x, x̄′)
6557 end
6658 tree, newmodel # note that tree is guaranteed to be updated
6759end
@@ -89,6 +81,13 @@ function update(tree, x, x̄s...)
8981 update! (t′, x′, x̄s... )
9082end
9183
84+ # This differs from _default_walk(f,x,y) in taking re from 2nd argument, but cache will still operate on the first
85+ function _second_walk (f, x, y)
86+ x′, _ = functor (typeof (y), x)
87+ y′, re = functor (y)
88+ re (map (f, x′, y′))
89+ end
90+
9291# default all rules to first order calls
9392apply! (o, state, x, dx, dxs... ) = apply! (o, state, x, dx)
9493
0 commit comments