From 394725dbb60261445021e69c9e0d786fe5f33fc6 Mon Sep 17 00:00:00 2001 From: MohamedLaghdafHABIBOULLAH Date: Tue, 17 Sep 2024 01:55:03 -0400 Subject: [PATCH 1/2] generalization of iprox for L0 and L1 norm if D is not positive --- src/shiftedNormL0.jl | 18 +++++++++++------- src/shiftedNormL1.jl | 8 ++++++-- test/partial_prox.jl | 1 - 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/shiftedNormL0.jl b/src/shiftedNormL0.jl index 38be69d2..f817ac9b 100644 --- a/src/shiftedNormL0.jl +++ b/src/shiftedNormL0.jl @@ -67,13 +67,17 @@ function iprox!( λ = ψ.h.lambda for i ∈ eachindex(y) di = d[i] - @assert di > 0 - ci = sqrt(2 * λ * di) - xps = ψ.xk[i] + ψ.sj[i] - if abs(di * xps - g[i]) ≤ ci - y[i] = -xps - else - y[i] = -g[i] / di + if d[i] ≤ 0 + y[i] = - 1/eps(R) + continue + else + ci = sqrt(2 * λ * di) + xps = ψ.xk[i] + ψ.sj[i] + if abs(di * xps - g[i]) ≤ ci + y[i] = -xps + else + y[i] = -g[i] / di + end end end return y diff --git a/src/shiftedNormL1.jl b/src/shiftedNormL1.jl index af89b210..e2d23fc3 100644 --- a/src/shiftedNormL1.jl +++ b/src/shiftedNormL1.jl @@ -67,8 +67,12 @@ function iprox!( @. y = -ψ.xk - ψ.sj for i ∈ eachindex(y) - @assert d[i] > 0 - y[i] = min(max(y[i], -g[i] / d[i] - λ / d[i]), -g[i] / d[i] + λ / d[i]) + if d[i] < 0 + y[i] = - 1/eps(R) + continue + else + y[i] = min(max(y[i], -g[i] / d[i] - λ / d[i]), -g[i] / d[i] + λ / d[i]) + end end return y diff --git a/test/partial_prox.jl b/test/partial_prox.jl index e7e5b2ff..bef783d7 100644 --- a/test/partial_prox.jl +++ b/test/partial_prox.jl @@ -58,7 +58,6 @@ for op ∈ (:NormL0, :NormL1, :RootNormLhalf) # tests iprox without bounds if op == :NormL0 || op == :NormL1 ψ = shifted(h, x) - @test_throws AssertionError iprox(ψ, q, zeros(n)) for d ∈ [ones(n), 2 * ones(n)] y = iprox(ψ, q, d) σ = d[1] From b6711dfa410ce796a2a95cb21e645d100cd9d621 Mon Sep 17 00:00:00 2001 From: MohamedLaghdafHABIBOULLAH Date: Sat, 22 Feb 2025 07:32:54 -0500 Subject: [PATCH 2/2] Add the special cases for di<0 and di=0, along with corresponding test --- src/shiftedNormL0.jl | 16 +++++++++++----- src/shiftedNormL1.jl | 13 +++++++++---- test/partial_prox.jl | 17 ++++++++++++++++- 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/src/shiftedNormL0.jl b/src/shiftedNormL0.jl index f817ac9b..425ec365 100644 --- a/src/shiftedNormL0.jl +++ b/src/shiftedNormL0.jl @@ -67,16 +67,22 @@ function iprox!( λ = ψ.h.lambda for i ∈ eachindex(y) di = d[i] - if d[i] ≤ 0 - y[i] = - 1/eps(R) - continue + gi = g[i] + if di < 0 + y[i] = -Inf + elseif di == 0 + if gi == 0 + y[i] = -ψ.xk[i] - ψ.sj[i] + else + y[i] = sign(gi) * Inf + end else ci = sqrt(2 * λ * di) xps = ψ.xk[i] + ψ.sj[i] - if abs(di * xps - g[i]) ≤ ci + if abs(di * xps - gi) ≤ ci y[i] = -xps else - y[i] = -g[i] / di + y[i] = -gi / di end end end diff --git a/src/shiftedNormL1.jl b/src/shiftedNormL1.jl index e2d23fc3..22ab1b75 100644 --- a/src/shiftedNormL1.jl +++ b/src/shiftedNormL1.jl @@ -67,11 +67,16 @@ function iprox!( @. y = -ψ.xk - ψ.sj for i ∈ eachindex(y) - if d[i] < 0 - y[i] = - 1/eps(R) - continue + di = d[i] + gi = g[i] + if di < 0 + y[i] = -Inf + elseif di == 0 + if abs(gi) > λ + y[i] = sign(gi) * Inf + end else - y[i] = min(max(y[i], -g[i] / d[i] - λ / d[i]), -g[i] / d[i] + λ / d[i]) + y[i] = min(max(y[i], -gi / di - λ / di), -gi / di + λ / di) end end diff --git a/test/partial_prox.jl b/test/partial_prox.jl index bef783d7..bf353a40 100644 --- a/test/partial_prox.jl +++ b/test/partial_prox.jl @@ -1,7 +1,8 @@ # test partial prox feature for operators that implement it for op ∈ (:NormL0, :NormL1, :RootNormLhalf) @testset "shifted $op with box partial prox" begin - h = eval(op)(3.14) + λ = 3.14 + h = eval(op)(λ) n = 5 l = zeros(n) u = ones(n) @@ -58,6 +59,7 @@ for op ∈ (:NormL0, :NormL1, :RootNormLhalf) # tests iprox without bounds if op == :NormL0 || op == :NormL1 ψ = shifted(h, x) + # test iprox with d > 0 for d ∈ [ones(n), 2 * ones(n)] y = iprox(ψ, q, d) σ = d[1] @@ -68,6 +70,19 @@ for op ∈ (:NormL0, :NormL1, :RootNormLhalf) end end end + # test iprox with d < 0 + for d ∈ [-ones(n), -2 * ones(n)] + y = iprox(ψ, q, d) + @test all(isinf.(y)) + end + # test iprox with d = 0 + d = zeros(n) + q1 = (λ + 1) * ones(n) + y = iprox(ψ, q1, d) + @test all(isinf.(y)) + q2 = zeros(n) + y = iprox(ψ, q2, d) + @test all(y .== -ψ.xk - ψ.sj) end end end