6767# ##
6868
6969"""
70- Optimisers.setup(rule, model) -> tree
70+ Optimisers.setup(rule, model) -> state_tree
7171
7272Initialises the given optimiser for every trainable parameter within the model.
7373Returns a tree of the relevant states, which must be passed to [`update`](@ref)
@@ -142,6 +142,7 @@ This is used in exactly the same manner as [`update`](@ref), but because it may
142142arrays within the old model (and the old state), it will be faster for models of ordinary
143143`Array`s or `CuArray`s. However, you should not rely on the old model being fully updated
144144but rather use the returned model.
145+ (The original state tree is always mutated, as each `Leaf` is mutable.)
145146
146147# Example
147148
@@ -150,9 +151,10 @@ julia> using StaticArrays, Zygote, Optimisers
150151
151152julia> m = (x = [1f0, 2f0], y = SA[4f0, 5f0]); # partly mutable model
152153
153- julia> t = Optimisers.setup(Momentum(1/30, 0.9), m);
154+ julia> t = Optimisers.setup(Momentum(1/30, 0.9), m) # tree of states
155+ (x = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]), y = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]))
154156
155- julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1]
157+ julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1] # structural gradient
156158(x = Float32[10.0, 14.0], y = Float32[10.0, 14.0])
157159
158160julia> t2, m2 = Optimisers.update!(t, m, g);
166168julia> m # original should be discarded, may be mutated but no guarantee
167169(x = Float32[0.6666666, 1.5333333], y = Float32[4.0, 5.0])
168170
169- julia> t == t2 # original state is in fact guaranteed to be mutated
171+ julia> t == t2 # original state tree is guaranteed to be mutated
170172true
171173```
172174"""
0 commit comments