@@ -43,16 +43,60 @@ function Base.show(io::IO, opt::FluxState)
4343 end
4444end
4545
46+ _DESCENT_EXAMPLE = """ # Implicit-style example
47+ This usage matches Flux ≤ v0.13:
48+ ```
49+ opt = Flux.Descent(0.3)
50+
51+ ps = Flux.params(model) # returns a Zygote.Params object
52+
53+ gs = gradient(ps) do # gradient takes a zero-argument anonymous function
54+ loss3(model, x, y) # ... which depends on the global model
55+ end # ... and returns a Zygote.Grads object
56+
57+ Flux.update!(opt, ps, gs)
58+ ```
59+ New on Flux v0.14 is a method `train!(loss, ps, opt)` which performs one step,
60+ rather than iterating over `data`. This is equivalent to `gradient` and `update!` above:
61+ ```
62+ Flux.train!(ps, opt) do
63+ loss3(model, x, y)
64+ end
65+ ```
66+
67+ # Explicit-style example
68+
69+ This no longer uses `Flux.params`, but instead the model itself:
70+ ```
71+ opt = Flux.Descent(0.3) # the same FluxState object
72+
73+ Flux.train!(model, opt) do m # now explicitly depends on the model
74+ loss3(m, x, y)
75+ end
76+ ```
77+ """
4678for opt in [
4779 :Descent , :Adam , :Momentum , :Nesterov , :RMSProp ,
4880 :AdaGrad , :AdaMax , :AdaDelta , :AMSGrad , :NAdam , :AdamW , :RAdam , :OAdam , :AdaBelief ,
49- # :InvDecay, :ExpDecay, :WeightDecay, :stop, :skip, : Optimiser,
50- # :ClipValue , :ClipNorm,
51- # TODO check that parameters line up nicely old-vs-new, and include the remaining rules
81+ # :InvDecay, :ExpDecay, :WeightDecay, :Optimiser,
82+ :ClipGrad , :ClipNorm ,
83+ # TODO sort out the remaining rules
5284]
53- @eval $ opt (parameters... ; kw... ) = FluxState (Optimisers.$ opt (parameters... ; kw... ), missing )
85+ @eval begin
86+ $ opt (parameters... ; kw... ) = FluxState (Optimisers.$ opt (parameters... ; kw... ), missing )
87+ str = string (""" Flux.$($ opt) (args...)
88+
89+ Returns `FluxState` wrapper around the following rule definition from Optimisers.jl,
90+ allowing its use with `Flux.train!` (in the same manner as `Flux.AbstractOptimiser` objects on Flux ≤ v0.13).
91+ Accepts the same arguments, with the same defaults, as the underlying rule:
92+
93+ """ , @doc (Optimisers.$ opt), $ opt == Descent ? _DESCENT_EXAMPLE : " " )
94+ @doc str $ opt
95+ end
5496end
5597
98+ @deprecate ClipValue ClipGrad
99+
56100
57101# ## Two styles of gradient, and their `train!` functions
58102
0 commit comments