@@ -18,7 +18,7 @@ init(o::Descent, x::AbstractArray) = nothing
1818function apply! (o:: Descent , state, x, dx)
1919 η = convert (float (eltype (x)), o. eta)
2020
21- return state, @. . dx * η
21+ return state, @lazy dx * η # @lazy creates a Broadcasted, will later fuse with x .= x .- dx
2222end
2323
2424"""
@@ -41,10 +41,10 @@ Momentum(η = 1f-2, ρ = 9f-1) = Momentum{typeof(η)}(η, ρ)
4141init (o:: Momentum , x:: AbstractArray ) = zero (x)
4242
4343function apply! (o:: Momentum , state, x, dx)
44- η, ρ, v = o. eta, o. rho, state
45- v′ = @. . v = ρ * v - η * dx
44+ η, ρ, mvel = o. eta, o. rho, state
45+ @. . mvel = ρ * mvel + η * dx # Macro @.. broadcasts into mvel if it can, else @. of rhs.
4646
47- return v′, @. . - v′
47+ return mvel, mvel
4848end
4949
5050"""
@@ -67,11 +67,12 @@ Nesterov(η = 1f-3, ρ = 9f-1) = Nesterov{typeof(η)}(η, ρ)
6767init (o:: Nesterov , x:: AbstractArray ) = zero (x)
6868
6969function apply! (o:: Nesterov , state, x, dx)
70- η, ρ, v = o. eta, o. rho, state
71- d = @. . ρ^ 2 * v - (1 + ρ) * η * dx
72- v′ = @. . v = ρ * v - η * dx
70+ η, ρ, vel = o. eta, o. rho, state
71+
72+ newdx = @. - ρ^ 2 * vel + (1 + ρ) * η * dx # Cannot be lazy as this needs the old velocity
73+ @. . vel = ρ * vel - η * dx
7374
74- return v′, @. . - d
75+ return vel, newdx
7576end
7677
7778"""
@@ -101,10 +102,11 @@ init(o::RMSProp, x::AbstractArray) = zero(x)
101102
102103function apply! (o:: RMSProp , state, x, dx)
103104 η, ρ, ϵ, acc = o. eta, o. rho, o. epsilon, state
104- acc′ = @. . acc = ρ * acc + (1 - ρ) * dx^ 2
105- dx′ = @. . dx * (η / (sqrt (acc) + ϵ))
105+
106+ @. . acc = ρ * acc + (1 - ρ) * dx^ 2
107+ dx′ = @lazy dx * (η / (sqrt (acc) + ϵ))
106108
107- return acc′ , dx′
109+ return acc, dx′
108110end
109111
110112"""
@@ -129,15 +131,15 @@ ADAM(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = ADAM{typeof(η)}(
129131
130132init (o:: ADAM , x:: AbstractArray ) = (zero (x), zero (x), o. beta)
131133
132- function apply! (o:: ADAM{T} , state, x, dx) where T
134+ function apply! (o:: ADAM , state, x, dx)
133135 η, β, ϵ = o. eta, o. beta, o. epsilon
134136 mt, vt, βt = state
135137
136- mt′ = @. . mt = β[1 ] * mt + (one (T) - β[1 ]) * dx
137- vt′ = @. . vt = β[2 ] * vt + (one (T) - β[2 ]) * dx ^ 2
138- dx′ = @. . mt / (one (T) - βt[1 ]) / (sqrt (vt / (one (T) - βt[2 ])) + ϵ) * η
138+ @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
139+ @. . vt = β[2 ] * vt + (1 - β[2 ]) * dx ^ 2
140+ dx′ = @lazy mt / (1 - βt[1 ]) / (sqrt (vt / (1 - βt[2 ])) + ϵ) * η
139141
140- return (mt′ , vt′ , βt .* β), dx′
142+ return (mt, vt, βt .* β), dx′
141143end
142144
143145"""
@@ -168,17 +170,17 @@ function apply!(o::RADAM, state, x, dx)
168170
169171 mt, vt, βt, t = state
170172
171- mt′ = @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
172- vt′ = @. . vt = β[2 ] * vt + (1 - β[2 ]) * dx^ 2
173+ @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
174+ @. . vt = β[2 ] * vt + (1 - β[2 ]) * dx^ 2
173175 ρ = ρ∞ - 2 * t * βt[2 ] / (1 - βt[2 ])
174176 if ρ > 4
175177 r = sqrt ((ρ - 4 ) * (ρ - 2 ) * ρ∞/ ((ρ∞ - 4 ) * (ρ∞ - 2 ) * ρ))
176- dx′ = @. . mt / (1 - βt[1 ]) / (sqrt (vt / (1 - βt[2 ])) + ϵ) * η * r
178+ dx′ = @lazy mt / (1 - βt[1 ]) / (sqrt (vt / (1 - βt[2 ])) + ϵ) * η * r
177179 else
178- dx′ = @. . mt / (1 - βt[1 ]) * η
180+ dx′ = @lazy mt / (1 - βt[1 ]) * η
179181 end
180182
181- return (mt′ , vt′ , βt .* β, t + 1 ), dx′
183+ return (mt, vt, βt .* β, t + 1 ), dx′
182184end
183185
184186"""
@@ -205,14 +207,13 @@ init(o::AdaMax, x::AbstractArray) = (zero(x), zero(x), o.beta)
205207
206208function apply! (o:: AdaMax , state, x, dx)
207209 η, β, ϵ = o. eta, o. beta, o. epsilon
208-
209210 mt, ut, βt = state
210211
211- mt′ = @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
212- ut′ = @. . ut = max (β[2 ] * ut, abs (dx))
213- dx′ = @. . (η/ (1 - βt[1 ])) * mt/ (ut + ϵ)
212+ @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
213+ @. . ut = max (β[2 ] * ut, abs (dx))
214+ dx′ = @lazy (η/ (1 - βt[1 ])) * mt/ (ut + ϵ)
214215
215- return (mt′ , ut′ , βt .* β), dx′
216+ return (mt, ut, βt .* β), dx′
216217end
217218
218219"""
@@ -240,16 +241,15 @@ init(o::OADAM, x::AbstractArray) = (zero(x), zero(x), o.beta, zero(x))
240241
241242function apply! (o:: OADAM , state, x, dx)
242243 η, β, ϵ = o. eta, o. beta, o. epsilon
244+ mt, vt, βt, term = state
243245
244- mt, vt, βt, dx_ = state
246+ @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
247+ @. . vt = β[2 ] * vt + (1 - β[2 ]) * dx^ 2
248+ prev = copy (term)
249+ @. . term = η * mt / (1 - βt[1 ]) / (sqrt (vt / (1 - βt[2 ])) + ϵ)
250+ dx′ = @lazy 2 * term - prev
245251
246- mt′ = @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
247- vt′ = @. . vt = β[2 ] * vt + (1 - β[2 ]) * dx^ 2
248- dx = @. . - dx_
249- dx_′ = @. . dx_ = η * mt / (1 - βt[1 ]) / (sqrt (vt / (1 - βt[2 ])) + ϵ)
250- dx′ = @. . dx + 2 * dx_
251-
252- return (mt′, vt′, βt .* β, dx_′), dx′
252+ return (mt, vt, βt .* β, term), dx′
253253end
254254
255255"""
@@ -271,16 +271,16 @@ struct ADAGrad{T}
271271end
272272ADAGrad (η = 1f-1 , ϵ = eps (typeof (η))) = ADAGrad {typeof(η)} (η, ϵ)
273273
274- init (o:: ADAGrad , x:: AbstractArray ) = fill! ( similar (x), o. epsilon)
274+ init (o:: ADAGrad , x:: AbstractArray ) = onevalue ( o. epsilon, x )
275275
276276function apply! (o:: ADAGrad , state, x, dx)
277277 η, ϵ = o. eta, o. epsilon
278278 acc = state
279279
280- acc′ = @. . acc = acc + dx^ 2
281- dx′ = @. . dx * η / (sqrt (acc) + ϵ)
280+ @. . acc = acc + dx^ 2
281+ dx′ = @lazy dx * η / (sqrt (acc) + ϵ)
282282
283- return acc′ , dx′
283+ return acc, dx′
284284end
285285
286286"""
@@ -307,13 +307,12 @@ function apply!(o::ADADelta, state, x, dx)
307307 ρ, ϵ = o. rho, o. epsilon
308308 acc, Δacc = state
309309
310- acc′ = @. . acc = ρ * acc + (1 - ρ) * dx^ 2
311- # DON'T remove epsilon from numerator
312- # or even out of the square roots
313- dx′ = @. . dx * sqrt (Δacc + ϵ) / sqrt (acc + ϵ)
314- Δacc′ = @. . Δacc = ρ * Δacc + (1 - ρ) * dx^ 2
310+ @. . acc = ρ * acc + (1 - ρ) * dx^ 2
311+ # DON'T remove epsilon from numerator or even out of the square roots!
312+ dx′ = @. dx * sqrt (Δacc + ϵ) / sqrt (acc + ϵ) # Cannot be lazy as this needs the old Δacc
313+ @. . Δacc = ρ * Δacc + (1 - ρ) * dx′^ 2
315314
316- return (acc′ , Δacc′ ), dx′
315+ return (acc, Δacc), dx′
317316end
318317
319318"""
@@ -338,19 +337,18 @@ end
338337AMSGrad (η = 1f-3 , β = (9f-1 , 9.99f-1 ), ϵ = eps (typeof (η))) = AMSGrad {typeof(η)} (η, β, ϵ)
339338
340339init (o:: AMSGrad , x:: AbstractArray ) =
341- (fill! ( similar (x), o. epsilon), fill! ( similar ( x), o. epsilon), fill! ( similar ( x), o. epsilon))
340+ (onevalue ( o. epsilon, x), onevalue ( o. epsilon, x), onevalue ( o. epsilon, x ))
342341
343342function apply! (o:: AMSGrad , state, x, dx)
344343 η, β, ϵ = o. eta, o. beta, o. epsilon
345-
346344 mt, vt, v̂t = state
347345
348- mt′ = @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
349- vt′ = @. . vt = β[2 ] * vt + (1 - β[2 ]) * dx ^ 2
350- v̂t′ = @. . v̂t = max (v̂t, vt)
351- dx′ = @. . η * mt / (sqrt (v̂t) + ϵ)
346+ @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
347+ @. . vt = β[2 ] * vt + (1 - β[2 ]) * dx ^ 2
348+ @. . v̂t = max (v̂t, vt)
349+ dx′ = @lazy η * mt / (sqrt (v̂t) + ϵ)
352350
353- return (mt′ , vt′ , v̂t′ ), dx′
351+ return (mt, vt, v̂t), dx′
354352end
355353
356354"""
@@ -381,12 +379,12 @@ function apply!(o::NADAM, state, x, dx)
381379
382380 mt, vt, βt = state
383381
384- mt′ = @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
385- vt′ = @. . vt = β[2 ] * vt + (1 - β[2 ]) * dx^ 2
386- dx′ = @. . (β[1 ] * mt / (1 - β[1 ] * βt[1 ]) + (1 - β[1 ]) * dx / (1 - βt[1 ])) /
382+ @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
383+ @. . vt = β[2 ] * vt + (1 - β[2 ]) * dx^ 2
384+ dx′ = @lazy (β[1 ] * mt / (1 - β[1 ] * βt[1 ]) + (1 - β[1 ]) * dx / (1 - βt[1 ])) /
387385 (sqrt (vt * β[2 ] / (1 - βt[2 ])) + ϵ) * η
388386
389- return (mt′ , vt′ , βt .* β), dx′
387+ return (mt, vt, βt .* β), dx′
390388end
391389
392390"""
@@ -405,7 +403,7 @@ weight decay regularization.
405403 (no need to change default)
406404"""
407405ADAMW (η = 1f-3 , β = (9f-1 , 9.99f-1 ), γ = 0 , ϵ = eps (typeof (η))) =
408- OptimiserChain (ADAM {typeof(η)} (η, β, ϵ), WeightDecay (γ))
406+ OptimiserChain (ADAM {typeof(η)} (η, β, ϵ), WeightDecay {typeof(η)} (γ))
409407
410408"""
411409 AdaBelief(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η)))
@@ -434,11 +432,11 @@ function apply!(o::AdaBelief, state, x, dx)
434432 η, β, ϵ = o. eta, o. beta, o. epsilon
435433 mt, st = state
436434
437- mt′ = @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
438- st′ = @. . st = β[2 ] * st + (1 - β[2 ]) * (dx - mt)^ 2
439- dx′ = @. . η * mt / (sqrt (st) + ϵ)
435+ @. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
436+ @. . st = β[2 ] * st + (1 - β[2 ]) * (dx - mt)^ 2
437+ dx′ = @lazy η * mt / (sqrt (st) + ϵ)
440438
441- return (mt′ , st′ ), dx′
439+ return (mt, st), dx′
442440end
443441
444442"""
@@ -457,7 +455,7 @@ WeightDecay() = WeightDecay(5f-4)
457455init (o:: WeightDecay , x:: AbstractArray ) = nothing
458456
459457function apply! (o:: WeightDecay , state, x, dx)
460- dx′ = @. . dx + o. wd * x
458+ dx′ = @lazy dx + o. wd * x
461459
462460 return state, dx′
463461end
@@ -478,7 +476,7 @@ init(o::ClipGrad, x::AbstractArray) = nothing
478476
479477function apply! (o:: ClipGrad , state, x, dx)
480478 δ = convert (float (eltype (x)), o. delta)
481- dx′ = @. . clamp (dx, - δ, δ)
479+ dx′ = @lazy clamp (dx, - δ, δ)
482480
483481 return state, dx′
484482end
@@ -510,7 +508,7 @@ function apply!(o::ClipNorm, state, x, dx)
510508 end
511509 λ = min (o. omega / nrm, 1 )
512510
513- return state, @. . dx * λ
511+ return state, @lazy dx * λ
514512end
515513
516514"""
0 commit comments