88
99"""
1010 Descent(η = 1f-1)
11+ Descent(; eta)
1112
1213Classic gradient descent optimiser with learning rate `η`.
1314For each parameter `p` and its gradient `dp`, this runs `p -= η*dp`.
1415
1516# Parameters
16- - Learning rate (`η`): Amount by which gradients are discounted before updating
17+ - Learning rate (`η == eta `): Amount by which gradients are discounted before updating
1718 the weights.
1819"""
1920struct Descent{T} <: AbstractRule
2021 eta:: T
2122end
22- Descent () = Descent ( 1f-1 )
23+ Descent (; eta = 1f-1 ) = Descent (eta )
2324
2425init (o:: Descent , x:: AbstractArray ) = nothing
2526
3738
3839"""
3940 Momentum(η = 0.01, ρ = 0.9)
41+ Momentum(; [eta, rho])
4042
4143Gradient descent optimizer with learning rate `η` and momentum `ρ`.
4244
4345# Parameters
44- - Learning rate (`η`): Amount by which gradients are discounted before updating
46+ - Learning rate (`η == eta `): Amount by which gradients are discounted before updating
4547 the weights.
46- - Momentum (`ρ`): Controls the acceleration of gradient descent in the
48+ - Momentum (`ρ == rho `): Controls the acceleration of gradient descent in the
4749 prominent direction, in effect dampening oscillations.
4850"""
4951@def struct Momentum <: AbstractRule
8991
9092"""
9193 RMSProp(η = 0.001, ρ = 0.9, ϵ = 1e-8; centred = false)
94+ RMSProp(; [eta, rho, epsilon, centred])
9295
9396Optimizer using the
9497[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
@@ -99,11 +102,11 @@ generally don't need tuning.
99102gradients by an estimate their variance, instead of their second moment.
100103
101104# Parameters
102- - Learning rate (`η`): Amount by which gradients are discounted before updating
105+ - Learning rate (`η == eta `): Amount by which gradients are discounted before updating
103106 the weights.
104- - Momentum (`ρ`): Controls the acceleration of gradient descent in the
107+ - Momentum (`ρ == rho `): Controls the acceleration of gradient descent in the
105108 prominent direction, in effect dampening oscillations.
106- - Machine epsilon (`ϵ`): Constant to prevent division by zero
109+ - Machine epsilon (`ϵ == epsilon `): Constant to prevent division by zero
107110 (no need to change default)
108111- Keyword `centred` (or `centered`): Indicates whether to use centred variant
109112 of the algorithm.
@@ -115,10 +118,11 @@ struct RMSProp <: AbstractRule
115118 centred:: Bool
116119end
117120
118- function RMSProp (η = 0.001 , ρ = 0.9 , ϵ = 1e-8 ; centred:: Bool = false , centered:: Bool = false )
121+ function RMSProp (η, ρ = 0.9 , ϵ = 1e-8 ; centred:: Bool = false , centered:: Bool = false )
119122 η < 0 && throw (DomainError (η, " the learning rate cannot be negative" ))
120123 RMSProp (η, ρ, ϵ, centred | centered)
121124end
125+ RMSProp (; eta = 0.001 , rho = 0.9 , epsilon = 1e-8 , kw... ) = RMSProp (eta, rho, epsilon; kw... )
122126
123127init (o:: RMSProp , x:: AbstractArray ) = (zero (x), o. centred ? zero (x) : false )
124128
@@ -488,22 +492,27 @@ end
488492
489493"""
490494 AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8)
495+ AdamW(; [eta, beta, lambda, epsilon])
491496
492497[AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its
493498weight decay regularization.
499+ Implemented as an [`OptimiserChain`](@ref) of [`Adam`](@ref) and [`WeightDecay`](@ref)`.
494500
495501# Parameters
496- - Learning rate (`η`): Amount by which gradients are discounted before updating
502+ - Learning rate (`η == eta `): Amount by which gradients are discounted before updating
497503 the weights.
498- - Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
504+ - Decay of momentums (`β::Tuple == beta `): Exponential decay for the first (β1) and the
499505 second (β2) momentum estimate.
500- - Weight decay (`λ`): Controls the strength of ``L_2`` regularisation.
501- - Machine epsilon (`ϵ`): Constant to prevent division by zero
506+ - Weight decay (`λ == lambda `): Controls the strength of ``L_2`` regularisation.
507+ - Machine epsilon (`ϵ == epsilon `): Constant to prevent division by zero
502508 (no need to change default)
503509"""
504- AdamW (η = 0.001 , β = (0.9 , 0.999 ), λ = 0 , ϵ = 1e-8 ) =
510+ AdamW (η, β = (0.9 , 0.999 ), λ = 0. 0 , ϵ = 1e-8 ) =
505511 OptimiserChain (Adam (η, β, ϵ), WeightDecay (λ))
506512
513+ AdamW (; eta = 0.001 , beta = (0.9 , 0.999 ), lambda = 0 , epsilon = 1e-8 ) =
514+ OptimiserChain (Adam (eta, beta, epsilon), WeightDecay (lambda))
515+
507516"""
508517 AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16)
509518
0 commit comments