@@ -47,9 +47,6 @@ function setup(rule::Optimisers.AbstractRule, model)
4747 state
4848end
4949
50- # opt = Flux.setup(Adam(), model); train!(model, opt) do m ...
51- setup (model, rule:: Optimisers.AbstractRule ) = setup (rule, model)
52-
5350"""
5451 train!(loss, model, data, opt)
5552
@@ -112,56 +109,10 @@ function train!(loss, model, data, opt)
112109 return losses # Not entirely sure returning losses is a good idea
113110end
114111
115- """
116- train!(loss, model, opt)
117-
118- Uses a `loss` function improve the `model`'s parameters.
119-
120- While the 4-argument method of `train!` iterates over a dataset,
121- this 3-argument method is for a single datapoint, and calls `gradient` just once.
122- It expects a function `loss` which takes just one argument, the model.
123- For example:
124- ```
125- opt = Flux.setup(Adam(), model) # explicit setup
126- train!(model, opt) do m # the model is passed to the function as `m`
127- Flux.crossentropy(m(x1), y1) # but the data point `(x1, y1)` is closed over.
128- end
129- ```
130- This calls `Zygote.withgradient(m -> Flux.crossentropy(m(x1), y1), model)`.
131- (The `do` block is another syntax for this anonymous function.)
132- Then it updates the parameters contained within `model` according to `opt`.
133- Finally it returns the value of the loss function.
134-
135- To iterate over a dataset, writing a loop allows more control than
136- calling 4-argument `train!`. For example, this adds printing and an early stop:
137- ```
138- data = Flux.DataLoader((Xtrain, Ytrain), batchsize=32)
139- opt = Flux.setup(Adam(), model)
140- for (i, d) in enumerate(data)
141- x, y = d
142- ell = Flux.train!(m -> Flux.crossentropy(m(x), y), model, opt)
143- i%10==0 && println("on step \$ i, the loss was \$ ell") # prints every 10th step
144- ell<0.1 && break # stops training
145- end
146- ```
147-
148- !!! note
149- This method has no implicit `Params` analog in Flux ≤ 0.13.
150- """
151- function train! (loss, model, opt)
152- l, (g, _... ) = explicit_withgradient (loss, model)
153- isfinite (l) || return l
154- _, model = Optimisers. update! (opt, model, g)
155- return l
156- end
157-
158- # These methods let you use Optimisers.Descent() without setup, when there is no state
112+ # This method let you use Optimisers.Descent() without setup, when there is no state
159113function train! (loss, model, data, rule:: Optimisers.AbstractRule )
160114 train! (loss, model, data, _rule_to_state (model, rule))
161115end
162- function train! (loss, model, rule:: Optimisers.AbstractRule )
163- train! (loss, model, _rule_to_state (model, rule))
164- end
165116
166117function _rule_to_state (model, rule:: Optimisers.AbstractRule )
167118 state = setup (rule, model)
0 commit comments