@@ -6,36 +6,41 @@ using Functors: fmap
66
77import .. Flux. Optimise: train!, update! # during 0.13, we add methods to the old functions
88
9- export setup, @train_autodiff
9+ export setup, train!
1010
1111using ProgressLogging: @progress , @withprogress , @logprogress
1212using Zygote: Zygote, Params
1313
1414"""
1515 opt = setup(rule, model)
1616
17- This is a version of `Optimisers.setup`, and is the first step before using `train!`.
17+ This is a version of `Optimisers.setup`, and is the first step before using [ `train!`](@ref Flux.train!) .
1818It differs from `Optimisers.setup` in that it:
1919* has one extra check for mutability
2020* has methods which accept Flux's old optimisers, and convert them.
2121
22+ # Example
2223```jldoctest
2324julia> model = Dense(2=>1, leakyrelu; init=Flux.ones32);
2425
25- julia> opt = Flux.setup(Momentum(0.11 ), model)
26- (weight = Leaf(Momentum{Float64}(0.11 , 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.11 , 0.9), Float32[0.0]), σ = ())
26+ julia> opt = Flux.setup(Momentum(0.1 ), model) # this encodes the optimiser and its state
27+ (weight = Leaf(Momentum{Float64}(0.1 , 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1 , 0.9), Float32[0.0]), σ = ())
2728
28- julia> Flux.train!(model, opt) do m # 3-arg train!, for one data point (x = [0.2, -0.3], y = [0.4])
29- sum(m([0.2, -0.3]) .- [0.4]) * 100
29+ julia> x1, y1 = [0.2, -0.3], [0.4]; # use the same data for two steps:
30+
31+ julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt) do m, x, y
32+ sum(abs.(m(x) .- y)) * 100
3033 end
31- -40.1
34+ 2-element Vector{Float32}:
35+ 40.1
36+ 38.7
3237
3338julia> model.bias # was zero, mutated by Flux.train!
34391-element Vector{Float32}:
35- -0.11
40+ 10.190001
3641
3742julia> opt # mutated by Flux.train!
38- (weight = Leaf(Momentum{Float64}(0.11 , 0.9), Float32[0.022 -0.033 ]), bias = Leaf(Momentum{Float64}(0.11 , 0.9), Float32[0.11 ]), σ = ())
43+ (weight = Leaf(Momentum{Float64}(0.1 , 0.9), Float32[-2.018 3.027 ]), bias = Leaf(Momentum{Float64}(0.1 , 0.9), Float32[-10.09 ]), σ = ())
3944```
4045"""
4146function setup (rule:: Optimisers.AbstractRule , model)
5156 train!(loss, model, data, opt)
5257
5358Uses a `loss` function and training `data` to improve the `model`'s parameters
54- according to a particular optimisation rule `opt`.
55-
56- !!! note
57- This method has significant changes from the one in Flux ≤ 0.13:
58- * It now takes the `model` itself, not the result of [`Flux.params`](@ref).
59- (This is to move away from Zygote's implicit parameter handling.)
60- * Instead of `loss` being a function which typically accepts two arguments
61- (the input `x` and expected output `y` from each element of `data`)
62- now it should typically accept three, the first of which is the `model` itself.
63- * `data` must iterate tuples. Each `d in data` is used as `loss(model, d...)`.
64- * `opt` should be the result of [`Flux.setup`](@ref), it will warn you if not.
65- * Callback functions are not supported.
59+ according to a particular optimisation rule `opt`. Iterates through `data` once,
60+ evaluating `loss(model, d...)` for each `d` in data.
6661
6762For example, with these definitions...
6863```
@@ -72,15 +67,17 @@ loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
7267
7368opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta
7469```
75- ...calling `train!(loss3, model, data, opt)` runs a loop much like this:
70+ ...calling `Flux.train!(loss3, model, data, opt)` runs a loop much like this,
71+ using Zygote's "explicit" mode for the gradient:
7672```
7773for d in data
78- ∂L∂m = Zygote. gradient(loss3, model, d...)[1]
79- Optimisers. update!(opt, model, ∂L∂m)
74+ ∂L∂m = gradient(loss3, model, d...)[1]
75+ update!(opt, model, ∂L∂m) # method for "explicit" gradient
8076end
8177```
8278You can also write this loop yourself, if you need more flexibility.
83- Besides the loop, `train!` will:
79+ For this reason `train!` is not highly extensible.
80+ It adds only a few featurs to the loop above:
8481
8582* Stop with a `DomainError` if the loss is infinite or `NaN` at any point.
8683
@@ -91,20 +88,36 @@ Besides the loop, `train!` will:
9188Note that the built-in loss functions accept 3 arguments, allowing for instance
9289`train!(Flux.Losses.mse, model, data, opt)` instead of defining `loss3` as above.
9390
94- Note that callback functions are not supported. But arbitrary code can be inserted into the loop.
91+ !!! note
92+ This method has significant changes from the one in Flux ≤ 0.13:
93+ * It now takes the `model` itself, not the result of [`Flux.params`](@ref).
94+ (This is to move away from Zygote's "implicit" parameter handling, with `Grads`.)
95+ * Instead of `loss` being a function which typically accepts two arguments
96+ (the input `x` and expected output `y` from each element of `data`)
97+ now it should typically accept three, the first of which is the `model` itself.
98+ * `data` must iterate tuples, otherwise you get an error.
99+ (Previously non-tuple types were not splatted into the loss.
100+ Pass in `((d,) for d in data)` to simulate this.)
101+ * `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser
102+ such as `Adam()` without this step should give you a warning.
103+ * Callback functions are not supported.
104+ But any code can be included in the above `for` loop.
95105"""
96- function train! (loss, model, data, opt)
106+ function train! (loss, model, data, opt; cb = nothing )
107+ isnothing (cb) || error (""" train! does not support callback functions.
108+ For more control use a loop with `gradient` and `update!`.""" )
97109 losses = Float32[]
98110 @withprogress for (i,d) in enumerate (data)
99111 d isa Tuple || error (""" train! expects as data an iterator producing tuples, but got $(typeof (d)) .
100112 Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""" )
101- l, (g, _... ) = explicit_withgradient (loss, model, d... )
113+ # l, (g, _...) = explicit_withgradient(loss, model, d...) # BTW this un-thunks gradient w.r.t. data. Could avoid that
114+ l, (g, _... ) = explicit_withgradient (m -> loss (m, d... ), model)
102115 isfinite (l) || throw (DomainError (" loss function returned $l , stopping training" ))
103116 opt, model = Optimisers. update! (opt, model, g)
104117 push! (losses, l)
105118 @logprogress Base. haslength (data) ? i/ length (data) : nothing
106119 end
107- return losses # Not entirely sure returning losses is a good idea
120+ return losses # Not entirely sure returning losses is a good idea, as it may conflict with later returning immutable models alla Optimisers.jl
108121end
109122
110123# This method let you use Optimisers.Descent() without setup, when there is no state
0 commit comments