@@ -487,7 +487,7 @@ function apply!(o::NAdam, state, x::AbstractArray{T}, dx) where T
487487end
488488
489489"""
490- AdamW(η = 0.001, β = (0.9, 0.999), γ = 0, ϵ = 1e-8)
490+ AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8)
491491
492492[AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its
493493weight decay regularization.
@@ -497,12 +497,12 @@ weight decay regularization.
497497 the weights.
498498- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
499499 second (β2) momentum estimate.
500- - Weight decay (`γ `): Decay applied to weights during optimisation .
500+ - Weight decay (`λ `): Controls the strength of ``L_2`` regularisation .
501501- Machine epsilon (`ϵ`): Constant to prevent division by zero
502502 (no need to change default)
503503"""
504- AdamW (η = 0.001 , β = (0.9 , 0.999 ), γ = 0 , ϵ = 1e-8 ) =
505- OptimiserChain (Adam (η, β, ϵ), WeightDecay (γ ))
504+ AdamW (η = 0.001 , β = (0.9 , 0.999 ), λ = 0 , ϵ = 1e-8 ) =
505+ OptimiserChain (Adam (η, β, ϵ), WeightDecay (λ ))
506506
507507"""
508508 AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16)
@@ -538,35 +538,79 @@ function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T
538538end
539539
540540"""
541- WeightDecay(γ = 5e-4)
541+ WeightDecay(λ = 5e-4)
542542
543- Decay weights by ``γ``, that is, add `γ .* x` to the gradient `x̄` which will be
544- subtracted from `x` .
543+ Implements ``L_2`` regularisation, also known as ridge regression,
544+ when composed with other rules as the first transformation in an [`OptimiserChain`](@ref) .
545545
546- Typically composed with other optimisers as the first transformation in an [`OptimiserChain`](@ref).
547- This is equivalent to adding ``L_2`` regularization with coefficient ``γ`` to the loss.
546+ It does this by adding `λ .* x` to the gradient. This is equivalent to adding
547+ `λ/2 * sum(abs2, x) == λ/2 * norm(x)^2` to the loss.
548+
549+ See also [`SignDecay`] for ``L_1`` normalisation.
548550
549551# Parameters
550- - Weight decay (`γ `): Decay applied to weights during optimisation .
552+ - Penalty (`λ ≥ 0 `): Controls the strength of the regularisation .
551553"""
552554@def struct WeightDecay <: AbstractRule
553- gamma = 5e-4
555+ lambda = 5e-4
554556end
555557
556558init (o:: WeightDecay , x:: AbstractArray ) = nothing
557559
558560function apply! (o:: WeightDecay , state, x:: AbstractArray{T} , dx) where T
559- γ = T (o. gamma )
560- dx′ = @lazy dx + γ * x
561+ λ = T (o. lambda )
562+ dx′ = @lazy dx + λ * x
561563
562564 return state, dx′
563565end
564566
567+ function adjust (r:: WeightDecay ; gamma = nothing , kw... )
568+ if isnothing (gamma)
569+ return _adjust (r, NamedTuple (kw))
570+ else
571+ Base. depwarn (" The strength of WeightDecay is now field :lambda, not :gamma" , :adjust , force= true )
572+ nt = (; lambda = gamma, NamedTuple (kw)... )
573+ return _adjust (r, nt)
574+ end
575+ end
576+
577+ """
578+ SignDecay(λ = 1e-3)
579+
580+ Implements ``L_1`` regularisation, also known as LASSO regression,
581+ when composed with other rules as the first transformation in an [`OptimiserChain`](@ref).
582+
583+ It does this by adding `λ .* sign(x)` to the gradient. This is equivalent to adding
584+ `λ * sum(abs, x) == λ * norm(x, 1)` to the loss.
585+
586+ See also [`WeightDecay`] for ``L_2`` normalisation.
587+ They can be used together: `OptimiserChain(SignDecay(0.012), WeightDecay(0.034), Adam())`
588+ is equivalent to adding `0.012 * norm(x, 1) + 0.017 * norm(x, 2)^2` to the loss function.
589+
590+ # Parameters
591+ - Penalty (`λ ≥ 0`): Controls the strength of the regularisation.
592+ """
593+ @def struct SignDecay <: AbstractRule
594+ lambda = 1e-3
595+ end
596+
597+ init (o:: SignDecay , x:: AbstractArray ) = nothing
598+
599+ function apply! (o:: SignDecay , state, x:: AbstractArray{T} , dx) where T
600+ λ = T (o. lambda)
601+ dx′ = @lazy dx + λ * sign (x)
602+
603+ return state, dx′
604+ end
605+
606+
565607"""
566608 ClipGrad(δ = 10)
567609
568610Restricts every gradient component to obey `-δ ≤ dx[i] ≤ δ`.
569611
612+ Typically composed with other rules using [`OptimiserChain`](@ref).
613+
570614See also [`ClipNorm`](@ref).
571615"""
572616@def struct ClipGrad <: AbstractRule
@@ -591,6 +635,8 @@ to stay at this threshold (unless `p==0`).
591635Throws an error if the norm is infinite or `NaN`,
592636which you can turn off with `throw = false`.
593637
638+ Typically composed with other rules using [`OptimiserChain`](@ref).
639+
594640See also [`ClipGrad`](@ref).
595641"""
596642struct ClipNorm <: AbstractRule
0 commit comments