@@ -16,10 +16,19 @@ RULES = [
1616 OptimiserChain (WeightDecay (), OADAM (), ClipGrad (1 )),
1717]
1818
19- name (o) = typeof (o). name. name
19+ name (o) = typeof (o). name. name # just for printing testset headings
2020name (o:: OptimiserChain ) = join (name .(o. opts), " → " )
2121
22+ LOG = Dict () # for debugging these testsets, this makes it easy to plot each optimiser's loss
23+
24+ loggradient (o) = (f, xs... ) -> begin
25+ y, dxs = Zygote. withgradient (f, xs... )
26+ push! (get! (() -> Float32[], LOG, name (o)), y)
27+ dxs # save the loss, return the gradient
28+ end
29+
2230@testset " independence" begin
31+ empty! (LOG)
2332 @testset " $(name (o)) " for o in RULES
2433 w = randn (10 , 10 )
2534 w′ = randn (10 , 10 )
@@ -28,22 +37,23 @@ name(o::OptimiserChain) = join(name.(o.opts), " → ")
2837 st = Optimisers. setup (o, w)
2938 for t = 1 : 10 ^ 5
3039 x = rand (10 )
31- gs = gradient (w -> iloss (x, w, w′), w)
40+ gs = loggradient (o) (w -> iloss (x, w, w′), w)
3241 st, w = Optimisers. update! (st, w, gs... )
3342 end
3443 @test iloss (rand (10 , 10 ), w, w′) < 0.01
3544 end
3645end
3746
3847@testset verbose= true " simple sum" begin
48+ empty! (LOG)
3949 @testset " $(name (o)) " for o in RULES
4050 m = shuffle! (reshape (1 : 64 , 8 , 8 ) .+ 0.0 )
4151 s = Optimisers. setup (o, m)
4252 for _ in 1 : 10 ^ 5
43- g = gradient (x -> sum (abs2, x + x' ), m)[1 ]
53+ g = loggradient (o) (x -> sum (abs2, x + x' ), m)[1 ]
4454 s, m = Optimisers. update! (s, m, g)
4555 end
46- # @test sum(m) < sum(1:64)
56+ @test sum (m) < sum (1 : 64 )
4757 if sum (m) < 1
4858 @test sum (m) < 1
4959 else
5464end
5565
5666@testset " original" begin
67+ empty! (LOG)
5768 @testset " $(name (o)) " for o in RULES
5869 w′ = (α = rand (3 , 3 ), β = rand (3 , 3 ))
5970 w = (α = 5 rand (3 , 3 ), β = rand (3 , 3 ))
6071 st = Optimisers. setup (o, w)
6172 loss (x, y) = mean ((x. α .* x. β .- y. α .* y. β) .^ 2 )
6273 @test loss (w, w′) > 1
6374 for i = 1 : 10 ^ 4
64- gs = gradient (x -> loss (x, w′), w)
75+ gs = loggradient (o) (x -> loss (x, w′), w)
6576 st, w = Optimisers. update (st, w, gs... )
6677 end
6778 @test loss (w, w′) < 0.001
6879 end
6980end
7081
7182@testset verbose= true " StaticArrays" begin
83+ empty! (LOG)
7284 @testset " $(name (o)) " for o in RULES
7385 W1 = @SMatrix randn (10 , 10 )
7486 b1 = @SVector randn (10 )
8294 @test s_loss (model, x, y) > 10
8395 state = Optimisers. setup (o, model)
8496 for t = 1 : 10 ^ 3
85- g = gradient (m -> s_loss (m, x, y), model)[1 ]
97+ g = loggradient (o) (m -> s_loss (m, x, y), model)[1 ]
8698 state, model = Optimisers. update! (state, model, g)
8799 end
88100 if o isa Descent
94106 end
95107end
96108
97- @testset verbose = true " element types" begin
109+ @testset " element types" begin
98110 @testset " $(name (o)) " for o in RULES
99111 marray = (Float32[1 ,2 ], Float64[3 ,4 ], Float16[5 ,6 ])
100112 types = map (eltype, marray)
166178 end
167179end
168180
181+ @testset " with complex numebers: Flux#1776" begin
182+ empty! (LOG)
183+ @testset " $(name (opt)) " for opt in [
184+ # The Flux PR had 1e-2 for all. But ADADelta(ρ) needs ρ≈0.9 not small. And it helps to make ε not too small too:
185+ ADAM (1e-2 ), RMSProp (1e-2 ), RADAM (1e-2 ), OADAM (1e-2 ), ADAGrad (1e-2 ), ADADelta (0.9 , 1e-5 ), NADAM (1e-2 ), AdaBelief (1e-2 ),
186+ # These weren't in Flux PR:
187+ Descent (1e-2 ), Momentum (1e-2 ), Nesterov (1e-2 ), ADAMW (1e-2 ),
188+ ]
189+ # Our "model" is just a complex number
190+ model = (w = zeros (ComplexF64, 1 ),)
191+
192+ # Our model attempts to learn `f(x) = conj(x)` where `f(x) = w*x`
193+ function loss (m)
194+ # Deterministic training data is the best training data
195+ x = ones (1 , 1 ) + 1im * ones (1 , 1 )
196+ # Manually implement `mse()` to allow demonstration of brokenness
197+ # on older Flux builds that don't have a fixed `mse()`
198+ return sum (abs2 .(m. w * x .- conj (x)))
199+ end
200+ @test loss (model) ≈ 2.0
201+
202+ state = Optimisers. setup (opt, model)
203+
204+ # Train for 10 iterations, enforcing that loss is monotonically decreasing
205+ last_loss = Inf
206+ for idx in 1 : 10
207+ grads = loggradient (opt)(loss, model)
208+ state, model = Optimisers. update! (state, model, grads... )
209+ opt isa Union{Momentum, Nesterov} && idx > 8 && continue # these are very flat at the end
210+ @test loss (model) < last_loss
211+ last_loss = loss (model)
212+ end
213+ @test loss (model) < 1.9
214+
215+ # Repeat with StaticArrays
216+ static_model = (w = SA[0.0 + 0im ],)
217+ static_state = Optimisers. setup (opt, static_model)
218+ function static_loss (m)
219+ x = hcat (SA[1.0 + im])
220+ sum (abs2 .(m. w * x .- conj (x)))
221+ end
222+ last_loss = Inf
223+ for idx in 1 : 10
224+ grads = gradient (static_loss, static_model)
225+ static_state, static_model = Optimisers. update! (static_state, static_model, grads... )
226+ opt isa Union{Momentum, Nesterov} && idx > 8 && continue
227+ @test static_loss (static_model) < last_loss
228+ last_loss = static_loss (static_model)
229+ end
230+ @test static_loss (static_model) < 1.9
231+ end
232+ end
0 commit comments