1-
1+ using BenchmarkTools
2+ using Optimisers
3+ using Functors
4+ using Zygote, Flux
25
36function trainables1 (x)
4- isnumeric (x) && return [x]
7+ Optimisers . isnumeric (x) && return [x]
58 arrays = AbstractArray[]
6- fmap (x; exclude = isnumeric, walk = _TrainableStructWalk ()) do y
9+ exclude (x) = Optimisers. isnumeric (x) && Functors. isleaf (x)
10+ fmap (x; exclude, walk = Optimisers. _TrainableStructWalk ()) do y
711 push! (arrays, y)
812 return y
913 end
@@ -17,19 +21,61 @@ using Functors: AbstractWalk, _map, _values, execute, ExcludeWalk
1721struct TrainableWalk2 <: AbstractWalk end
1822
1923function (walk:: TrainableWalk2 )(recurse, x, ys... )
20- x_children = _values ( Optimisers. trainable (x) )
24+ x_children = Optimisers. trainable (x)
2125 ys_children = map (Optimisers. trainable, ys)
22- res = _map (recurse, x_children, ys_children... )
23- @show _values (res)
24- return _values (res)
26+ res = map (recurse, x_children, ys_children... )
27+ return reduce (vcat, values (res),init= [])
2528end
2629
2730function trainables2 (x)
2831 exclude (x) = Optimisers. isnumeric (x) && Functors. isleaf (x)
29- return execute (ExcludeWalk (TrainableWalk2 (), x -> x, exclude), x)
32+ return execute (ExcludeWalk (TrainableWalk2 (), x -> [x], exclude), x)
33+ end
34+
35+
36+ struct TrainableWalk3 <: AbstractWalk end
37+
38+ function (walk:: TrainableWalk3 )(recurse, x, ys... )
39+ x_children = Optimisers. trainable (x)
40+ ys_children = map (Optimisers. trainable, ys)
41+ res = map (recurse, x_children, ys_children... )
42+ return vcat (values (res)... )
43+ end
44+
45+ function trainables3 (x)
46+ exclude (x) = Optimisers. isnumeric (x)
47+ return execute (ExcludeWalk (TrainableWalk3 (), x -> [x], exclude), x)
48+ end
49+
50+
51+ function floss (ps)
52+ sum ([sum (p) for p in ps])
3053end
3154
3255using Flux
3356
34- m = Chain (Dense (2 => 3 , relu), BatchNorm (3 ), Dense (3 => 2 ))
35- trainables2 (m)
57+ function perf ()
58+ m = Chain (Dense (128 => 128 , relu),
59+ Dense (128 => 128 , relu),
60+ BatchNorm (128 ), Dense (3 => 2 ), x -> x^ 2 )
61+ Dense (128 => 128 , relu),
62+ Dense (128 => 128 , relu)
63+
64+ println (" trainables1" )
65+ @btime trainables1 ($ m)
66+ println (" trainables2" )
67+ @btime trainables2 ($ m)
68+ println (" trainables3" )
69+ @btime trainables3 ($ m)
70+ println ()
71+
72+
73+ # gradient(m -> floss(trainables1(m)), #m) # non differentiable since mutating
74+ println (" gradient trainables2" )
75+ @btime gradient (m -> floss (trainables2 (m)), $ m)
76+ println (" gradient trainables3" )
77+ @btime gradient (m -> floss (trainables3 (m)), $ m)
78+ end
79+
80+ Zygote. refresh ()
81+ perf ()
0 commit comments