6666# ##
6767
6868"""
69- Optimisers.setup(rule, model) -> tree
69+ Optimisers.setup(rule, model) -> state_tree
7070
7171Initialises the given optimiser for every trainable parameter within the model.
7272Returns a tree of the relevant states, which must be passed to [`update`](@ref)
@@ -141,6 +141,7 @@ This is used in exactly the same manner as [`update`](@ref), but because it may
141141arrays within the old model (and the old state), it will be faster for models of ordinary
142142`Array`s or `CuArray`s. However, you should not rely on the old model being fully updated
143143but rather use the returned model.
144+ (The original state tree is always mutated, as each `Leaf` is mutable.)
144145
145146# Example
146147
@@ -149,9 +150,10 @@ julia> using StaticArrays, Zygote, Optimisers
149150
150151julia> m = (x = [1f0, 2f0], y = SA[4f0, 5f0]); # partly mutable model
151152
152- julia> t = Optimisers.setup(Momentum(1/30, 0.9), m);
153+ julia> t = Optimisers.setup(Momentum(1/30, 0.9), m) # tree of states
154+ (x = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]), y = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]))
153155
154- julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1]
156+ julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1] # structural gradient
155157(x = Float32[10.0, 14.0], y = Float32[10.0, 14.0])
156158
157159julia> t2, m2 = Optimisers.update!(t, m, g);
165167julia> m # original should be discarded, may be mutated but no guarantee
166168(x = Float32[0.6666666, 1.5333333], y = Float32[4.0, 5.0])
167169
168- julia> t == t2 # original state is in fact guaranteed to be mutated
170+ julia> t == t2 # original state tree is guaranteed to be mutated
169171true
170172```
171173"""
0 commit comments