@@ -63,17 +63,17 @@ according to a particular optimisation rule `opt`.
6363 * Instead of `loss` being a function which typically accepts two arguments
6464 (the input `x` and expected output `y` from each element of `data`)
6565 now it should typically accept three, the first of which is the `model` itself.
66- * `data` should iterate tuples or NamedTuples
67- * `opt` should be the result of [`Flux.setup`](@ref).
66+ * `data` must iterate tuples. Each `d in data` is used as `loss(model, d...)`.
67+ * `opt` should be the result of [`Flux.setup`](@ref), it will warn you if not .
6868 * Callback functions are not supported.
6969
7070For example, with these definitions...
7171```
72- data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple (or NamedTuple)
72+ data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple
7373
74- loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
74+ loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
7575
76- opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta
76+ opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta
7777```
7878...calling `train!(loss3, model, data, opt)` runs a loop much like this:
7979```
@@ -82,19 +82,28 @@ for d in data
8282 Optimisers.update!(opt, model, ∂L∂m)
8383end
8484```
85- Stops with a `DomainError` if the loss is infinite or `NaN` at any point.
85+ You can also write this loop yourself, if you need more flexibility.
86+ Besides the loop, `train!` will:
8687
87- Returns a vector containing the value of the loss function at each datapoint .
88+ * Stop with a `DomainError` if the loss is infinite or `NaN` at any point .
8889
89- The built-in loss functions accept 3 arguments, allowing for instance `train!(Flux.Losses.mse, model, data, opt)` .
90+ * Return a vector containing the value of the loss function at each datapoint .
9091
91- Callback functions are not supported. But see 3-argument `train!(loss, model, opt)` for an
92- easy way to construct more complicated training loops.
92+ * Show a progress bar using [`@withprogress`](https://github.com/JuliaLogging/ProgressLogging.jl).
93+
94+ Note that the built-in loss functions accept 3 arguments, allowing for instance
95+ `train!(Flux.Losses.mse, model, data, opt)` instead of defining `loss3` as above.
96+
97+ Note that callback functions are not supported. But arbitrary code can be inserted into the loop.
9398"""
9499function train! (loss, model, data, opt)
100+ Base. issingletontype (typeof (loss)) || error (""" train! with explicit parameter expects a pure loss function.
101+ It must not close over the model, like loss(x,y) = mse(model(x), y). """ )
95102 losses = Float32[]
96103 @withprogress for (i,d) in enumerate (data)
97- l, (g, _... ) = explicit_withgradient (loss, model, data_splat (d)... )
104+ d isa Tuple || error (""" train! expects as data an iterator producing tuples, but got $(typeof (d)) .
105+ Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""" )
106+ l, (g, _... ) = explicit_withgradient (loss, model, d... )
98107 isfinite (l) || throw (DomainError (" loss function returned $l , stopping training" ))
99108 opt, model = Optimisers. update! (opt, model, g)
100109 push! (losses, l)
@@ -103,12 +112,6 @@ function train!(loss, model, data, opt)
103112 return losses # Not entirely sure returning losses is a good idea
104113end
105114
106- data_splat (x:: T ) where T = error (""" train! expects every d in data be a Tuple or a NamedTuple, got $T
107- To allow this type, define `Flux.Train.data_splat(x::$T ) = (x,)`""" )
108- data_splat (x:: Tuple ) = x
109- data_splat (x:: NamedTuple ) = x
110- data_splat (x:: AbstractArray{<:Number} ) = (x,)
111-
112115"""
113116 train!(loss, model, opt)
114117
0 commit comments