|
216 | 216 | r = Any[nothing,m] |
217 | 217 | r[1] = r |
218 | 218 | @test size.(params(r)) == [(5, 10), (5, 5), (5,), (5, 1)] |
| 219 | + |
| 220 | + @testset "use params in gradient context" begin |
| 221 | + m = Chain(Dense(3,2), Dense(2,2)) |
| 222 | + ps = Flux.params(m) |
| 223 | + gs = gradient(() -> sum(sum(p) for p in Flux.params(m)), ps) |
| 224 | + for p in ps |
| 225 | + @test gs[p] ≈ ones(size(p)) |
| 226 | + end |
| 227 | + |
| 228 | + w1, w2 = rand(2), rand(2) |
| 229 | + ps = Flux.params(w1, w2) |
| 230 | + gs = gradient(() -> sum(sum(p) for p in Flux.params(w1, w2)), ps) |
| 231 | + for p in ps |
| 232 | + @test gs[p] ≈ ones(size(p)) |
| 233 | + end |
| 234 | + |
| 235 | + # BROKEN TESTS |
| 236 | + m = Chain(Dense(3,2), Dense(2,2)) |
| 237 | + @test_broken gradient(m -> sum(params(m)[1]), m) != (nothing, ) |
| 238 | + @test_broken gradient(m -> sum(params(m)[1]), m) != (nothing, ) |
| 239 | + |
| 240 | + gs = gradient(() -> sum(params(m)[1]), params(m)) |
| 241 | + @test_broken gs[params(m)[1]] !== nothing |
| 242 | + |
| 243 | + # Tests from https://github.com/FluxML/Flux.jl/pull/1614 |
| 244 | + m = Dense(3, 2) |
| 245 | + ps = Flux.params(m) |
| 246 | + data = rand(Float32, 3, 5) |
| 247 | + loss(m, x) = sum(m(x).^2) |
| 248 | + |
| 249 | + g1 = gradient(Flux.params(m)) do |
| 250 | + loss(m, data) |
| 251 | + end |
| 252 | + g2 = gradient(Flux.params(m)) do |
| 253 | + ps = Flux.params(m) # just creating params without using them |
| 254 | + loss(m, data) |
| 255 | + end |
| 256 | + g3 = gradient(Flux.params(m)) do |
| 257 | + ps = Flux.params(m) |
| 258 | + loss(m, data) + sum(sum(p) for p in ps) |
| 259 | + end |
| 260 | + g4 = gradient(Flux.params(m)) do |
| 261 | + loss(m, data) + sum(sum(p) for p in ps) |
| 262 | + end |
| 263 | + g5 = gradient(Flux.params(m)) do |
| 264 | + sum(Flux.params(m)[1]) + sum(Flux.params(m)[2]) |
| 265 | + end |
| 266 | + g6 = gradient(Flux.params(m)) do |
| 267 | + sum(ps[1]) + sum(ps[2]) |
| 268 | + end |
| 269 | + @test g2[m.weight] == g1[m.weight] |
| 270 | + @test g3[m.weight] == g1[m.weight] .+ 1 |
| 271 | + @test g4[m.weight] == g1[m.weight] .+ 1 |
| 272 | + @test_broken g5[m.weight] .== 1 # TODO regression with respect to master |
| 273 | + @test_broken g6[m.weight] .== 1 # Not a regression, broken on master |
| 274 | + end |
219 | 275 | end |
220 | 276 |
|
221 | 277 | @testset "Basic Stacking" begin |
|
0 commit comments