Skip to content

Commit 7c778bd

Browse files
authored
Allow types from ChainRules (#41)
* using ChainRulesCore * version * indents * add some tests * no weird indents
1 parent cf7e403 commit 7c778bd

File tree

3 files changed

+26
-16
lines changed

3 files changed

+26
-16
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,21 @@ authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
44
version = "0.2.0"
55

66
[deps]
7+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
78
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1011
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1112

1213
[compat]
14+
ChainRulesCore = "1"
1315
Functors = "0.2.7"
1416
julia = "1.6"
1517

1618
[extras]
1719
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
18-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1920
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2021
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2122

2223
[targets]
23-
test = ["Test", "ChainRulesCore", "StaticArrays", "Zygote"]
24+
test = ["Test", "StaticArrays", "Zygote"]

src/interface.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11

2+
using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero
3+
base(dx::Tangent) = backing(canonicalize(dx))
4+
base(dx) = dx
5+
const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor}
6+
27
struct Leaf{R,S}
38
rule::R
49
state::S
@@ -18,26 +23,21 @@ end
1823

1924
subtract!(x, x̄) = iswriteable(x) ? (x .= x .- x̄) : (x .- x̄)
2025

26+
update!(::Nothing, x, ::Zero...) = nothing, x
2127
update!(::Nothing, x, x̄s...) = nothing, x
2228

29+
update!(ℓ::Leaf, x, ::Zero...) = ℓ, x
2330
function update!(ℓ::Leaf, x, x̄s...)
24-
if all(isnothing, x̄s)
25-
return ℓ, x
26-
else
27-
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, x̄s...)
28-
return Leaf(ℓ.rule, s′), subtract!(x, x̄′)
29-
end
31+
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, base.(x̄s)...)
32+
Leaf(ℓ.rule, s′), subtract!(x, x̄′)
3033
end
3134

35+
update!(tree, x, ::Zero...) = tree, x
3236
function update!(tree, x, x̄s...)
33-
if all(isnothing, x̄s)
34-
return tree, x
35-
else
36-
x̄s′ = map(x̄ -> functor(typeof(x), x̄)[1], x̄s)
37-
x′, re = functor(typeof(x), x)
38-
xtree = map((stᵢ, xᵢ, x̄sᵢ...) -> update!(stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
39-
return map(first, xtree), re(map(last, xtree))
40-
end
37+
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
38+
x′, re = functor(typeof(x), x)
39+
xtree = map((stᵢ, xᵢ, x̄sᵢ...) -> update!(stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
40+
map(first, xtree), re(map(last, xtree))
4141
end
4242

4343
function update(tree, x, x̄s...)

test/runtests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
3131
s3, m3 = Optimisers.update!(s, m, g)
3232
@test objectid(m3[1]) == mid
3333
@test m3[1] [1,2] .- 0.1 .* [25, 33]
34+
35+
g4 = Tangent{typeof(m)}(g...)
36+
s4, m4 = Optimisers.update!(s, ([1.0, 2.0],), g4)
37+
@test m4[1] [1,2] .- 0.1 .* [25, 33]
3438
end
3539

3640
@testset "gradient clipping" begin
@@ -74,6 +78,11 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
7478
@test mf2.x == [1,2]
7579
@test mf2.y == (a = sin, b = [2.9, 3.9], c = 5)
7680

81+
gf3 = Tangent{typeof(mf)}(; x = NoTangent(), y = Tangent{typeof(mf.y)}(; a = NoTangent(), b = [1,1], c = 1))
82+
_, mf3 = Optimisers.update(sf, mf, gf3) # the same, but with ChainRules types
83+
@test mf3.x == [1,2]
84+
@test mf3.y == (a = sin, b = [2.9, 3.9], c = 5)
85+
7786
# TwoThirds has functor a,c only, and trainable a only
7887
mt = TwoThirds(Float32[1,2], Float32[3,4], Float32[5,6])
7988
mt10 = fmap(x -> 10x, mt)

0 commit comments

Comments
 (0)