|
| 1 | +### |
| 2 | +### freezing |
| 3 | +### |
| 4 | + |
| 5 | +""" |
| 6 | + Optimisers.freeze!(tree) |
| 7 | +
|
| 8 | +Temporarily alters the state `tree = setup(rule, model)` so that parameters |
| 9 | +will not be updated. Un-done by [`thaw!`](@ref Optimisers.thaw!). |
| 10 | +
|
| 11 | +Can be applied to the state corresponding to only part of a model, |
| 12 | +for instance with `model::Chain`, to freeze `model.layers[1]` you |
| 13 | +should call `freeze!(tree.layers[1])`. |
| 14 | +
|
| 15 | +# Example |
| 16 | +```jldoctest |
| 17 | +julia> m = (x = ([1.0], 2.0), y = [3.0]); |
| 18 | +
|
| 19 | +julia> s = Optimisers.setup(Momentum(), m); |
| 20 | +
|
| 21 | +julia> Optimisers.freeze!(s.x) |
| 22 | +
|
| 23 | +julia> Optimisers.update!(s, m, (x = ([pi], 10pi), y = [100pi])); # with fake gradient |
| 24 | +
|
| 25 | +julia> m |
| 26 | +(x = ([1.0], 2.0), y = [-0.14159258336972558]) |
| 27 | +
|
| 28 | +julia> s |
| 29 | +(x = (Leaf(Momentum{Float32}(0.01, 0.9), [0.0], frozen = true), ()), y = Leaf(Momentum{Float32}(0.01, 0.9), [3.14159])) |
| 30 | +
|
| 31 | +julia> Optimisers.thaw!(s) |
| 32 | +
|
| 33 | +julia> s.x |
| 34 | +(Leaf(Momentum{Float32}(0.01, 0.9), [0.0]), ()) |
| 35 | +``` |
| 36 | +""" |
| 37 | +freeze!(tree) = foreach(freeze!, tree) |
| 38 | +freeze!(ℓ::Leaf) = (ℓ.frozen = true; nothing) |
| 39 | + |
| 40 | +""" |
| 41 | + Optimisers.thaw!(tree) |
| 42 | +
|
| 43 | +The reverse of [`freeze!`](@ref Optimisers.freeze!). Applies to all parameters, |
| 44 | +mutating every `Leaf(rule, state, frozen = true)` to `Leaf(rule, state, frozen = false)`. |
| 45 | +""" |
| 46 | +thaw!(tree) = foreach(thaw!, tree) |
| 47 | +thaw!(ℓ::Leaf) = (ℓ.frozen = false; nothing) |
| 48 | + |
| 49 | +freeze!(::Union{Number, AbstractArray{<:Number}}) = throw(ArgumentError( |
| 50 | + "`freeze!` must not be applied to a model, only to the state tree from `setup`")) |
| 51 | +thaw!(::Union{Number, AbstractArray{<:Number}}) = throw(ArgumentError( |
| 52 | + "`thaw!` must not be applied to a model, only to the state tree from `setup`")) |
| 53 | + |
| 54 | +### |
| 55 | +### adjust |
| 56 | +### |
1 | 57 |
|
2 | 58 | """ |
3 | 59 | Optimisers.adjust(tree, η) -> tree |
@@ -47,8 +103,8 @@ adjust(tree; kw...) = map(st -> adjust(st; kw...), tree) |
47 | 103 | adjust(::Nothing, ::Real) = nothing |
48 | 104 | adjust(::Nothing; kw...) = nothing |
49 | 105 |
|
50 | | -adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state) |
51 | | -adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state) |
| 106 | +adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state, ℓ.frozen) |
| 107 | +adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state, ℓ.frozen) |
52 | 108 |
|
53 | 109 |
|
54 | 110 | """ |
|
0 commit comments