|
1 | 1 |
|
| 2 | +using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero |
| 3 | +base(dx::Tangent) = backing(canonicalize(dx)) |
| 4 | +base(dx) = dx |
| 5 | +const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor} |
| 6 | + |
2 | 7 | struct Leaf{R,S} |
3 | 8 | rule::R |
4 | 9 | state::S |
|
18 | 23 |
|
19 | 24 | subtract!(x, x̄) = iswriteable(x) ? (x .= x .- x̄) : (x .- x̄) |
20 | 25 |
|
| 26 | +update!(::Nothing, x, ::Zero...) = nothing, x |
21 | 27 | update!(::Nothing, x, x̄s...) = nothing, x |
22 | 28 |
|
| 29 | +update!(ℓ::Leaf, x, ::Zero...) = ℓ, x |
23 | 30 | function update!(ℓ::Leaf, x, x̄s...) |
24 | | - if all(isnothing, x̄s) |
25 | | - return ℓ, x |
26 | | - else |
27 | | - s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, x̄s...) |
28 | | - return Leaf(ℓ.rule, s′), subtract!(x, x̄′) |
29 | | - end |
| 31 | + s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, base.(x̄s)...) |
| 32 | + Leaf(ℓ.rule, s′), subtract!(x, x̄′) |
30 | 33 | end |
31 | 34 |
|
| 35 | +update!(tree, x, ::Zero...) = tree, x |
32 | 36 | function update!(tree, x, x̄s...) |
33 | | - if all(isnothing, x̄s) |
34 | | - return tree, x |
35 | | - else |
36 | | - x̄s′ = map(x̄ -> functor(typeof(x), x̄)[1], x̄s) |
37 | | - x′, re = functor(typeof(x), x) |
38 | | - xtree = map((stᵢ, xᵢ, x̄sᵢ...) -> update!(stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) |
39 | | - return map(first, xtree), re(map(last, xtree)) |
40 | | - end |
| 37 | + x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s) |
| 38 | + x′, re = functor(typeof(x), x) |
| 39 | + xtree = map((stᵢ, xᵢ, x̄sᵢ...) -> update!(stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) |
| 40 | + map(first, xtree), re(map(last, xtree)) |
41 | 41 | end |
42 | 42 |
|
43 | 43 | function update(tree, x, x̄s...) |
|
0 commit comments