@@ -38,7 +38,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
3838 end
3939
4040 @testset " gradient clipping" begin
41- m = (α = ([0 ], sin), γ = rand (3 ))
41+ m = (α = ([0.0 ], sin), γ = rand (3 ))
4242 s1 = Optimisers. setup (ClipGrad (13 ), m)
4343 _, m1 = Optimisers. update (s1, m, (α = nothing , γ = [1 ,10 ,100 ],))
4444 @test m. γ .- m1. γ ≈ [1 , 10 , 13 ]
@@ -58,7 +58,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
5858 end
5959
6060 @testset " OptimiserChain" begin
61- x = [1 ,10 ,100 ]; dx = [1 ,2 , 3 ];
61+ x = [1 , 10 , 100.0 ]; dx = [1 , 2 , 3.0 ];
6262 @test Optimisers. update (Optimisers. setup (WeightDecay (0.1 ), x), x, dx)[2 ] ≈ [1 - 0.1 - 1 , 10 - 1 - 2 , 100 - 10 - 3 ]
6363 @test Optimisers. update (Optimisers. setup (ClipGrad (2 ), x), x, dx)[2 ] ≈ [1 - 1 , 10 - 2 , 100 - 2 ]
6464
@@ -81,7 +81,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
8181
8282 @testset " trainable subset" begin
8383 # Foo has an old-style tuple trainable, both elements
84- mf = Foo ([1 , 2 ], (a = sin, b = [3 , 4 ], c = 5 ))
84+ mf = Foo ([1.0 , 2.0 ], (a = sin, b = [3.0 , 4.0 ], c = 5 ))
8585 sf = Optimisers. setup (Descent (0.1 ), mf)
8686 gf = (x = nothing , y = (a = nothing , b = [1 ,1 ], c = 1 ))
8787 _, mf2 = Optimisers. update (sf, mf, gf)
@@ -116,6 +116,20 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
116116 @test Optimisers. update! (s, m, g... )[2 ] isa Foo
117117 end
118118
119+ @testset " eltype preservation" begin
120+ m = (Float16[1 ,2 ], Float32[3 ,4 ])
121+ s1 = Optimisers. setup (Descent (0.1 ), m)
122+ s2, m2 = Optimisers. update (s1, m, m)
123+ @test eltype (m2[1 ]) == Float16 # because update copies & calls update!
124+ @test eltype (m2[2 ]) == Float32
125+
126+ staticm = (SA{Float16}[1 ,2 ], SA{Float32}[3 ,4 ])
127+ s3 = Optimisers. setup (Descent (0.1 ), staticm)
128+ s4, m4 = Optimisers. update (s3, staticm, staticm)
129+ @test eltype (m4[1 ]) == Float16 # because of explicit broadcast in subtract!
130+ @test eltype (m4[2 ]) == Float32
131+ end
132+
119133 @testset " forgotten gradient" begin
120134 x = [1.0 , 2.0 ]
121135 sx = Optimisers. setup (Descent (), x)
0 commit comments