-
Notifications
You must be signed in to change notification settings - Fork 10
Added group norm L0 and shifted group norm L0 #117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
c3e9b5d
95a5c6f
198ff38
76ed644
cec91e3
7a17fd7
48ce947
b6075f7
36fa78f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,64 @@ | ||||||||||||||||||||||||||
| # Group L0 norm (times a constant) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| export GroupNormL0 | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @doc raw""" | ||||||||||||||||||||||||||
| GroupNormL0(λ = 1, idx = [:]) | ||||||||||||||||||||||||||
| Returns the group ``\ell_0``-norm operator | ||||||||||||||||||||||||||
| ```math | ||||||||||||||||||||||||||
| f(x) = \sum_i \lambda_i \| \|x_{[i]}\|_2 \|_0 | ||||||||||||||||||||||||||
| ``` | ||||||||||||||||||||||||||
| for groups ``x_{[i]}`` and nonnegative weights ``\lambda_i``. | ||||||||||||||||||||||||||
| This assumes that the groups ``x_{[i]}`` are non-overlapping | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| struct GroupNormL0{R <: Real, RR <: AbstractVector{R}, I} | ||||||||||||||||||||||||||
| lambda::RR | ||||||||||||||||||||||||||
| idx::I | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| function GroupNormL0{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} | ||||||||||||||||||||||||||
| any(lambda .< 0) && error("weights λ must be nonnegative") | ||||||||||||||||||||||||||
| length(lambda) != length(idx) && error("number of weights and groups must be the same") | ||||||||||||||||||||||||||
| new{R, RR, I}(lambda, idx) | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| GroupNormL0(lambda::AbstractVector{R} = [one(R)], idx::I = [:]) where {R <: Real, I} = | ||||||||||||||||||||||||||
| GroupNormL0{R, typeof(lambda), I}(lambda, idx) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| function (f::GroupNormL0)(x::AbstractArray{R}) where {R <: Real} | ||||||||||||||||||||||||||
| sum_c = R(0) | ||||||||||||||||||||||||||
| for (idx, λ) ∈ zip(f.idx, f.lambda) | ||||||||||||||||||||||||||
| y = norm(x[idx]) | ||||||||||||||||||||||||||
| if y > 0 | ||||||||||||||||||||||||||
| sum_c += λ | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
| return sum_c | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| function prox!( | ||||||||||||||||||||||||||
| y::AbstractArray{R}, | ||||||||||||||||||||||||||
| f::GroupNormL0{R, RR, I}, | ||||||||||||||||||||||||||
| x::AbstractArray{R}, | ||||||||||||||||||||||||||
| γ::R = R(1), | ||||||||||||||||||||||||||
| ) where {R <: Real, RR <: AbstractVector{R}, I} | ||||||||||||||||||||||||||
|
Comment on lines
+40
to
+45
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
| ysum = R(0) | ||||||||||||||||||||||||||
| for (idx, λ) ∈ zip(f.idx, f.lambda) | ||||||||||||||||||||||||||
| yt = norm(x[idx])^2 | ||||||||||||||||||||||||||
| if yt !=0 | ||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
| ysum += λ | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
| if yt <= 2 * γ * λ | ||||||||||||||||||||||||||
| y[idx] .= 0 | ||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||
| y[idx] .= x[idx] | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
| return ysum | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| fun_name(f::GroupNormL0) = "Group L₀-norm" | ||||||||||||||||||||||||||
| fun_dom(f::GroupNormL0) = "AbstractArray{Float64}, AbstractArray{Complex}" | ||||||||||||||||||||||||||
| fun_expr(f::GroupNormL0) = "x ↦ Σᵢ λᵢ ‖ ‖xᵢ‖₂ ‖₀" | ||||||||||||||||||||||||||
| fun_params(f::GroupNormL0) = "λ = $(f.lambda), g = $(f.g)" | ||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -17,13 +17,9 @@ struct GroupNormL2{R <: Real, RR <: AbstractVector{R}, I} | |||||
| idx::I | ||||||
|
|
||||||
| function GroupNormL2{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| if any(lambda .< 0) | ||||||
| error("weights λ must be nonnegative") | ||||||
| elseif length(lambda) != length(idx) | ||||||
| error("number of weights and groups must be the same") | ||||||
| else | ||||||
| new{R, RR, I}(lambda, idx) | ||||||
| end | ||||||
| any(lambda .< 0) && error("weights λ must be nonnegative") | ||||||
| length(lambda) != length(idx) && error("number of weights and groups must be the same") | ||||||
| new{R, RR, I}(lambda, idx) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There may be others in this file. It’s just easier to remember that V stands for “vector”. |
||||||
| end | ||||||
| end | ||||||
|
|
||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| export ShiftedGroupNormL0 | ||
|
|
||
| mutable struct ShiftedGroupNormL0{ | ||
| R <: Real, | ||
| RR <: AbstractVector{R}, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use V0, V1, V2, V3 here. |
||
| I, | ||
| V0 <: AbstractVector{R}, | ||
| V1 <: AbstractVector{R}, | ||
| V2 <: AbstractVector{R}, | ||
| } <: ShiftedProximableFunction | ||
| h::GroupNormL0{R, RR, I} | ||
| xk::V0 | ||
| sj::V1 | ||
| sol::V2 | ||
| shifted_twice::Bool | ||
| xsy::V2 | ||
|
|
||
| function ShiftedGroupNormL0( | ||
| h::GroupNormL0{R, RR, I}, | ||
| xk::AbstractVector{R}, | ||
| sj::AbstractVector{R}, | ||
| shifted_twice::Bool, | ||
| ) where {R <: Real, RR <: AbstractVector{R}, I} | ||
| sol = similar(sj) | ||
| xsy = similar(sj) | ||
| new{R, RR, I, typeof(xk), typeof(sj), typeof(sol)}(h, xk, sj, sol, shifted_twice, xsy) | ||
| end | ||
| end | ||
|
|
||
| shifted( | ||
| h::GroupNormL0{R, RR, I}, | ||
| xk::AbstractVector{R}, | ||
| ) where {R <: Real, RR <: AbstractVector{R}, I} = ShiftedGroupNormL0(h, xk, zero(xk), false) | ||
| shifted(h::NormL2{R}, xk::AbstractVector{R}) where {R <: Real} = | ||
| ShiftedGroupNormL0(GroupNormL0([h.lambda]), xk, zero(xk), false) | ||
| shifted( | ||
| ψ::ShiftedGroupNormL0{R, RR, I, V0, V1, V2}, | ||
| sj::AbstractVector{R}, | ||
| ) where { | ||
| R <: Real, | ||
| RR <: AbstractVector{R}, | ||
| I, | ||
| V0 <: AbstractVector{R}, | ||
| V1 <: AbstractVector{R}, | ||
| V2 <: AbstractVector{R}, | ||
| } = ShiftedGroupNormL0(ψ.h, ψ.xk, sj, true) | ||
|
|
||
| fun_name(ψ::ShiftedGroupNormL0) = "shifted x ↦ Σᵢ λᵢ ‖ ‖xᵢ‖₂ ‖₀ function" | ||
| fun_expr(ψ::ShiftedGroupNormL0) = "x ↦ Σᵢ λᵢ ‖ ‖xk + sj + t‖₂" | ||
| fun_params(ψ::ShiftedGroupNormL0) = "xk = $(ψ.xk)\n" * " "^14 * "sj = $(ψ.sj)\n" * " "^14 | ||
|
|
||
| function prox!( | ||
| y::AbstractVector{R}, | ||
| ψ::ShiftedGroupNormL0{R, RR, I, V0, V1, V2}, | ||
| q::AbstractVector{R}, | ||
| σ::R, | ||
| ) where { | ||
| R <: Real, | ||
| RR <: AbstractVector{R}, | ||
| I, | ||
| V0 <: AbstractVector{R}, | ||
| V1 <: AbstractVector{R}, | ||
| V2 <: AbstractVector{R}, | ||
| } | ||
| ψ.sol .= q .+ ψ.xk .+ ψ.sj | ||
|
|
||
| for (idx, λ) ∈ zip(ψ.h.idx, ψ.h.lambda) | ||
| snorm = norm(ψ.sol[idx])^2 | ||
| if snorm <= 2 * σ * λ | ||
| y[idx] .= 0 | ||
| else | ||
| y[idx] .= ψ.sol[idx] | ||
| end | ||
| end | ||
| y .-= (ψ.xk .+ ψ.sj) | ||
| return y | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.