@@ -5,6 +5,8 @@ using Optimisers: @.., @lazy
55
66Random. seed! (1 )
77
8+ # Fake "models" for testing
9+
810struct Foo; x; y; end
911Functors. @functor Foo
1012Optimisers. trainable (x:: Foo ) = (x. y, x. x)
@@ -16,6 +18,8 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
1618mutable struct MutTwo; x; y; end
1719Functors. @functor MutTwo
1820
21+ # Simple rules for testing
22+
1923struct DummyHigherOrder <: AbstractRule end
2024Optimisers. init (:: DummyHigherOrder , x:: AbstractArray ) =
2125 (ones (eltype (x), size (x)), zero (x))
227231 @test_throws MethodError Optimisers. update (sm, m)
228232 end
229233
230- @testset " 2nd order gradient" begin
231- m = (α = ([1.0 ], sin), γ = Float32[4 ,3 ,2 ])
232-
233- # Special rule which requires this:
234- s = Optimisers. setup (BiRule (), m)
235- g = (α = ([0.1 ], ZeroTangent ()), γ = [1 ,10 ,100 ],)
236- s1, m1 = Optimisers. update (s, m, g, g)
237- @test m1. α[1 ] == [0.9 ]
238- @test_throws Exception Optimisers. update (s, m, g, map (x-> 2 .* x, g))
239-
240- # Ordinary rule which doesn't need it:
241- s2 = Optimisers. setup (Adam (), m)
242- s3, m3 = Optimisers. update (s2, m, g)
243- s4, m4 = Optimisers. update (s2, m, g, g)
244- @test m3. γ == m4. γ
245- end
246-
247234 @testset " broadcasting macros" begin
248235 x = [1.0 , 2.0 ]; y = [3 ,4 ]; z = [5 ,6 ]
249236 @test (@lazy x + y * z) isa Broadcast. Broadcasted
@@ -365,34 +352,53 @@ end
365352 @test model2. a === model2. b # tie of MutTwo structs is restored
366353 @test model2. a != = model2. c # but a new tie is not created
367354 end
368- end
355+ end # tied weights
356+
357+ @testset " 2nd-order interface" begin
358+ @testset " BiRule" begin
359+ m = (α = ([1.0 ], sin), γ = Float32[4 ,3 ,2 ])
360+
361+ # Special rule which requires this:
362+ s = Optimisers. setup (BiRule (), m)
363+ g = (α = ([0.1 ], ZeroTangent ()), γ = [1 ,10 ,100 ],)
364+ s1, m1 = Optimisers. update (s, m, g, g)
365+ @test m1. α[1 ] == [0.9 ]
366+ @test_throws Exception Optimisers. update (s, m, g, map (x-> 2 .* x, g))
367+
368+ # Ordinary rule which doesn't need it:
369+ s2 = Optimisers. setup (Adam (), m)
370+ s3, m3 = Optimisers. update (s2, m, g)
371+ s4, m4 = Optimisers. update (s2, m, g, g)
372+ @test m3. γ == m4. γ
373+ end
369374
370- @testset " higher order interface" begin
371- w, b = rand (3 , 4 ), rand (3 )
372-
373- o = DummyHigherOrder ()
374- psin = (w, b)
375- dxs = map (x -> rand (size (x)... ), psin)
376- dx2s = map (x -> rand (size (x)... ), psin)
377- stin = Optimisers. setup (o, psin)
378- stout, psout = Optimisers. update (stin, psin, dxs, dx2s)
379-
380- # hardcoded rule behavior for dummy rule
381- @test psout[1 ] == dummy_update_rule (stin[1 ]. state, psin[1 ], dxs[1 ], dx2s[1 ])
382- @test psout[2 ] == dummy_update_rule (stin[2 ]. state, psin[2 ], dxs[2 ], dx2s[2 ])
383- @test stout[1 ]. state[1 ] == stin[1 ]. state[1 ] .+ 1
384- @test stout[2 ]. state[2 ] == stin[2 ]. state[2 ] .+ 1
385-
386- # error if only given one derivative
387- @test_throws MethodError Optimisers. update (stin, psin, dxs)
388-
389- # first-order rules compose with second-order
390- ochain = OptimiserChain (Descent (0.1 ), o)
391- stin = Optimisers. setup (ochain, psin)
392- stout, psout = Optimisers. update (stin, psin, dxs, dx2s)
393- @test psout[1 ] == dummy_update_rule (stin[1 ]. state[2 ], psin[1 ], 0.1 * dxs[1 ], dx2s[1 ])
394- @test psout[2 ] == dummy_update_rule (stin[2 ]. state[2 ], psin[2 ], 0.1 * dxs[2 ], dx2s[2 ])
395- end
375+ @testset " DummyHigherOrder" begin
376+ w, b = rand (3 , 4 ), rand (3 )
377+
378+ o = DummyHigherOrder ()
379+ psin = (w, b)
380+ dxs = map (x -> rand (size (x)... ), psin)
381+ dx2s = map (x -> rand (size (x)... ), psin)
382+ stin = Optimisers. setup (o, psin)
383+ stout, psout = Optimisers. update (stin, psin, dxs, dx2s)
384+
385+ # hardcoded rule behavior for dummy rule
386+ @test psout[1 ] == dummy_update_rule (stin[1 ]. state, psin[1 ], dxs[1 ], dx2s[1 ])
387+ @test psout[2 ] == dummy_update_rule (stin[2 ]. state, psin[2 ], dxs[2 ], dx2s[2 ])
388+ @test stout[1 ]. state[1 ] == stin[1 ]. state[1 ] .+ 1
389+ @test stout[2 ]. state[2 ] == stin[2 ]. state[2 ] .+ 1
390+
391+ # error if only given one derivative
392+ @test_throws MethodError Optimisers. update (stin, psin, dxs)
393+
394+ # first-order rules compose with second-order
395+ ochain = OptimiserChain (Descent (0.1 ), o)
396+ stin = Optimisers. setup (ochain, psin)
397+ stout, psout = Optimisers. update (stin, psin, dxs, dx2s)
398+ @test psout[1 ] == dummy_update_rule (stin[1 ]. state[2 ], psin[1 ], 0.1 * dxs[1 ], dx2s[1 ])
399+ @test psout[2 ] == dummy_update_rule (stin[2 ]. state[2 ], psin[2 ], 0.1 * dxs[2 ], dx2s[2 ])
400+ end
401+ end # 2nd-order
396402
397403 end
398404 @testset verbose= true " Destructure" begin
0 commit comments