diff --git a/src/shiftedNormL0.jl b/src/shiftedNormL0.jl index 38be69d2..425ec365 100644 --- a/src/shiftedNormL0.jl +++ b/src/shiftedNormL0.jl @@ -67,13 +67,23 @@ function iprox!( λ = ψ.h.lambda for i ∈ eachindex(y) di = d[i] - @assert di > 0 - ci = sqrt(2 * λ * di) - xps = ψ.xk[i] + ψ.sj[i] - if abs(di * xps - g[i]) ≤ ci - y[i] = -xps - else - y[i] = -g[i] / di + gi = g[i] + if di < 0 + y[i] = -Inf + elseif di == 0 + if gi == 0 + y[i] = -ψ.xk[i] - ψ.sj[i] + else + y[i] = sign(gi) * Inf + end + else + ci = sqrt(2 * λ * di) + xps = ψ.xk[i] + ψ.sj[i] + if abs(di * xps - gi) ≤ ci + y[i] = -xps + else + y[i] = -gi / di + end end end return y diff --git a/src/shiftedNormL1.jl b/src/shiftedNormL1.jl index af89b210..22ab1b75 100644 --- a/src/shiftedNormL1.jl +++ b/src/shiftedNormL1.jl @@ -67,8 +67,17 @@ function iprox!( @. y = -ψ.xk - ψ.sj for i ∈ eachindex(y) - @assert d[i] > 0 - y[i] = min(max(y[i], -g[i] / d[i] - λ / d[i]), -g[i] / d[i] + λ / d[i]) + di = d[i] + gi = g[i] + if di < 0 + y[i] = -Inf + elseif di == 0 + if abs(gi) > λ + y[i] = sign(gi) * Inf + end + else + y[i] = min(max(y[i], -gi / di - λ / di), -gi / di + λ / di) + end end return y diff --git a/test/partial_prox.jl b/test/partial_prox.jl index e7e5b2ff..bf353a40 100644 --- a/test/partial_prox.jl +++ b/test/partial_prox.jl @@ -1,7 +1,8 @@ # test partial prox feature for operators that implement it for op ∈ (:NormL0, :NormL1, :RootNormLhalf) @testset "shifted $op with box partial prox" begin - h = eval(op)(3.14) + λ = 3.14 + h = eval(op)(λ) n = 5 l = zeros(n) u = ones(n) @@ -58,7 +59,7 @@ for op ∈ (:NormL0, :NormL1, :RootNormLhalf) # tests iprox without bounds if op == :NormL0 || op == :NormL1 ψ = shifted(h, x) - @test_throws AssertionError iprox(ψ, q, zeros(n)) + # test iprox with d > 0 for d ∈ [ones(n), 2 * ones(n)] y = iprox(ψ, q, d) σ = d[1] @@ -69,6 +70,19 @@ for op ∈ (:NormL0, :NormL1, :RootNormLhalf) end end end + # test iprox with d < 0 + for d ∈ [-ones(n), -2 * ones(n)] + y = iprox(ψ, q, d) + @test all(isinf.(y)) + end + # test iprox with d = 0 + d = zeros(n) + q1 = (λ + 1) * ones(n) + y = iprox(ψ, q1, d) + @test all(isinf.(y)) + q2 = zeros(n) + y = iprox(ψ, q2, d) + @test all(y .== -ψ.xk - ψ.sj) end end end