@@ -2,18 +2,33 @@ using BenchmarkTools
22using Optimisers
33using Functors
44using Zygote, Flux
5+ using ChainRulesCore
56
67function trainables1 (x)
7- Optimisers. isnumeric (x) && return [x]
88 arrays = AbstractArray[]
9- exclude (x) = Optimisers. isnumeric (x) && Functors . isleaf (x)
9+ exclude (x) = Optimisers. isnumeric (x)
1010 fmap (x; exclude, walk = Optimisers. _TrainableStructWalk ()) do y
1111 push! (arrays, y)
1212 return y
1313 end
1414 return arrays
1515end
1616
17+ function ∇trainables1 (x, Δ)
18+ exclude (x) = Optimisers. isnumeric (x)
19+ i = 0
20+ return fmapstructure (x; exclude, walk = Optimisers. _TrainableStructWalk ()) do _
21+ return Δ[i+= 1 ]
22+ end
23+ end
24+
25+
26+ function ChainRulesCore. rrule (:: typeof (trainables1), x)
27+ y = trainables1 (x)
28+ trainables_back (Δ) = (NoTangent (), ∇trainables1 (x, unthunk (Δ)))
29+ return y, trainables_back
30+ end
31+
1732# ###########
1833
1934using Functors: AbstractWalk, _map, _values, execute, ExcludeWalk
4964
5065
5166function floss (ps)
52- sum ([sum (p) for p in ps])
67+ sum ([sum (abs2, p) for p in ps])
5368end
5469
5570using Flux
5671
5772function perf ()
5873 m = Chain (Dense (128 => 128 , relu),
5974 Dense (128 => 128 , relu),
60- BatchNorm (128 ), Dense (3 => 2 ), x -> x^ 2 )
75+ BatchNorm (128 ),
76+ x -> x^ 2 ,
6177 Dense (128 => 128 , relu),
62- Dense (128 => 128 , relu)
78+ Dense (128 => 128 , relu))
6379
6480 println (" trainables1" )
65- @btime trainables1 ($ m)
81+ @btime floss ( trainables1 ($ m) )
6682 println (" trainables2" )
67- @btime trainables2 ($ m)
83+ @btime floss ( trainables2 ($ m) )
6884 println (" trainables3" )
69- @btime trainables3 ($ m)
85+ @btime floss ( trainables3 ($ m) )
7086 println ()
7187
72-
73- # gradient(m -> floss(trainables1(m)), #m) # non differentiable since mutating
88+ println ( " gradient trainables1 " )
89+ @btime gradient (m -> floss (trainables1 (m)), $ m)
7490 println (" gradient trainables2" )
7591 @btime gradient (m -> floss (trainables2 (m)), $ m)
7692 println (" gradient trainables3" )
7793 @btime gradient (m -> floss (trainables3 (m)), $ m)
94+
95+ nothing
7896end
7997
8098Zygote. refresh ()
81- perf ()
99+ perf ()
100+
101+
102+ m = Chain (Dense (128 => 128 , relu),
103+ Dense (128 => 128 , relu),
104+ BatchNorm (128 ),
105+ x -> x^ 2 ,
106+ Dense (128 => 128 , relu),
107+ Dense (128 => 128 , relu))
108+
109+ floss (trainables1 (m))
110+ g1 = gradient (m -> floss (trainables1 (m)), m)[1 ]
111+ g2 = gradient (m -> floss (trainables2 (m)), m)[1 ]
112+ @test g1. layers[1 ]. weight ≈ g2. layers[1 ]. weight
113+ @test g1. layers[1 ]. weight ≈ g2. layers[1 ]. weight
114+ @test g1. layers[3 ]. μ === nothing
115+ @test g2. layers[3 ]. μ === nothing
0 commit comments