|
20 | 20 | Base.:(==)(a::Leaf, b::Leaf) = children(a) == children(b) |
21 | 21 |
|
22 | 22 | function setup(rule::AbstractRule, model) |
23 | | - cnt = Ref(0) |
24 | | - # Rely on Functors to identify shared arrays, they will share a Leaf in this tree: |
25 | | - tree = fmapstructure(model, exclude = isnumeric) do x |
26 | | - cnt[] += 1 |
27 | | - Leaf(rule, init(rule, x)) |
28 | | - end |
29 | | - cnt[] == 0 && @warn "setup found no parameters in the given model" |
| 23 | + cache = IdDict() |
| 24 | + tree = _setup(rule, model; cache) |
| 25 | + isempty(cache) && @warn "setup found no trainable parameters in this model" |
30 | 26 | tree |
31 | 27 | end |
32 | 28 |
|
| 29 | +# _setup is almost fmapstructure, but needs a _trainable_walk, and a cache which ignores numbers etc. |
| 30 | +function _setup(rule, x; cache) |
| 31 | + haskey(cache, x) && return cache[x] |
| 32 | + if isnumeric(x) |
| 33 | + ℓ = Leaf(rule, init(rule, x)) |
| 34 | + if isbits(x) |
| 35 | + cache[nothing] = nothing # just to disable the warning |
| 36 | + ℓ |
| 37 | + else |
| 38 | + cache[x] = ℓ |
| 39 | + end |
| 40 | + else |
| 41 | + map(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x)) |
| 42 | + end |
| 43 | +end |
| 44 | + |
33 | 45 | function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type! |
34 | 46 | ioc = IOContext(io, :compact => true) |
35 | 47 | print(ioc, "Leaf(", ℓ.rule, ", ") |
|
41 | 53 | ### update |
42 | 54 | ### |
43 | 55 |
|
44 | | -function update!(tree, model, grad) |
| 56 | +function update(tree, model, grad, higher...) |
| 57 | + t′ = fmap(copy, tree; exclude = maywrite) # walks inside Leaf |
| 58 | + x′ = fmap(copy, model; exclude = maywrite) |
| 59 | + update!(t′, x′, grad, higher...) |
| 60 | +end |
| 61 | + |
| 62 | +function update!(tree, model, grad, higher...) |
45 | 63 | # First walk is to accumulate the gradient. This recursion visits every copy of |
46 | 64 | # shared leaves, but stops when branches are absent from the gradient: |
47 | | - dict = IdDict{Leaf, Any}() |
48 | | - grads!(dict, tree, model, grad) |
49 | | - # Second walk is to update the model. The walk taken follows Leaf identity |
50 | | - newmodel = fmap(tree, model; exclude = ℓ -> ℓ isa Leaf, walk = _second_walk, cache = LeafCache()) do ℓ, x |
51 | | - haskey(dict, ℓ) || return x # no gradient seen, nothing to do |
52 | | - s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, dict[ℓ]) |
53 | | - ℓ.state = s′ # to get state out of here, rely on mutability of Leaf |
| 65 | + grads = IdDict{Leaf, Any}() |
| 66 | + _grads!(grads, tree, model, grad, higher...) |
| 67 | + # Second walk is to update the model. The params cache indexed by (tree,x), |
| 68 | + # so that identified Leafs can tie isbits parameters, but setup won't do that for you: |
| 69 | + newmodel = _update!(tree, model; grads, params = IdDict()) |
| 70 | + tree, newmodel # note that tree is guaranteed to be updated. Also that it's not necc a tree. |
| 71 | +end |
| 72 | + |
| 73 | +function _update!(tree, x; grads, params) |
| 74 | + haskey(params, (tree,x)) && return params[(tree,x)] |
| 75 | + isbits(tree) && return x # means () is not cached, and also (((),),) |
| 76 | + x′, re = functor(x) |
| 77 | + x′′ = map((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′) |
| 78 | + params[(tree,x)] = re(x′′) |
| 79 | +end |
| 80 | +function _update!(ℓ::Leaf, x; grads, params) |
| 81 | + haskey(params, (ℓ,x)) && return params[(ℓ,x)] |
| 82 | + params[(ℓ,x)] = if haskey(grads, ℓ) |
| 83 | + ℓ.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, grads[ℓ]...) |
54 | 84 | subtract!(x, x̄′) |
| 85 | + else |
| 86 | + x # no gradient seen |
55 | 87 | end |
56 | | - tree, newmodel # note that tree is guaranteed to be updated |
57 | 88 | end |
58 | 89 |
|
59 | 90 | subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄) |
60 | 91 |
|
61 | | -grads!(dict::IdDict, ℓ::Leaf, x, ::Zero) = nothing |
62 | | -function grads!(dict::IdDict, ℓ::Leaf, x, x̄) |
63 | | - x̄₀ = get(dict, ℓ, ZeroTangent()) |
64 | | - dict[ℓ] = x̄ + x̄₀ # adding Zero should be free. Lazy accumulation broadcasted(+, x̄, x̄₀) also possible. |
| 92 | +_grads!(dict::IdDict, ℓ::Leaf, x, ::Zero...) = nothing |
| 93 | +function _grads!(dict::IdDict, ℓ::Leaf, x, x̄s...) |
| 94 | + x̄s₀ = get(dict, ℓ, map(_ -> ZeroTangent(), x̄s)) |
| 95 | + dict[ℓ] = map(+, x̄s, x̄s₀) # adding Zero should be free. Lazy accumulation broadcasted(+, x̄, x̄₀) also possible. |
65 | 96 | nothing |
66 | 97 | end |
67 | | -grads!(dict::IdDict, t, x, ::Zero) = nothing |
68 | | -function grads!(dict::IdDict, tree, x, x̄s...) |
69 | | - # The only reason grads! takes model is that functor(typeof(x), base(x̄)) may differ from |
| 98 | +_grads!(dict::IdDict, t, x, ::Zero...) = nothing |
| 99 | +function _grads!(dict::IdDict, tree, x, x̄s...) |
| 100 | + # The only reason _grads! takes model is that functor(typeof(x), base(x̄)) may differ from |
70 | 101 | # functor(typeof(tree), base(x̄)), for things like Transpose |
71 | 102 | x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s) |
72 | 103 | x′, _ = functor(typeof(x), x) |
73 | | - foreach((tᵢ, xᵢ, x̄sᵢ...) -> grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) |
74 | | -end |
75 | | - |
76 | | -function update(tree, x, x̄s...) |
77 | | - t′ = fmap(copy, tree; exclude = maywrite) # goes inside Leaf |
78 | | - x′ = fmap(copy, x; exclude = maywrite) |
79 | | - update!(t′, x′, x̄s...) |
80 | | -end |
81 | | - |
82 | | -# This differs from _default_walk(f,x,y) in taking re from 2nd argument, but cache will still operate on the first |
83 | | -function _second_walk(f, x, y) |
84 | | - x′, _ = functor(typeof(y), x) |
85 | | - y′, re = functor(y) |
86 | | - re(map(f, x′, y′)) |
87 | | -end |
88 | | - |
89 | | -# When fmap reconstructs for update!, it should not cache results with trivial nodes like () in the state. |
90 | | -# This cache type has just enough methods to work in Functors, which possibly should be upgraded to just work. |
91 | | -struct LeafCache <: AbstractDict{Leaf,Any} |
92 | | - dict::IdDict{Leaf,Any} |
| 104 | + foreach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) |
93 | 105 | end |
94 | | -LeafCache() = LeafCache(IdDict{Leaf,Any}()) |
95 | | - |
96 | | -Base.setindex!(c::LeafCache, x, ℓ::Leaf) = setindex!(c.dict, x, ℓ) |
97 | | -Base.setindex!(c::LeafCache, x, _) = nothing |
98 | | -Base.in(k, c::LeafCache) = k in c.dict |
99 | | -Base.haskey(c::LeafCache, k) = haskey(c.dict, k) |
100 | | -Base.getindex(c::LeafCache, ℓ::Leaf) = getindex(c.dict, ℓ) |
101 | | -Base.iterate(c::LeafCache, i = 0) = iterate(c.dict, i) |
102 | | -Base.length(c::LeafCache) = length(c.dict) |
103 | 106 |
|
104 | 107 | # default all rules to first order calls |
105 | 108 | apply!(o, state, x, dx, dx2, dxs...) = apply!(o, state, x, dx) |
|
0 commit comments