diff --git a/src/ShiftedProximalOperators.jl b/src/ShiftedProximalOperators.jl index c400c958..527c13a6 100644 --- a/src/ShiftedProximalOperators.jl +++ b/src/ShiftedProximalOperators.jl @@ -30,6 +30,7 @@ include("utils.jl") include("psvd.jl") include("rootNormLhalf.jl") +include("groupNormL0.jl") include("groupNormL2.jl") include("Rank.jl") include("cappedl1.jl") @@ -39,6 +40,7 @@ include("shiftedNormL0.jl") include("shiftedNormL0Box.jl") include("shiftedRootNormLhalf.jl") include("shiftedNormL1.jl") +include("shiftedGroupNormL0.jl") include("shiftedGroupNormL2.jl") include("shiftedNormL1B2.jl") diff --git a/src/groupNormL0.jl b/src/groupNormL0.jl new file mode 100644 index 00000000..00f1342c --- /dev/null +++ b/src/groupNormL0.jl @@ -0,0 +1,64 @@ +# Group L0 norm (times a constant) + +export GroupNormL0 + +@doc raw""" + GroupNormL0(λ = 1, idx = [:]) + +Returns the group ``\ell_0``-norm operator +```math +f(x) = \sum_i \lambda_i \| \|x_{[i]}\|_2 \|_0 +``` +for groups ``x_{[i]}`` and nonnegative weights ``\lambda_i``. +This assumes that the groups ``x_{[i]}`` are non-overlapping +""" +struct GroupNormL0{R <: Real, RR <: AbstractVector{R}, I} + lambda::RR + idx::I + + function GroupNormL0{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} + any(lambda .< 0) && error("weights λ must be nonnegative") + length(lambda) != length(idx) && error("number of weights and groups must be the same") + new{R, RR, I}(lambda, idx) + end +end + +GroupNormL0(lambda::AbstractVector{R} = [one(R)], idx::I = [:]) where {R <: Real, I} = + GroupNormL0{R, typeof(lambda), I}(lambda, idx) + +function (f::GroupNormL0)(x::AbstractArray{R}) where {R <: Real} + sum_c = R(0) + for (idx, λ) ∈ zip(f.idx, f.lambda) + y = norm(x[idx]) + if y > 0 + sum_c += λ + end + end + return sum_c +end + +function prox!( + y::AbstractArray{R}, + f::GroupNormL0{R, RR, I}, + x::AbstractArray{R}, + γ::R = R(1), +) where {R <: Real, RR <: AbstractVector{R}, I} + ysum = R(0) + for (idx, λ) ∈ zip(f.idx, f.lambda) + yt = norm(x[idx])^2 + if yt !=0 + ysum += λ + end + if yt <= 2 * γ * λ + y[idx] .= 0 + else + y[idx] .= x[idx] + end + end + return ysum +end + +fun_name(f::GroupNormL0) = "Group L₀-norm" +fun_dom(f::GroupNormL0) = "AbstractArray{Float64}, AbstractArray{Complex}" +fun_expr(f::GroupNormL0) = "x ↦ Σᵢ λᵢ ‖ ‖xᵢ‖₂ ‖₀" +fun_params(f::GroupNormL0) = "λ = $(f.lambda), g = $(f.g)" diff --git a/src/groupNormL2.jl b/src/groupNormL2.jl index bf000337..45e8106a 100644 --- a/src/groupNormL2.jl +++ b/src/groupNormL2.jl @@ -17,13 +17,9 @@ struct GroupNormL2{R <: Real, RR <: AbstractVector{R}, I} idx::I function GroupNormL2{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} - if any(lambda .< 0) - error("weights λ must be nonnegative") - elseif length(lambda) != length(idx) - error("number of weights and groups must be the same") - else - new{R, RR, I}(lambda, idx) - end + any(lambda .< 0) && error("weights λ must be nonnegative") + length(lambda) != length(idx) && error("number of weights and groups must be the same") + new{R, RR, I}(lambda, idx) end end diff --git a/src/shiftedGroupNormL0.jl b/src/shiftedGroupNormL0.jl new file mode 100644 index 00000000..3b12808b --- /dev/null +++ b/src/shiftedGroupNormL0.jl @@ -0,0 +1,77 @@ +export ShiftedGroupNormL0 + +mutable struct ShiftedGroupNormL0{ + R <: Real, + RR <: AbstractVector{R}, + I, + V0 <: AbstractVector{R}, + V1 <: AbstractVector{R}, + V2 <: AbstractVector{R}, +} <: ShiftedProximableFunction + h::GroupNormL0{R, RR, I} + xk::V0 + sj::V1 + sol::V2 + shifted_twice::Bool + xsy::V2 + + function ShiftedGroupNormL0( + h::GroupNormL0{R, RR, I}, + xk::AbstractVector{R}, + sj::AbstractVector{R}, + shifted_twice::Bool, + ) where {R <: Real, RR <: AbstractVector{R}, I} + sol = similar(sj) + xsy = similar(sj) + new{R, RR, I, typeof(xk), typeof(sj), typeof(sol)}(h, xk, sj, sol, shifted_twice, xsy) + end +end + +shifted( + h::GroupNormL0{R, RR, I}, + xk::AbstractVector{R}, +) where {R <: Real, RR <: AbstractVector{R}, I} = ShiftedGroupNormL0(h, xk, zero(xk), false) +shifted(h::NormL2{R}, xk::AbstractVector{R}) where {R <: Real} = + ShiftedGroupNormL0(GroupNormL0([h.lambda]), xk, zero(xk), false) +shifted( + ψ::ShiftedGroupNormL0{R, RR, I, V0, V1, V2}, + sj::AbstractVector{R}, +) where { + R <: Real, + RR <: AbstractVector{R}, + I, + V0 <: AbstractVector{R}, + V1 <: AbstractVector{R}, + V2 <: AbstractVector{R}, +} = ShiftedGroupNormL0(ψ.h, ψ.xk, sj, true) + +fun_name(ψ::ShiftedGroupNormL0) = "shifted x ↦ Σᵢ λᵢ ‖ ‖xᵢ‖₂ ‖₀ function" +fun_expr(ψ::ShiftedGroupNormL0) = "x ↦ Σᵢ λᵢ ‖ ‖xk + sj + t‖₂" +fun_params(ψ::ShiftedGroupNormL0) = "xk = $(ψ.xk)\n" * " "^14 * "sj = $(ψ.sj)\n" * " "^14 + +function prox!( + y::AbstractVector{R}, + ψ::ShiftedGroupNormL0{R, RR, I, V0, V1, V2}, + q::AbstractVector{R}, + σ::R, +) where { + R <: Real, + RR <: AbstractVector{R}, + I, + V0 <: AbstractVector{R}, + V1 <: AbstractVector{R}, + V2 <: AbstractVector{R}, +} + ψ.sol .= q .+ ψ.xk .+ ψ.sj + + for (idx, λ) ∈ zip(ψ.h.idx, ψ.h.lambda) + snorm = norm(ψ.sol[idx])^2 + if snorm <= 2 * σ * λ + y[idx] .= 0 + else + y[idx] .= ψ.sol[idx] + end + end + y .-= (ψ.xk .+ ψ.sj) + return y +end