|
1 | | -using BenchmarkTools |
2 | | -using Optimisers |
3 | | -using Functors |
4 | | -using Zygote, Flux |
5 | | -using ChainRulesCore |
6 | 1 |
|
7 | | -function trainables1(x) |
| 2 | +""" |
| 3 | + trainables(x) |
| 4 | +
|
| 5 | +Return an iterable over all the trainable parameters in `x`, that is all the numerical |
| 6 | +arrays (see [`isnumeric`](@ref Optimisers.isnumeric)) which are reachable through [`trainable`](@ref Optimisers.trainable). |
| 7 | +
|
| 8 | +Parameters appearing multiple times in the model will be present only once in the output. |
| 9 | +
|
| 10 | +See also [`destructure`](@ref). |
| 11 | +
|
| 12 | +# Examples |
| 13 | +
|
| 14 | +```jldoctest |
| 15 | +julia> struct MyLayer |
| 16 | + w |
| 17 | + b |
| 18 | + end |
| 19 | +
|
| 20 | +julia> Functors.@functor MyLayer |
| 21 | +
|
| 22 | +julia> Optimisers.trainable(x::MyLayer) = (; w = x.w,) # only w is trainable in this example |
| 23 | +
|
| 24 | +julia> x = MyLayer([1.0,2.0,3.0], [4.0,5.0,6.0]); |
| 25 | +
|
| 26 | +julia> trainables(x) |
| 27 | +1-element Vector{AbstractArray}: |
| 28 | + [1.0, 2.0, 3.0] |
| 29 | +""" |
| 30 | +function trainables(x) |
8 | 31 | arrays = AbstractArray[] |
9 | 32 | exclude(x) = Optimisers.isnumeric(x) |
10 | | - fmap(x; exclude, walk = Optimisers._TrainableStructWalk()) do y |
| 33 | + fmap(x; exclude, walk = Optimisers.TrainableStructWalk()) do y |
11 | 34 | push!(arrays, y) |
12 | 35 | return y |
13 | 36 | end |
14 | 37 | return arrays |
15 | 38 | end |
16 | 39 |
|
17 | | -function ∇trainables1(x, Δ) |
| 40 | +function ∇trainables(x, Δ) |
18 | 41 | exclude(x) = Optimisers.isnumeric(x) |
19 | 42 | i = 0 |
20 | | - return fmapstructure(x; exclude, walk = Optimisers._TrainableStructWalk()) do _ |
| 43 | + return fmapstructure(x; exclude, walk = Optimisers.TrainableStructWalk()) do _ |
21 | 44 | return Δ[i+=1] |
22 | 45 | end |
23 | 46 | end |
24 | 47 |
|
25 | | - |
26 | | -function ChainRulesCore.rrule(::typeof(trainables1), x) |
27 | | - y = trainables1(x) |
28 | | - trainables_back(Δ) = (NoTangent(), ∇trainables1(x, unthunk(Δ))) |
| 48 | +function ChainRulesCore.rrule(::typeof(trainables), x) |
| 49 | + y = trainables(x) |
| 50 | + trainables_back(Δ) = (NoTangent(), ∇trainables(x, unthunk(Δ))) |
29 | 51 | return y, trainables_back |
30 | 52 | end |
31 | | - |
32 | | -############ |
33 | | - |
34 | | -using Functors: AbstractWalk, _map, _values, execute, ExcludeWalk |
35 | | - |
36 | | -struct TrainableWalk2 <: AbstractWalk end |
37 | | - |
38 | | -function (walk::TrainableWalk2)(recurse, x, ys...) |
39 | | - x_children = Optimisers.trainable(x) |
40 | | - ys_children = map(Optimisers.trainable, ys) |
41 | | - res = map(recurse, x_children, ys_children...) |
42 | | - return reduce(vcat, values(res),init=[]) |
43 | | -end |
44 | | - |
45 | | -function trainables2(x) |
46 | | - exclude(x) = Optimisers.isnumeric(x) && Functors.isleaf(x) |
47 | | - return execute(ExcludeWalk(TrainableWalk2(), x ->[x], exclude), x) |
48 | | -end |
49 | | - |
50 | | - |
51 | | -struct TrainableWalk3 <: AbstractWalk end |
52 | | - |
53 | | -function (walk::TrainableWalk3)(recurse, x, ys...) |
54 | | - x_children = Optimisers.trainable(x) |
55 | | - ys_children = map(Optimisers.trainable, ys) |
56 | | - res = map(recurse, x_children, ys_children...) |
57 | | - return vcat(values(res)...) |
58 | | -end |
59 | | - |
60 | | -function trainables3(x) |
61 | | - exclude(x) = Optimisers.isnumeric(x) |
62 | | - return execute(ExcludeWalk(TrainableWalk3(), x ->[x], exclude), x) |
63 | | -end |
64 | | - |
65 | | - |
66 | | -function floss(ps) |
67 | | - sum([sum(abs2, p) for p in ps]) |
68 | | -end |
69 | | - |
70 | | -using Flux |
71 | | - |
72 | | -function perf() |
73 | | - m = Chain(Dense(128 => 128, relu), |
74 | | - Dense(128 => 128, relu), |
75 | | - BatchNorm(128), |
76 | | - x -> x^2, |
77 | | - Dense(128 => 128, relu), |
78 | | - Dense(128 => 128, relu)) |
79 | | - |
80 | | - println("trainables1") |
81 | | - @btime floss(trainables1($m)) |
82 | | - println("trainables2") |
83 | | - @btime floss(trainables2($m)) |
84 | | - println("trainables3") |
85 | | - @btime floss(trainables3($m)) |
86 | | - println() |
87 | | - |
88 | | - println("gradient trainables1") |
89 | | - @btime gradient(m -> floss(trainables1(m)), $m) |
90 | | - println("gradient trainables2") |
91 | | - @btime gradient(m -> floss(trainables2(m)), $m) |
92 | | - println("gradient trainables3") |
93 | | - @btime gradient(m -> floss(trainables3(m)), $m) |
94 | | - |
95 | | - nothing |
96 | | -end |
97 | | - |
98 | | -Zygote.refresh() |
99 | | -perf() |
100 | | - |
101 | | - |
102 | | -m = Chain(Dense(128 => 128, relu), |
103 | | - Dense(128 => 128, relu), |
104 | | - BatchNorm(128), |
105 | | - x -> x^2, |
106 | | - Dense(128 => 128, relu), |
107 | | - Dense(128 => 128, relu)) |
108 | | - |
109 | | -floss(trainables1(m)) |
110 | | -g1 = gradient(m -> floss(trainables1(m)), m)[1] |
111 | | -g2 = gradient(m -> floss(trainables2(m)), m)[1] |
112 | | -@test g1.layers[1].weight ≈ g2.layers[1].weight |
113 | | -@test g1.layers[1].weight ≈ g2.layers[1].weight |
114 | | -@test g1.layers[3].μ === nothing |
115 | | -@test g2.layers[3].μ === nothing |
0 commit comments