-
Notifications
You must be signed in to change notification settings - Fork 10
Added group norm L0 and shifted group norm L0 #117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
c3e9b5d
95a5c6f
198ff38
76ed644
cec91e3
7a17fd7
48ce947
b6075f7
36fa78f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,68 @@ | ||||||||||||||||||||||||||
| # Group L2 norm (times a constant) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| export GroupNormL0 | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @doc raw""" | ||||||||||||||||||||||||||
| GroupNormL0(λ = 1, idx = [:]) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Returns the group ``\ell_0``-norm operator | ||||||||||||||||||||||||||
| ```math | ||||||||||||||||||||||||||
| f(x) = \sum_i \lambda_i \| \|x_{[i]}\|_2 \|_0 | ||||||||||||||||||||||||||
| ``` | ||||||||||||||||||||||||||
| for groups ``x_{[i]}`` and nonnegative weights ``\lambda_i``. | ||||||||||||||||||||||||||
| This assumes that the groups ``x_{[i]}`` are non-overlapping | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| struct GroupNormL0{R <: Real, RR <: AbstractVector{R}, I} | ||||||||||||||||||||||||||
| lambda::RR | ||||||||||||||||||||||||||
| idx::I | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| function GroupNormL0{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} | ||||||||||||||||||||||||||
| if any(lambda .< 0) | ||||||||||||||||||||||||||
| error("weights λ must be nonnegative") | ||||||||||||||||||||||||||
| elseif length(lambda) != length(idx) | ||||||||||||||||||||||||||
| error("number of weights and groups must be the same") | ||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||
| new{R, RR, I}(lambda, idx) | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
AHsu98 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| GroupNormL0(lambda::AbstractVector{R} = [one(R)], idx::I = [:]) where {R <: Real, I} = | ||||||||||||||||||||||||||
| GroupNormL0{R, typeof(lambda), I}(lambda, idx) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| function (f::GroupNormL0)(x::AbstractArray{R}) where {R <: Real} | ||||||||||||||||||||||||||
| sum_c = R(0) | ||||||||||||||||||||||||||
| for (idx, λ) ∈ zip(f.idx, f.lambda) | ||||||||||||||||||||||||||
| y = norm(x[idx]) | ||||||||||||||||||||||||||
| if y>0 | ||||||||||||||||||||||||||
AHsu98 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||
| sum_c += λ | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
| return sum_c | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| function prox!( | ||||||||||||||||||||||||||
| y::AbstractArray{R}, | ||||||||||||||||||||||||||
| f::GroupNormL0{R, RR, I}, | ||||||||||||||||||||||||||
| x::AbstractArray{R}, | ||||||||||||||||||||||||||
| γ::R = R(1), | ||||||||||||||||||||||||||
| ) where {R <: Real, RR <: AbstractVector{R}, I} | ||||||||||||||||||||||||||
|
Comment on lines
+40
to
+45
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
| ysum = R(0) | ||||||||||||||||||||||||||
| for (idx, λ) ∈ zip(f.idx, f.lambda) | ||||||||||||||||||||||||||
| yt = norm(x[idx])^2 | ||||||||||||||||||||||||||
| if yt !=0 | ||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
| ysum += λ | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
| if yt <= 2 * γ * λ | ||||||||||||||||||||||||||
| y[idx] .= 0 | ||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||
| y[idx] .= x[idx] | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
| return ysum | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| fun_name(f::GroupNormL0) = "Group L₀-norm" | ||||||||||||||||||||||||||
| fun_dom(f::GroupNormL0) = "AbstractArray{Float64}, AbstractArray{Complex}" | ||||||||||||||||||||||||||
| fun_expr(f::GroupNormL0) = "x ↦ Σᵢ λᵢ ‖ ‖xᵢ‖₂ ‖₀" | ||||||||||||||||||||||||||
| fun_params(f::GroupNormL0) = "λ = $(f.lambda), g = $(f.g)" | ||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| export ShiftedGroupNormL0 | ||
|
|
||
| mutable struct ShiftedGroupNormL0{ | ||
| R <: Real, | ||
| RR <: AbstractVector{R}, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use V0, V1, V2, V3 here. |
||
| I, | ||
| V0 <: AbstractVector{R}, | ||
| V1 <: AbstractVector{R}, | ||
| V2 <: AbstractVector{R}, | ||
| } <: ShiftedProximableFunction | ||
| h::GroupNormL0{R, RR, I} | ||
| xk::V0 | ||
| sj::V1 | ||
| sol::V2 | ||
| shifted_twice::Bool | ||
| xsy::V2 | ||
|
|
||
| function ShiftedGroupNormL0( | ||
| h::GroupNormL0{R, RR, I}, | ||
| xk::AbstractVector{R}, | ||
| sj::AbstractVector{R}, | ||
| shifted_twice::Bool, | ||
| ) where {R <: Real, RR <: AbstractVector{R}, I} | ||
| sol = similar(sj) | ||
| xsy = similar(sj) | ||
| new{R, RR, I, typeof(xk), typeof(sj), typeof(sol)}(h, xk, sj, sol, shifted_twice, xsy) | ||
| end | ||
| end | ||
|
|
||
| shifted( | ||
| h::GroupNormL0{R, RR, I}, | ||
| xk::AbstractVector{R}, | ||
| ) where {R <: Real, RR <: AbstractVector{R}, I} = ShiftedGroupNormL0(h, xk, zero(xk), false) | ||
| shifted(h::NormL2{R}, xk::AbstractVector{R}) where {R <: Real} = | ||
| ShiftedGroupNormL0(GroupNormL0([h.lambda]), xk, zero(xk), false) | ||
| shifted( | ||
| ψ::ShiftedGroupNormL0{R, RR, I, V0, V1, V2}, | ||
| sj::AbstractVector{R}, | ||
| ) where { | ||
| R <: Real, | ||
| RR <: AbstractVector{R}, | ||
| I, | ||
| V0 <: AbstractVector{R}, | ||
| V1 <: AbstractVector{R}, | ||
| V2 <: AbstractVector{R}, | ||
| } = ShiftedGroupNormL0(ψ.h, ψ.xk, sj, true) | ||
|
|
||
| fun_name(ψ::ShiftedGroupNormL0) = "shifted x ↦ Σᵢ λᵢ ‖ ‖xᵢ‖₂ ‖₀ function" | ||
| fun_expr(ψ::ShiftedGroupNormL0) = "x ↦ Σᵢ λᵢ ‖ ‖xk + sj + t‖₂" | ||
| fun_params(ψ::ShiftedGroupNormL0) = "xk = $(ψ.xk)\n" * " "^14 * "sj = $(ψ.sj)\n" * " "^14 | ||
|
|
||
| function prox!( | ||
| y::AbstractVector{R}, | ||
| ψ::ShiftedGroupNormL0{R, RR, I, V0, V1, V2}, | ||
| q::AbstractVector{R}, | ||
| σ::R, | ||
| ) where { | ||
| R <: Real, | ||
| RR <: AbstractVector{R}, | ||
| I, | ||
| V0 <: AbstractVector{R}, | ||
| V1 <: AbstractVector{R}, | ||
| V2 <: AbstractVector{R}, | ||
| } | ||
| ψ.sol .= q + ψ.xk + ψ.sj | ||
AHsu98 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| for (idx, λ) ∈ zip(ψ.h.idx, ψ.h.lambda) | ||
| snorm = norm(ψ.sol[idx])^2 | ||
| if snorm <= 2 * γ * λ | ||
| y[idx] .= 0 | ||
| else | ||
| y[idx] .= ψ.sol[idx] | ||
| end | ||
| end | ||
| y .-= (ψ.xk + ψ.sj) | ||
AHsu98 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return y | ||
| end | ||
Uh oh!
There was an error while loading. Please reload this page.