@@ -4,7 +4,7 @@ using LinearAlgebra
44using Optimisers: Optimisers
55using Functors: fmap
66
7- export train!, update!, adjust!, FluxState, @epochs ,
7+ export train!, update!, adjust!, FluxState,
88 Descent, Adam, Momentum, Nesterov, RMSProp,
99 AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief # ,
1010 # InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser,
@@ -15,7 +15,7 @@ export train!, update!, adjust!, FluxState, @epochs,
1515
1616"""
1717 FluxState(rule, state=missing)
18-
18+
1919This is an interface between the all-mutable world Flux.jl likes,
2020and the could-be-immutable world that Optimisers.jl inhabits.
2121
5656
5757# ## Two styles of gradient, and their `train!` functions
5858
59- using ProgressLogging: @progress , @withprogress , @logprogress
59+ using ProgressLogging: @progress , @withprogress , @logprogress # TODO add progress logging again
6060using Zygote: Zygote, Params
6161
62- include (" explicit_train.jl.jl " ) # new!
63- include (" implicit_train.jl.jl " ) # Params etc, Zygote only
62+ include (" explicit_train.jl" ) # new!
63+ include (" implicit_train.jl" ) # Params etc, Zygote only
6464
6565explicit_withgradient (f, args... ) = Zygote. withgradient (f, args... ) # can overload this to use e.g. Yota / Diffractor
6666
67- # using Requires # Flux doesn't use this right now
68- # @init @require Diffractor="9f5e2b26-1114-432f-b630-d3fe2085c51c" begin
69- # @eval function explicit_withgradient(f, args...)
70- # y, back = Diffractor.∂⃖¹(f, args...)
71- # _, grads... = back(Zygote.sensitivity(y))
72- # return (; value = y, gradient = grads)
73- # end
74- # end
75-
76- #=
77-
78- using Diffractor
79- function Flux.Train.explicit_withgradient(f, args...)
80- y, back = Diffractor.∂⃖¹(f, args...)
81- _, grads... = back(one(y))
82- return (; value = y, gradient = grads)
83- end
84-
85- =#
86-
8767# ## Misc. related utilities
8868
8969"""
@@ -107,94 +87,4 @@ function adjust!(opt::FluxState, eta::Real)
10787 return opt
10888end
10989
110- """
111- @epochs N body
112-
113- Run `body` expression `N` times. Mainly useful for quickly doing
114- multiple epochs of training in a REPL.
115-
116- Functionally equivalent to this loop:
117- ```
118- for _ in 1:N
119- body
120- end
121- ```
122- ... but adds progress logging and `@info` messages,
123- and returns the result of the last iteration.
124-
125- # Examples
126- ```jldoctest
127- julia> Flux.@epochs 2 println("hello")
128- [ Info: Epoch 1
129- hello
130- [ Info: Epoch 2
131- hello
132- ```
133- """
134- macro epochs (n, ex)
135- @gensym val
136- body = :(for i in 1 : $ (esc (n))
137- @info " Epoch $i "
138- $ (esc (val)) = $ (esc (ex))
139- end )
140- loop = Expr (:macrocall , Symbol (" @progress" ), __source__, body)
141- Expr (:block , :($ (esc (val)) = nothing ), loop, :($ (esc (val))))
142- # TODO make this actualy return the value? Names aren't right.
143- #
144- # $loop
145- # # @progress for i in 1:$(esc(n))
146- # # @info "Epoch $i"
147- # # $(esc(val)) = $(esc(ex))
148- # # end
149- # $val # DOESN"T WORK! Expr(:macrocall, ...) ?
150- # end
151- end
152-
153- end
154-
155-
156- #=
157-
158- using Flux, Random
159- data = [(rand(3,2).*[i,1,20/i], [i i]) for i in 1:50] |> shuffle!;
160-
161- # This exact code works on Flux@0.13. There, train! returns nothing:
162- model2 = Chain(Dense(3 => 7, relu), Dense(7 => 1))
163- opt2 = Flux.Adam()
164- Flux.train!(Flux.params(model2), data, opt2) do x, y
165- Flux.mse(model2(x), y)
166- end
167- opt2 # contains an IdDict
168-
169- # This is the new "explicit" method of Train
170- model1 = Chain(Dense(3 => 7, relu), Dense(7 => 1))
171- opt1 = Flux.Adam()
172- Flux.train!(model1, data, opt1) do m, x, y
173- Flux.mse(m(x), y)
174- end |> sum
175- opt1 # contains state tree
176-
177- # This is new 3-arg train!, one step not an iteration over data:
178- x1, y1 = data[1]
179- Flux.train!(model1, opt1) do m
180- Flux.mse(m(x1), y1)
181- end
182-
183-
184-
185-
186-
187- julia> using ProgressLogging
188- julia> @macroexpand1 @loop N body
189- begin
190- x = nothing
191- @progress for i in 1:N
192- @info "step $i"
193- x = body
194- end
195- x
196- end
197-
198-
199-
200- =#
90+ end # module
0 commit comments