From 0b91ffb6538f8babc9990ec429c89e5c7d7d7666 Mon Sep 17 00:00:00 2001 From: Tamas Hakkel Date: Tue, 18 Mar 2025 16:38:08 +0100 Subject: [PATCH 1/7] move traits to ProximalCore.jl & implement is_locally_smooth where suitable --- Project.toml | 4 +-- src/ProximalOperators.jl | 29 +++++++++++------ src/calculus/conjugate.jl | 12 +++---- src/calculus/distL2.jl | 5 +-- src/calculus/pointwiseMinimum.jl | 4 +-- src/calculus/postcompose.jl | 11 ++++--- src/calculus/precompose.jl | 11 ++++--- src/calculus/precomposeDiagonal.jl | 11 ++++--- src/calculus/regularize.jl | 3 +- src/calculus/separableSum.jl | 19 +++++------ src/calculus/slicedSeparableSum.jl | 44 +++++++++++++------------- src/calculus/sum.jl | 21 ++++++------ src/calculus/tilt.jl | 4 +-- src/calculus/translate.jl | 11 ++++--- src/functions/elasticNet.jl | 3 +- src/functions/indAffine.jl | 2 +- src/functions/indAffineIterative.jl | 2 +- src/functions/indBallL0.jl | 2 +- src/functions/indBallL1.jl | 4 +-- src/functions/indBallL2.jl | 2 +- src/functions/indBallRank.jl | 4 +-- src/functions/indBinary.jl | 2 +- src/functions/indBox.jl | 2 +- src/functions/indExp.jl | 2 +- src/functions/indFree.jl | 4 +-- src/functions/indGraph.jl | 4 +-- src/functions/indHalfspace.jl | 2 +- src/functions/indHyperslab.jl | 2 +- src/functions/indNonnegative.jl | 2 +- src/functions/indNonpositive.jl | 2 +- src/functions/indPSD.jl | 2 +- src/functions/indPoint.jl | 4 +-- src/functions/indPolyhedral.jl | 2 +- src/functions/indPolyhedralOSQP.jl | 2 +- src/functions/indSOC.jl | 4 +-- src/functions/indSimplex.jl | 2 +- src/functions/indSphereL2.jl | 2 +- src/functions/indStiefel.jl | 2 +- src/functions/indZero.jl | 6 ++-- src/functions/leastSquaresIterative.jl | 2 +- src/functions/logBarrier.jl | 1 + src/functions/logisticLoss.jl | 2 +- src/functions/normL1.jl | 1 + src/functions/normL2.jl | 1 + src/functions/normLinf.jl | 2 ++ src/functions/quadraticIterative.jl | 2 +- src/functions/sumPositive.jl | 2 ++ src/utilities/traits.jl | 32 ------------------- test/Project.toml | 1 + test/runtests.jl | 35 ++++++++++---------- test/test_calls.jl | 2 +- test/test_gradients.jl | 2 +- test/test_huberLoss.jl | 6 ++-- test/test_indAffine.jl | 8 ++--- test/test_indPolyhedral.jl | 4 +-- test/test_leastSquares.jl | 8 ++--- test/test_moreauEnvelope.jl | 12 +++---- test/test_pointwiseMinimum.jl | 4 +-- test/test_precompose.jl | 6 ++-- test/test_quadratic.jl | 6 ++-- test/test_results.jl | 2 +- test/test_sum.jl | 12 +++---- 62 files changed, 202 insertions(+), 207 deletions(-) delete mode 100644 src/utilities/traits.jl diff --git a/Project.toml b/Project.toml index cfb96db8..8de132c4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ProximalOperators" uuid = "a725b495-10eb-56fe-b38b-717eba820537" -version = "0.16.1" +version = "0.17.0" [deps] IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" @@ -15,7 +15,7 @@ TSVD = "9449cd9e-2762-5aa3-a617-5413e99d722e" IterativeSolvers = "0.8 - 0.9" LinearAlgebra = "1.4" OSQP = "0.3 - 0.8" -ProximalCore = "0.1" +ProximalCore = "0.2" SparseArrays = "1.4" SuiteSparse = "1.4" TSVD = "0.3 - 0.4" diff --git a/src/ProximalOperators.jl b/src/ProximalOperators.jl index 7b9575a1..bd3fc6e7 100644 --- a/src/ProximalOperators.jl +++ b/src/ProximalOperators.jl @@ -4,15 +4,27 @@ module ProximalOperators using LinearAlgebra import ProximalCore: prox, prox!, gradient, gradient! -import ProximalCore: is_convex, is_generalized_quadratic +import ProximalCore: + is_convex, + is_strongly_convex, + is_generalized_quadratic, + is_proximable, + is_separable, + is_singleton_indicator, + is_cone_indicator, + is_affine_indicator, + is_set_indicator, + is_smooth, + is_locally_smooth, + is_support -const RealOrComplex{R <: Real} = Union{R, Complex{R}} -const HermOrSym{T, S} = Union{Hermitian{T, S}, Symmetric{T, S}} -const RealBasedArray{R} = AbstractArray{C, N} where {C <: RealOrComplex{R}, N} -const TupleOfArrays{R} = Tuple{RealBasedArray{R}, Vararg{RealBasedArray{R}}} -const ArrayOrTuple{R} = Union{RealBasedArray{R}, TupleOfArrays{R}} -const TransposeOrAdjoint{M} = Union{Transpose{C,M} where C, Adjoint{C,M} where C} -const Maybe{T} = Union{T, Nothing} +const RealOrComplex{R<:Real} = Union{R,Complex{R}} +const HermOrSym{T,S} = Union{Hermitian{T,S},Symmetric{T,S}} +const RealBasedArray{R} = AbstractArray{C,N} where {C<:RealOrComplex{R},N} +const TupleOfArrays{R} = Tuple{RealBasedArray{R},Vararg{RealBasedArray{R}}} +const ArrayOrTuple{R} = Union{RealBasedArray{R},TupleOfArrays{R}} +const TransposeOrAdjoint{M} = Union{Transpose{C,M} where C,Adjoint{C,M} where C} +const Maybe{T} = Union{T,Nothing} export prox, prox!, gradient, gradient! @@ -23,7 +35,6 @@ include("utilities/linops.jl") include("utilities/symmetricpacked.jl") include("utilities/uniformarrays.jl") include("utilities/normdiff.jl") -include("utilities/traits.jl") # Basic functions diff --git a/src/calculus/conjugate.jl b/src/calculus/conjugate.jl index cfb1138c..23313faa 100644 --- a/src/calculus/conjugate.jl +++ b/src/calculus/conjugate.jl @@ -20,14 +20,14 @@ struct Conjugate{T} end end -is_prox_accurate(::Type{Conjugate{T}}) where T = is_prox_accurate(T) +is_proximable(::Type{Conjugate{T}}) where T = is_proximable(T) is_convex(::Type{Conjugate{T}}) where T = true -is_cone(::Type{Conjugate{T}}) where T = is_cone(T) && is_convex(T) +is_cone_indicator(::Type{Conjugate{T}}) where T = is_cone_indicator(T) && is_convex(T) is_smooth(::Type{Conjugate{T}}) where T = is_strongly_convex(T) is_strongly_convex(::Type{Conjugate{T}}) where T = is_smooth(T) is_generalized_quadratic(::Type{Conjugate{T}}) where T = is_generalized_quadratic(T) -is_set(::Type{Conjugate{T}}) where T = is_convex(T) && is_support(T) -is_positively_homogeneous(::Type{Conjugate{T}}) where T = is_convex(T) && is_set(T) +is_set_indicator(::Type{Conjugate{T}}) where T = is_convex(T) && is_support(T) +is_positively_homogeneous(::Type{Conjugate{T}}) where T = is_convex(T) && is_set_indicator(T) Conjugate(f::T) where T = Conjugate{T}(f) @@ -37,7 +37,7 @@ Conjugate(f::T) where T = Conjugate{T}(f) function prox!(y, g::Conjugate, x, gamma) # Moreau identity v = prox!(y, g.f, x/gamma, 1/gamma) - if is_set(g) + if is_set_indicator(g) v = real(eltype(x))(0) else v = real(dot(x, y)) - gamma * real(dot(y, y)) - v @@ -50,7 +50,7 @@ end function prox_naive(g::Conjugate, x, gamma) y, v = prox_naive(g.f, x/gamma, 1/gamma) - return x - gamma * y, if is_set(g) real(eltype(x))(0) else real(dot(x, y)) - gamma * real(dot(y, y)) - v end + return x - gamma * y, if is_set_indicator(g) real(eltype(x))(0) else real(dot(x, y)) - gamma * real(dot(y, y)) - v end end # TODO: hard-code conjugation rules? E.g. precompose/epicompose diff --git a/src/calculus/distL2.jl b/src/calculus/distL2.jl index c80fd79f..f7947fb9 100644 --- a/src/calculus/distL2.jl +++ b/src/calculus/distL2.jl @@ -14,7 +14,7 @@ struct DistL2{R, T} ind::T lambda::R function DistL2{R, T}(ind::T, lambda::R) where {R, T} - if !is_set(ind) + if !is_set_indicator(ind) error("`ind` must be a convex set") end if lambda <= 0 @@ -25,7 +25,8 @@ struct DistL2{R, T} end end -is_prox_accurate(::Type{DistL2{R, T}}) where {R, T} = is_prox_accurate(T) +is_proximable(::Type{DistL2{R, T}}) where {R, T} = is_proximable(T) +is_locally_smooth(::Type{DistL2{R, T}}) where {R, T} = is_proximable(T) is_convex(::Type{DistL2{R, T}}) where {R, T} = is_convex(T) DistL2(ind::T, lambda::R=1) where {R, T} = DistL2{R, T}(ind, lambda) diff --git a/src/calculus/pointwiseMinimum.jl b/src/calculus/pointwiseMinimum.jl index c00dd78f..71ad8663 100644 --- a/src/calculus/pointwiseMinimum.jl +++ b/src/calculus/pointwiseMinimum.jl @@ -17,8 +17,8 @@ PointwiseMinimum(fs...) = PointwiseMinimum{typeof(fs)}(fs) component_types(::Type{PointwiseMinimum{T}}) where T = fieldtypes(T) -@generated is_set(::Type{T}) where T <: PointwiseMinimum = return all(is_set, component_types(T)) ? :(true) : :(false) -@generated is_cone(::Type{T}) where T <: PointwiseMinimum = return all(is_cone, component_types(T)) ? :(true) : :(false) +@generated is_set_indicator(::Type{T}) where T <: PointwiseMinimum = return all(is_set_indicator, component_types(T)) ? :(true) : :(false) +@generated is_cone_indicator(::Type{T}) where T <: PointwiseMinimum = return all(is_cone_indicator, component_types(T)) ? :(true) : :(false) function (g::PointwiseMinimum{T})(x) where T return minimum(f(x) for f in g.fs) diff --git a/src/calculus/postcompose.jl b/src/calculus/postcompose.jl index ec87066f..30cc8891 100644 --- a/src/calculus/postcompose.jl +++ b/src/calculus/postcompose.jl @@ -23,14 +23,15 @@ struct Postcompose{T, R, S} end end -is_prox_accurate(::Type{<:Postcompose{T}}) where T = is_prox_accurate(T) +is_proximable(::Type{<:Postcompose{T}}) where T = is_proximable(T) is_separable(::Type{<:Postcompose{T}}) where T = is_separable(T) is_convex(::Type{<:Postcompose{T}}) where T = is_convex(T) -is_set(::Type{<:Postcompose{T}}) where T = is_set(T) -is_singleton(::Type{<:Postcompose{T}}) where T = is_singleton(T) -is_cone(::Type{<:Postcompose{T}}) where T = is_cone(T) -is_affine(::Type{<:Postcompose{T}}) where T = is_affine(T) +is_set_indicator(::Type{<:Postcompose{T}}) where T = is_set_indicator(T) +is_singleton_indicator(::Type{<:Postcompose{T}}) where T = is_singleton_indicator(T) +is_cone_indicator(::Type{<:Postcompose{T}}) where T = is_cone_indicator(T) +is_affine_indicator(::Type{<:Postcompose{T}}) where T = is_affine_indicator(T) is_smooth(::Type{<:Postcompose{T}}) where T = is_smooth(T) +is_locally_smooth(::Type{<:Postcompose{T}}) where T = is_locally_smooth(T) is_generalized_quadratic(::Type{<:Postcompose{T}}) where T = is_generalized_quadratic(T) is_strongly_convex(::Type{<:Postcompose{T}}) where T = is_strongly_convex(T) diff --git a/src/calculus/precompose.jl b/src/calculus/precompose.jl index 6a3a2776..a104ceb2 100644 --- a/src/calculus/precompose.jl +++ b/src/calculus/precompose.jl @@ -37,13 +37,14 @@ struct Precompose{T, M, U, V} end end -is_prox_accurate(::Type{<:Precompose{T}}) where T = is_prox_accurate(T) +is_proximable(::Type{<:Precompose{T}}) where T = is_proximable(T) is_convex(::Type{<:Precompose{T}}) where T = is_convex(T) -is_set(::Type{<:Precompose{T}}) where T = is_set(T) -is_singleton(::Type{<:Precompose{T}}) where T = is_singleton(T) -is_cone(::Type{<:Precompose{T}}) where T = is_cone(T) -is_affine(::Type{<:Precompose{T}}) where T = is_affine(T) +is_set_indicator(::Type{<:Precompose{T}}) where T = is_set_indicator(T) +is_singleton_indicator(::Type{<:Precompose{T}}) where T = is_singleton_indicator(T) +is_cone_indicator(::Type{<:Precompose{T}}) where T = is_cone_indicator(T) +is_affine_indicator(::Type{<:Precompose{T}}) where T = is_affine_indicator(T) is_smooth(::Type{<:Precompose{T}}) where T = is_smooth(T) +is_locally_smooth(::Type{<:Precompose{T}}) where T = is_locally_smooth(T) is_generalized_quadratic(::Type{<:Precompose{T}}) where T = is_generalized_quadratic(T) is_strongly_convex(::Type{<:Precompose{T}}) where T = is_strongly_convex(T) diff --git a/src/calculus/precomposeDiagonal.jl b/src/calculus/precomposeDiagonal.jl index f33d2cf7..07874581 100644 --- a/src/calculus/precomposeDiagonal.jl +++ b/src/calculus/precomposeDiagonal.jl @@ -32,13 +32,14 @@ struct PrecomposeDiagonal{T, R, S} end is_separable(::Type{<:PrecomposeDiagonal{T}}) where T = is_separable(T) -is_prox_accurate(::Type{<:PrecomposeDiagonal{T}}) where T = is_prox_accurate(T) +is_proximable(::Type{<:PrecomposeDiagonal{T}}) where T = is_proximable(T) is_convex(::Type{<:PrecomposeDiagonal{T}}) where T = is_convex(T) -is_set(::Type{<:PrecomposeDiagonal{T}}) where T = is_set(T) -is_singleton(::Type{<:PrecomposeDiagonal{T}}) where T = is_singleton(T) -is_cone(::Type{<:PrecomposeDiagonal{T}}) where T = is_cone(T) -is_affine(::Type{<:PrecomposeDiagonal{T}}) where T = is_affine(T) +is_set_indicator(::Type{<:PrecomposeDiagonal{T}}) where T = is_set_indicator(T) +is_singleton_indicator(::Type{<:PrecomposeDiagonal{T}}) where T = is_singleton_indicator(T) +is_cone_indicator(::Type{<:PrecomposeDiagonal{T}}) where T = is_cone_indicator(T) +is_affine_indicator(::Type{<:PrecomposeDiagonal{T}}) where T = is_affine_indicator(T) is_smooth(::Type{<:PrecomposeDiagonal{T}}) where T = is_smooth(T) +is_locally_smooth(::Type{<:PrecomposeDiagonal{T}}) where T = is_locally_smooth(T) is_generalized_quadratic(::Type{<:PrecomposeDiagonal{T}}) where T = is_generalized_quadratic(T) is_strongly_convex(::Type{<:PrecomposeDiagonal{T}}) where T = is_strongly_convex(T) diff --git a/src/calculus/regularize.jl b/src/calculus/regularize.jl index 826a9841..3d266db8 100644 --- a/src/calculus/regularize.jl +++ b/src/calculus/regularize.jl @@ -25,9 +25,10 @@ struct Regularize{T, S, A} end is_separable(::Type{<:Regularize{T}}) where T = is_separable(T) -is_prox_accurate(::Type{<:Regularize{T}}) where T = is_prox_accurate(T) +is_proximable(::Type{<:Regularize{T}}) where T = is_proximable(T) is_convex(::Type{<:Regularize{T}}) where T = is_convex(T) is_smooth(::Type{<:Regularize{T}}) where T = is_smooth(T) +is_locally_smooth(::Type{<:Regularize{T}}) where T = is_locally_smooth(T) is_generalized_quadratic(::Type{<:Regularize{T}}) where T = is_generalized_quadratic(T) is_strongly_convex(::Type{<:Regularize}) = true diff --git a/src/calculus/separableSum.jl b/src/calculus/separableSum.jl index 5142cf40..e8999e24 100644 --- a/src/calculus/separableSum.jl +++ b/src/calculus/separableSum.jl @@ -29,15 +29,16 @@ SeparableSum(fs::Vararg) = SeparableSum((fs...,)) component_types(::Type{SeparableSum{T}}) where T = fieldtypes(T) -@generated is_prox_accurate(::Type{T}) where T <: SeparableSum = return all(is_prox_accurate, component_types(T)) ? :(true) : :(false) -@generated is_convex(::Type{T}) where T <: SeparableSum = return all(is_convex, component_types(T)) ? :(true) : :(false) -@generated is_set(::Type{T}) where T <: SeparableSum = return all(is_set, component_types(T)) ? :(true) : :(false) -@generated is_singleton(::Type{T}) where T <: SeparableSum = return all(is_singleton, component_types(T)) ? :(true) : :(false) -@generated is_cone(::Type{T}) where T <: SeparableSum = return all(is_cone, component_types(T)) ? :(true) : :(false) -@generated is_affine(::Type{T}) where T <: SeparableSum = return all(is_affine, component_types(T)) ? :(true) : :(false) -@generated is_smooth(::Type{T}) where T <: SeparableSum = return all(is_smooth, component_types(T)) ? :(true) : :(false) -@generated is_generalized_quadratic(::Type{T}) where T <: SeparableSum = return all(is_generalized_quadratic, component_types(T)) ? :(true) : :(false) -@generated is_strongly_convex(::Type{T}) where T <: SeparableSum = return all(is_strongly_convex, component_types(T)) ? :(true) : :(false) +@generated is_proximable(::Type{T}) where T <: SeparableSum = return all(is_proximable, component_types(T)) ? true : false +@generated is_convex(::Type{T}) where T <: SeparableSum = return all(is_convex, component_types(T)) ? true : false +@generated is_set_indicator(::Type{T}) where T <: SeparableSum = return all(is_set_indicator, component_types(T)) ? true : false +@generated is_singleton_indicator(::Type{T}) where T <: SeparableSum = return all(is_singleton_indicator, component_types(T)) ? true : false +@generated is_cone_indicator(::Type{T}) where T <: SeparableSum = return all(is_cone_indicator, component_types(T)) ? true : false +@generated is_affine_indicator(::Type{T}) where T <: SeparableSum = return all(is_affine_indicator, component_types(T)) ? true : false +@generated is_smooth(::Type{T}) where T <: SeparableSum = return all(is_smooth, component_types(T)) ? true : false +@generated is_locally_smooth(::Type{T}) where T <: SeparableSum = return all(is_locally_smooth, component_types(T)) ? true : false +@generated is_generalized_quadratic(::Type{T}) where T <: SeparableSum = return all(is_generalized_quadratic, component_types(T)) ? true : false +@generated is_strongly_convex(::Type{T}) where T <: SeparableSum = return all(is_strongly_convex, component_types(T)) ? true : false (g::SeparableSum)(xs::Tuple) = sum(f(x) for (f, x) in zip(g.fs, xs)) diff --git a/src/calculus/slicedSeparableSum.jl b/src/calculus/slicedSeparableSum.jl index 7a40fabc..47167d35 100644 --- a/src/calculus/slicedSeparableSum.jl +++ b/src/calculus/slicedSeparableSum.jl @@ -51,46 +51,46 @@ SlicedSeparableSum(f::F, idxs::T) where {F, T <: Tuple} = SlicedSeparableSum(Tuple(f for k in eachindex(idxs)), idxs) # Unroll the loop over the different types of functions to evaluate -@generated function (f::SlicedSeparableSum{A, B, N})(x) where {A, B, N} - ex = :(v = 0.0) - for i = 1:N # For each function type - ex = quote $ex; - for k in eachindex(f.fs[$i]) # For each function of that type - v += f.fs[$i][k](view(x,f.idxs[$i][k]...)) - end +function (f::SlicedSeparableSum)(x) + v = zero(eltype(x)) + for (fs_group, idxs_group) = zip(f.fs, f.idxs) # For each function type + for (fun, idx) in zip(fs_group, idxs_group) # For each function of that type + v += fun(view(x, idx...)) end end - ex = :($ex; return v) + return v end # Unroll the loop over the different types of functions to prox on -@generated function prox!(y, f::SlicedSeparableSum{A, B, N}, x, gamma) where {A, B, N} - ex = :(v = 0.0) - for i = 1:N # For each function type - ex = quote $ex; - for k in eachindex(f.fs[$i]) # For each function of that type - g = prox!(view(y, f.idxs[$i][k]...), f.fs[$i][k], view(x,f.idxs[$i][k]...), gamma) - v += g +function prox!(y, f::SlicedSeparableSum, x, gamma) + v = zero(eltype(x)) + for (fs_group, idxs_group) = zip(f.fs, f.idxs) # For each function type + for (fun, idx) in zip(fs_group, idxs_group) # For each function of that type + g = if idx isa Tuple + prox!(view(y, idx...), fun, view(x, idx...), gamma) + else + prox!(view(y, idx), fun, view(x, idx), gamma) end + v += g end end - ex = :($ex; return v) + return v end component_types(::Type{SlicedSeparableSum{S, T, N}}) where {S, T, N} = Tuple(A.parameters[1] for A in fieldtypes(S)) -@generated is_prox_accurate(::Type{T}) where T <: SlicedSeparableSum = return all(is_prox_accurate, component_types(T)) ? :(true) : :(false) +@generated is_proximable(::Type{T}) where T <: SlicedSeparableSum = return all(is_proximable, component_types(T)) ? :(true) : :(false) @generated is_convex(::Type{T}) where T <: SlicedSeparableSum = return all(is_convex, component_types(T)) ? :(true) : :(false) -@generated is_set(::Type{T}) where T <: SlicedSeparableSum = return all(is_set, component_types(T)) ? :(true) : :(false) -@generated is_singleton(::Type{T}) where T <: SlicedSeparableSum = return all(is_singleton, component_types(T)) ? :(true) : :(false) -@generated is_cone(::Type{T}) where T <: SlicedSeparableSum = return all(is_cone, component_types(T)) ? :(true) : :(false) -@generated is_affine(::Type{T}) where T <: SlicedSeparableSum = return all(is_affine, component_types(T)) ? :(true) : :(false) +@generated is_set_indicator(::Type{T}) where T <: SlicedSeparableSum = return all(is_set_indicator, component_types(T)) ? :(true) : :(false) +@generated is_singleton_indicator(::Type{T}) where T <: SlicedSeparableSum = return all(is_singleton_indicator, component_types(T)) ? :(true) : :(false) +@generated is_cone_indicator(::Type{T}) where T <: SlicedSeparableSum = return all(is_cone_indicator, component_types(T)) ? :(true) : :(false) +@generated is_affine_indicator(::Type{T}) where T <: SlicedSeparableSum = return all(is_affine_indicator, component_types(T)) ? :(true) : :(false) @generated is_smooth(::Type{T}) where T <: SlicedSeparableSum = return all(is_smooth, component_types(T)) ? :(true) : :(false) @generated is_generalized_quadratic(::Type{T}) where T <: SlicedSeparableSum = return all(is_generalized_quadratic, component_types(T)) ? :(true) : :(false) @generated is_strongly_convex(::Type{T}) where T <: SlicedSeparableSum = return all(is_strongly_convex, component_types(T)) ? :(true) : :(false) function prox_naive(f::SlicedSeparableSum, x, gamma) - fy = 0 + fy = zero(eltype(x)) y = similar(x) for t in eachindex(f.fs) for k in eachindex(f.fs[t]) diff --git a/src/calculus/sum.jl b/src/calculus/sum.jl index 534725f9..155f997d 100644 --- a/src/calculus/sum.jl +++ b/src/calculus/sum.jl @@ -17,16 +17,17 @@ Sum(fs::Vararg) = Sum((fs...,)) component_types(::Type{Sum{T}}) where T = fieldtypes(T) -# note: is_prox_accurate false because prox in general doesn't exist? -is_prox_accurate(::Type{<:Sum}) = false -@generated is_convex(::Type{T}) where T <: Sum = return all(is_convex, component_types(T)) ? :(true) : :(false) -@generated is_set(::Type{T}) where T <: Sum = return all(is_set, component_types(T)) ? :(true) : :(false) -@generated is_singleton(::Type{T}) where T <: Sum = return all(is_singleton, component_types(T)) ? :(true) : :(false) -@generated is_cone(::Type{T}) where T <: Sum = return all(is_cone, component_types(T)) ? :(true) : :(false) -@generated is_affine(::Type{T}) where T <: Sum = return all(is_affine, component_types(T)) ? :(true) : :(false) -@generated is_smooth(::Type{T}) where T <: Sum = return all(is_smooth, component_types(T)) ? :(true) : :(false) -@generated is_generalized_quadratic(::Type{T}) where T <: Sum = return all(is_generalized_quadratic, component_types(T)) ? :(true) : :(false) -@generated is_strongly_convex(::Type{T}) where T <: Sum = return (all(is_convex, component_types(T)) && any(is_strongly_convex, component_types(T))) ? :(true) : :(false) +# note: is_proximable false because prox in general doesn't exist? +is_proximable(::Type{<:Sum}) = false +@generated is_convex(::Type{T}) where T <: Sum = return all(is_convex, component_types(T)) ? true : false +@generated is_set_indicator(::Type{T}) where T <: Sum = return all(is_set_indicator, component_types(T)) ? true : false +@generated is_singleton_indicator(::Type{T}) where T <: Sum = return all(is_singleton_indicator, component_types(T)) ? true : false +@generated is_cone_indicator(::Type{T}) where T <: Sum = return all(is_cone_indicator, component_types(T)) ? true : false +@generated is_affine_indicator(::Type{T}) where T <: Sum = return all(is_affine_indicator, component_types(T)) ? true : false +@generated is_smooth(::Type{T}) where T <: Sum = return all(is_smooth, component_types(T)) ? true : false +@generated is_locally_smooth(::Type{T}) where T <: Sum = return all(is_locally_smooth, component_types(T)) ? true : false +@generated is_generalized_quadratic(::Type{T}) where T <: Sum = return all(is_generalized_quadratic, component_types(T)) ? true : false +@generated is_strongly_convex(::Type{T}) where T <: Sum = return (all(is_convex, component_types(T)) && any(is_strongly_convex, component_types(T))) ? true : false function (sumobj::Sum)(x) sum = real(eltype(x))(0) diff --git a/src/calculus/tilt.jl b/src/calculus/tilt.jl index 5d7f40b2..3cc7b32d 100644 --- a/src/calculus/tilt.jl +++ b/src/calculus/tilt.jl @@ -17,9 +17,9 @@ struct Tilt{T, S, R} end is_separable(::Type{<:Tilt{T}}) where T = is_separable(T) -is_prox_accurate(::Type{<:Tilt{T}}) where T = is_prox_accurate(T) +is_proximable(::Type{<:Tilt{T}}) where T = is_proximable(T) is_convex(::Type{<:Tilt{T}}) where T = is_convex(T) -is_singleton(::Type{<:Tilt{T}}) where T = is_singleton(T) +is_singleton_indicator(::Type{<:Tilt{T}}) where T = is_singleton_indicator(T) is_smooth(::Type{<:Tilt{T}}) where T = is_smooth(T) is_generalized_quadratic(::Type{<:Tilt{T}}) where T = is_generalized_quadratic(T) is_strongly_convex(::Type{<:Tilt{T}}) where T = is_strongly_convex(T) diff --git a/src/calculus/translate.jl b/src/calculus/translate.jl index 902afca2..e00cc71a 100644 --- a/src/calculus/translate.jl +++ b/src/calculus/translate.jl @@ -14,13 +14,14 @@ struct Translate{T, V} end is_separable(::Type{<:Translate{T}}) where T = is_separable(T) -is_prox_accurate(::Type{<:Translate{T}}) where T = is_prox_accurate(T) +is_proximable(::Type{<:Translate{T}}) where T = is_proximable(T) is_convex(::Type{<:Translate{T}}) where T = is_convex(T) -is_set(::Type{<:Translate{T}}) where T = is_set(T) -is_singleton(::Type{<:Translate{T}}) where T = is_singleton(T) -is_cone(::Type{<:Translate{T}}) where T = is_cone(T) -is_affine(::Type{<:Translate{T}}) where T = is_affine(T) +is_set_indicator(::Type{<:Translate{T}}) where T = is_set_indicator(T) +is_singleton_indicator(::Type{<:Translate{T}}) where T = is_singleton_indicator(T) +is_cone_indicator(::Type{<:Translate{T}}) where T = is_cone_indicator(T) +is_affine_indicator(::Type{<:Translate{T}}) where T = is_affine_indicator(T) is_smooth(::Type{<:Translate{T}}) where T = is_smooth(T) +is_locally_smooth(::Type{<:Translate{T}}) where T = is_locally_smooth(T) is_generalized_quadratic(::Type{<:Translate{T}}) where T = is_generalized_quadratic(T) is_strongly_convex(::Type{<:Translate{T}}) where T = is_strongly_convex(T) diff --git a/src/functions/elasticNet.jl b/src/functions/elasticNet.jl index 744eb584..96abc493 100644 --- a/src/functions/elasticNet.jl +++ b/src/functions/elasticNet.jl @@ -24,8 +24,9 @@ struct ElasticNet{R, S} end is_separable(f::Type{<:ElasticNet}) = true -is_prox_accurate(f::Type{<:ElasticNet}) = true +is_proximable(f::Type{<:ElasticNet}) = true is_convex(f::Type{<:ElasticNet}) = true +is_locally_smooth(f::Type{<:ElasticNet}) = true ElasticNet(mu::R=1, lambda::S=1) where {R, S} = ElasticNet{R, S}(mu, lambda) diff --git a/src/functions/indAffine.jl b/src/functions/indAffine.jl index 9dda046a..fe2c1ac8 100644 --- a/src/functions/indAffine.jl +++ b/src/functions/indAffine.jl @@ -10,7 +10,7 @@ export IndAffine abstract type IndAffine end -is_affine(f::Type{<:IndAffine}) = true +is_affine_indicator(f::Type{<:IndAffine}) = true is_generalized_quadratic(f::Type{<:IndAffine}) = true fun_name(f::IndAffine) = "Indicator of an affine subspace" diff --git a/src/functions/indAffineIterative.jl b/src/functions/indAffineIterative.jl index 6dd8d7e9..d3b1a7c9 100644 --- a/src/functions/indAffineIterative.jl +++ b/src/functions/indAffineIterative.jl @@ -15,7 +15,7 @@ struct IndAffineIterative{M, V} <: IndAffine end end -is_prox_accurate(f::Type{<:IndAffineIterative}) = false +is_proximable(f::Type{<:IndAffineIterative}) = false IndAffineIterative(A::M, b::V) where {M, V} = IndAffineIterative{M, V}(A, b) diff --git a/src/functions/indBallL0.jl b/src/functions/indBallL0.jl index b2377bc1..1e28c267 100644 --- a/src/functions/indBallL0.jl +++ b/src/functions/indBallL0.jl @@ -22,7 +22,7 @@ struct IndBallL0{I} end end -is_set(f::Type{<:IndBallL0}) = true +is_set_indicator(f::Type{<:IndBallL0}) = true IndBallL0(r::I) where {I} = IndBallL0{I}(r) diff --git a/src/functions/indBallL1.jl b/src/functions/indBallL1.jl index b646dac1..b37a0608 100644 --- a/src/functions/indBallL1.jl +++ b/src/functions/indBallL1.jl @@ -23,8 +23,8 @@ struct IndBallL1{R} end is_convex(f::Type{<:IndBallL1}) = true -is_set(f::Type{<:IndBallL1}) = true -is_prox_accurate(f::Type{<:IndBallL1}) = false +is_set_indicator(f::Type{<:IndBallL1}) = true +is_proximable(f::Type{<:IndBallL1}) = false IndBallL1(r::R=1.0) where R = IndBallL1{R}(r) diff --git a/src/functions/indBallL2.jl b/src/functions/indBallL2.jl index 03dd0ab4..0e3871eb 100644 --- a/src/functions/indBallL2.jl +++ b/src/functions/indBallL2.jl @@ -23,7 +23,7 @@ struct IndBallL2{R} end is_convex(f::Type{<:IndBallL2}) = true -is_set(f::Type{<:IndBallL2}) = true +is_set_indicator(f::Type{<:IndBallL2}) = true IndBallL2(r::R=1) where R = IndBallL2{R}(r) diff --git a/src/functions/indBallRank.jl b/src/functions/indBallRank.jl index deeaca9c..f799dd73 100644 --- a/src/functions/indBallRank.jl +++ b/src/functions/indBallRank.jl @@ -25,8 +25,8 @@ struct IndBallRank{I} end end -is_set(f::Type{<:IndBallRank}) = true -is_prox_accurate(f::Type{<:IndBallRank}) = false +is_set_indicator(f::Type{<:IndBallRank}) = true +is_proximable(f::Type{<:IndBallRank}) = false IndBallRank(r::I=1) where I = IndBallRank{I}(r) diff --git a/src/functions/indBinary.jl b/src/functions/indBinary.jl index 282d7630..403d9b7d 100644 --- a/src/functions/indBinary.jl +++ b/src/functions/indBinary.jl @@ -16,7 +16,7 @@ struct IndBinary{T, S} high::S end -is_set(f::Type{<:IndBinary}) = true +is_set_indicator(f::Type{<:IndBinary}) = true IndBinary() = IndBinary(0, 1) diff --git a/src/functions/indBox.jl b/src/functions/indBox.jl index 685fead2..e757fef2 100644 --- a/src/functions/indBox.jl +++ b/src/functions/indBox.jl @@ -28,7 +28,7 @@ end is_separable(f::Type{<:IndBox}) = true is_convex(f::Type{<:IndBox}) = true -is_set(f::Type{<:IndBox}) = true +is_set_indicator(f::Type{<:IndBox}) = true compatible_bounds(::Real, ::Real) = true compatible_bounds(::Real, ::AbstractArray) = true diff --git a/src/functions/indExp.jl b/src/functions/indExp.jl index 8dffcbfa..09bb45ef 100644 --- a/src/functions/indExp.jl +++ b/src/functions/indExp.jl @@ -14,7 +14,7 @@ C = \\mathrm{cl} \\{ (r,s,t) : s > 0, s⋅e^{r/s} \\leq t \\} \\subset \\mathbb{ struct IndExpPrimal end is_convex(f::Type{<:IndExpPrimal}) = true -is_cone(f::Type{<:IndExpPrimal}) = true +is_cone_indicator(f::Type{<:IndExpPrimal}) = true """ IndExpDual() diff --git a/src/functions/indFree.jl b/src/functions/indFree.jl index 1afb0ecc..b5ef5f6c 100644 --- a/src/functions/indFree.jl +++ b/src/functions/indFree.jl @@ -12,8 +12,8 @@ struct IndFree end is_separable(f::Type{<:IndFree}) = true is_convex(f::Type{<:IndFree}) = true -is_affine(f::Type{<:IndFree}) = true -is_cone(f::Type{<:IndFree}) = true +is_affine_indicator(f::Type{<:IndFree}) = true +is_cone_indicator(f::Type{<:IndFree}) = true is_smooth(f::Type{<:IndFree}) = true is_generalized_quadratic(f::Type{<:IndFree}) = true diff --git a/src/functions/indGraph.jl b/src/functions/indGraph.jl index c86c99e0..aafdafeb 100644 --- a/src/functions/indGraph.jl +++ b/src/functions/indGraph.jl @@ -29,8 +29,8 @@ function IndGraph(A::AbstractMatrix) end is_convex(f::Type{<:IndGraph}) = true -is_set(f::Type{<:IndGraph}) = true -is_cone(f::Type{<:IndGraph}) = true +is_set_indicator(f::Type{<:IndGraph}) = true +is_cone_indicator(f::Type{<:IndGraph}) = true IndGraph(a::AbstractVector) = IndGraph(a') diff --git a/src/functions/indHalfspace.jl b/src/functions/indHalfspace.jl index 8c3237b5..d44a45d6 100644 --- a/src/functions/indHalfspace.jl +++ b/src/functions/indHalfspace.jl @@ -26,7 +26,7 @@ end IndHalfspace(a::T, b::R) where {R, T} = IndHalfspace{R, T}(a, b) is_convex(f::Type{<:IndHalfspace}) = true -is_set(f::Type{<:IndHalfspace}) = true +is_set_indicator(f::Type{<:IndHalfspace}) = true function (f::IndHalfspace)(x) R = real(eltype(x)) diff --git a/src/functions/indHyperslab.jl b/src/functions/indHyperslab.jl index e32274ed..9602324d 100644 --- a/src/functions/indHyperslab.jl +++ b/src/functions/indHyperslab.jl @@ -27,7 +27,7 @@ end IndHyperslab(low::R, a::T, upp::R) where {R, T} = IndHyperslab{R, T}(low, a, upp) is_convex(f::Type{<:IndHyperslab}) = true -is_set(f::Type{<:IndHyperslab}) = true +is_set_indicator(f::Type{<:IndHyperslab}) = true function (f::IndHyperslab)(x) R = real(eltype(x)) diff --git a/src/functions/indNonnegative.jl b/src/functions/indNonnegative.jl index ba5b3bf2..fc08a6c4 100644 --- a/src/functions/indNonnegative.jl +++ b/src/functions/indNonnegative.jl @@ -14,7 +14,7 @@ struct IndNonnegative end is_separable(f::Type{<:IndNonnegative}) = true is_convex(f::Type{<:IndNonnegative}) = true -is_cone(f::Type{<:IndNonnegative}) = true +is_cone_indicator(f::Type{<:IndNonnegative}) = true function (::IndNonnegative)(x) R = eltype(x) diff --git a/src/functions/indNonpositive.jl b/src/functions/indNonpositive.jl index ba481e96..7dac78f1 100644 --- a/src/functions/indNonpositive.jl +++ b/src/functions/indNonpositive.jl @@ -14,7 +14,7 @@ struct IndNonpositive end is_separable(f::Type{<:IndNonpositive}) = true is_convex(f::Type{<:IndNonpositive}) = true -is_cone(f::Type{<:IndNonpositive}) = true +is_cone_indicator(f::Type{<:IndNonpositive}) = true function (::IndNonpositive)(x) R = eltype(x) diff --git a/src/functions/indPSD.jl b/src/functions/indPSD.jl index 36e72a66..9ec57946 100644 --- a/src/functions/indPSD.jl +++ b/src/functions/indPSD.jl @@ -46,7 +46,7 @@ function (::IndPSD)(X::Union{Symmetric, Hermitian}) end is_convex(f::Type{<:IndPSD}) = true -is_cone(f::Type{<:IndPSD}) = true +is_cone_indicator(f::Type{<:IndPSD}) = true function prox!(Y::Union{Symmetric, Hermitian}, ::IndPSD, X::Union{Symmetric, Hermitian}, gamma) R = real(eltype(X)) diff --git a/src/functions/indPoint.jl b/src/functions/indPoint.jl index afea3540..fe7f40f9 100644 --- a/src/functions/indPoint.jl +++ b/src/functions/indPoint.jl @@ -20,8 +20,8 @@ end is_separable(f::Type{<:IndPoint}) = true is_convex(f::Type{<:IndPoint}) = true -is_singleton(f::Type{<:IndPoint}) = true -is_affine(f::Type{<:IndPoint}) = true +is_singleton_indicator(f::Type{<:IndPoint}) = true +is_affine_indicator(f::Type{<:IndPoint}) = true IndPoint(p::T=0) where T = IndPoint{T}(p) diff --git a/src/functions/indPolyhedral.jl b/src/functions/indPolyhedral.jl index 77c894a5..8405c2c3 100644 --- a/src/functions/indPolyhedral.jl +++ b/src/functions/indPolyhedral.jl @@ -3,7 +3,7 @@ export IndPolyhedral abstract type IndPolyhedral end is_convex(::Type{<:IndPolyhedral}) = true -is_set(::Type{<:IndPolyhedral}) = true +is_set_indicator(::Type{<:IndPolyhedral}) = true """ IndPolyhedral([l,] A, [u, xmin, xmax]) diff --git a/src/functions/indPolyhedralOSQP.jl b/src/functions/indPolyhedralOSQP.jl index 3d2b4902..6fb67d7f 100644 --- a/src/functions/indPolyhedralOSQP.jl +++ b/src/functions/indPolyhedralOSQP.jl @@ -24,7 +24,7 @@ end # properties -is_prox_accurate(::Type{<:IndPolyhedralOSQP}) = false +is_proximable(::Type{<:IndPolyhedralOSQP}) = false # constructors diff --git a/src/functions/indSOC.jl b/src/functions/indSOC.jl index 2f3e6da0..55547f06 100644 --- a/src/functions/indSOC.jl +++ b/src/functions/indSOC.jl @@ -22,7 +22,7 @@ function (::IndSOC)(x) end is_convex(f::Type{<:IndSOC}) = true -is_cone(f::Type{<:IndSOC}) = true +is_cone_indicator(f::Type{<:IndSOC}) = true function prox!(y, ::IndSOC, x, gamma) T = eltype(x) @@ -84,7 +84,7 @@ function (::IndRotatedSOC)(x) end is_convex(f::IndRotatedSOC) = true -is_set(f::IndRotatedSOC) = true +is_set_indicator(f::IndRotatedSOC) = true function prox!(y, ::IndRotatedSOC, x, gamma) T = eltype(x) diff --git a/src/functions/indSimplex.jl b/src/functions/indSimplex.jl index 451423d4..65f7bce5 100644 --- a/src/functions/indSimplex.jl +++ b/src/functions/indSimplex.jl @@ -24,7 +24,7 @@ struct IndSimplex{R} end is_convex(f::Type{<:IndSimplex}) = true -is_set(f::Type{<:IndSimplex}) = true +is_set_indicator(f::Type{<:IndSimplex}) = true IndSimplex(a::R=1) where R = IndSimplex{R}(a) diff --git a/src/functions/indSphereL2.jl b/src/functions/indSphereL2.jl index ce568717..8a31f87b 100644 --- a/src/functions/indSphereL2.jl +++ b/src/functions/indSphereL2.jl @@ -22,7 +22,7 @@ struct IndSphereL2{R} end end -is_set(f::Type{<:IndSphereL2}) = true +is_set_indicator(f::Type{<:IndSphereL2}) = true IndSphereL2(r::R=1) where R = IndSphereL2{R}(r) diff --git a/src/functions/indStiefel.jl b/src/functions/indStiefel.jl index 433cc3af..3da57258 100644 --- a/src/functions/indStiefel.jl +++ b/src/functions/indStiefel.jl @@ -14,7 +14,7 @@ are inferred from the matrix provided as input. """ struct IndStiefel end -is_set(f::Type{<:IndStiefel}) = true +is_set_indicator(f::Type{<:IndStiefel}) = true function (::IndStiefel)(X) R = real(eltype(X)) diff --git a/src/functions/indZero.jl b/src/functions/indZero.jl index 8876bb58..c95efdba 100644 --- a/src/functions/indZero.jl +++ b/src/functions/indZero.jl @@ -11,9 +11,9 @@ struct IndZero end is_separable(f::Type{<:IndZero}) = true is_convex(f::Type{<:IndZero}) = true -is_singleton(f::Type{<:IndZero}) = true -is_cone(f::Type{<:IndZero}) = true -is_affine(f::Type{<:IndZero}) = true +is_singleton_indicator(f::Type{<:IndZero}) = true +is_cone_indicator(f::Type{<:IndZero}) = true +is_affine_indicator(f::Type{<:IndZero}) = true function (::IndZero)(x) C = eltype(x) diff --git a/src/functions/leastSquaresIterative.jl b/src/functions/leastSquaresIterative.jl index 29739eae..a9a848a1 100644 --- a/src/functions/leastSquaresIterative.jl +++ b/src/functions/leastSquaresIterative.jl @@ -16,7 +16,7 @@ struct LeastSquaresIterative{N, R, RC, M, V, O, IsConvex} <: LeastSquares q::Array{RC, N} # n (by-p) end -is_prox_accurate(f::Type{<:LeastSquaresIterative}) = false +is_proximable(f::Type{<:LeastSquaresIterative}) = false is_convex(::Type{LeastSquaresIterative{N, R, RC, M, V, O, IsConvex}}) where {N, R, RC, M, V, O, IsConvex} = IsConvex function LeastSquaresIterative(A::M, b, lambda) where M diff --git a/src/functions/logBarrier.jl b/src/functions/logBarrier.jl index d0826649..66001e67 100644 --- a/src/functions/logBarrier.jl +++ b/src/functions/logBarrier.jl @@ -26,6 +26,7 @@ end is_separable(f::Type{<:LogBarrier}) = true is_convex(f::Type{<:LogBarrier}) = true +is_locally_smooth(f::Type{<:LogBarrier}) = true LogBarrier(a::R=1, b::S=0, mu::T=1) where {R, S, T} = LogBarrier{R, S, T}(a, b, mu) diff --git a/src/functions/logisticLoss.jl b/src/functions/logisticLoss.jl index 0681c427..0f60f91a 100644 --- a/src/functions/logisticLoss.jl +++ b/src/functions/logisticLoss.jl @@ -27,7 +27,7 @@ LogisticLoss(y::T, mu::R=1) where {R, T} = LogisticLoss{T, R}(y, mu) is_separable(f::Type{<:LogisticLoss}) = true is_convex(f::Type{<:LogisticLoss}) = true is_smooth(f::Type{<:LogisticLoss}) = true -is_prox_accurate(f::Type{<:LogisticLoss}) = false +is_proximable(f::Type{<:LogisticLoss}) = false # f(x) = mu log(1 + exp(-y x)) diff --git a/src/functions/normL1.jl b/src/functions/normL1.jl index c057f0f6..f77abbba 100644 --- a/src/functions/normL1.jl +++ b/src/functions/normL1.jl @@ -31,6 +31,7 @@ end is_separable(f::Type{<:NormL1}) = true is_convex(f::Type{<:NormL1}) = true is_positively_homogeneous(f::Type{<:NormL1}) = true +is_locally_smooth(f::Type{<:NormL1}) = true NormL1(lambda::R=1) where R = NormL1{R}(lambda) diff --git a/src/functions/normL2.jl b/src/functions/normL2.jl index 431e28a4..f4c01eb0 100644 --- a/src/functions/normL2.jl +++ b/src/functions/normL2.jl @@ -23,6 +23,7 @@ end is_convex(f::Type{<:NormL2}) = true is_positively_homogeneous(f::Type{<:NormL2}) = true +is_locally_smooth(f::Type{<:NormL2}) = true NormL2(lambda::R=1) where R = NormL2{R}(lambda) diff --git a/src/functions/normLinf.jl b/src/functions/normLinf.jl index d9b88b4f..128b1d34 100644 --- a/src/functions/normLinf.jl +++ b/src/functions/normLinf.jl @@ -13,6 +13,8 @@ for a nonnegative parameter `λ`. """ NormLinf(lambda::T=1) where T = Conjugate(IndBallL1(lambda)) +is_locally_smooth(f::Type{<:Conjugate{<:IndBallL1}}) = true + (f::Conjugate{<:IndBallL1})(x) = (f.f.r) * norm(x, Inf) function gradient!(y, f::Conjugate{<:IndBallL1}, x) diff --git a/src/functions/quadraticIterative.jl b/src/functions/quadraticIterative.jl index 6753d64c..15c5c7b0 100644 --- a/src/functions/quadraticIterative.jl +++ b/src/functions/quadraticIterative.jl @@ -9,7 +9,7 @@ struct QuadraticIterative{M, V} <: Quadratic temp::V end -is_prox_accurate(f::Type{<:QuadraticIterative}) = false +is_proximable(f::Type{<:QuadraticIterative}) = false function QuadraticIterative(Q::M, q::V) where {M, V} if size(Q, 1) != size(Q, 2) || length(q) != size(Q, 2) diff --git a/src/functions/sumPositive.jl b/src/functions/sumPositive.jl index ac0f1d62..c4bcdd2c 100644 --- a/src/functions/sumPositive.jl +++ b/src/functions/sumPositive.jl @@ -14,6 +14,8 @@ struct SumPositive end is_separable(f::Type{<:SumPositive}) = true is_convex(f::Type{<:SumPositive}) = true +is_positively_homogeneous(f::Type{<:SumPositive}) = true +is_locally_smooth(f::Type{<:SumPositive}) = true function (::SumPositive)(x) return sum(xi -> max(xi, eltype(x)(0)), x) diff --git a/src/utilities/traits.jl b/src/utilities/traits.jl deleted file mode 100644 index b5d46f5b..00000000 --- a/src/utilities/traits.jl +++ /dev/null @@ -1,32 +0,0 @@ -is_prox_accurate(::Type) = true -is_prox_accurate(::T) where T = is_prox_accurate(T) - -is_separable(::Type) = false -is_separable(::T) where T = is_separable(T) - -is_singleton(::Type) = false -is_singleton(::T) where T = is_singleton(T) - -is_cone(::Type) = false -is_cone(::T) where T = is_cone(T) - -is_affine(T::Type) = is_singleton(T) -is_affine(::T) where T = is_affine(T) - -is_set(T::Type) = is_cone(T) || is_affine(T) -is_set(::T) where T = is_set(T) - -is_positively_homogeneous(T::Type) = is_cone(T) -is_positively_homogeneous(::T) where T = is_positively_homogeneous(T) - -is_support(T::Type) = is_convex(T) && is_positively_homogeneous(T) -is_support(::T) where T = is_support(T) - -is_smooth(::Type) = false -is_smooth(::T) where T = is_smooth(T) - -is_quadratic(T::Type) = is_generalized_quadratic(T) && is_smooth(T) -is_quadratic(::T) where T = is_quadratic(T) - -is_strongly_convex(::Type) = false -is_strongly_convex(::T) where T = is_strongly_convex(T) diff --git a/test/Project.toml b/test/Project.toml index 1b3a6c17..66be2241 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index f91f581d..9abe42bb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,16 +1,15 @@ using Test using ProximalOperators - -using ProximalOperators: - ArrayOrTuple, - is_prox_accurate, +using ProximalOperators: ArrayOrTuple +using ProximalCore: ProximalCore, + is_proximable, is_separable, is_convex, - is_singleton, - is_cone, - is_affine, - is_set, + is_singleton_indicator, + is_cone_indicator, + is_affine_indicator, + is_set_indicator, is_smooth, is_quadratic, is_generalized_quadratic, @@ -49,19 +48,19 @@ function prox_test(f, x::ArrayOrTuple{R}, gamma=1) where R <: Real @test typeof(fy_naive) == R - rtol = if ProximalOperators.is_prox_accurate(f) sqrt(eps(R)) else 1e-4 end + rtol = if is_proximable(f) sqrt(eps(R)) else 1e-4 end - if ProximalOperators.is_convex(f) + if is_convex(f) @test all(isapprox.(y_prealloc, y, rtol=rtol, atol=100*eps(R))) @test all(isapprox.(y_naive, y, rtol=rtol, atol=100*eps(R))) - if ProximalOperators.is_set(f) + if is_set_indicator(f) @test fy_prealloc == 0 end @test isapprox(fy_prealloc, fy, rtol=rtol, atol=100*eps(R)) @test isapprox(fy_naive, fy, rtol=rtol, atol=100*eps(R)) end - if !ProximalOperators.is_set(f) || ProximalOperators.is_prox_accurate(f) + if !is_set_indicator(f) || ProximalCore.is_proximable(f) f_at_y = call_test(f, y) if f_at_y !== nothing @test isapprox(f_at_y, fy, rtol=rtol, atol=100*eps(R)) @@ -88,10 +87,10 @@ function predicates_test(f) is_generalized_quadratic, is_quadratic, is_smooth, - is_singleton, - is_cone, - is_affine, - is_set, + is_singleton_indicator, + is_cone_indicator, + is_affine_indicator, + is_set_indicator, is_positively_homogeneous, is_support, ] @@ -104,9 +103,9 @@ function predicates_test(f) # quadratic => generalized_quadratic && smooth @test !is_quadratic(f) || (is_generalized_quadratic(f) && is_smooth(f)) # (singleton || cone || affine) => set - @test !(is_singleton(f) || is_cone(f) || is_affine(f)) || is_set(f) + @test !(is_singleton_indicator(f) || is_cone_indicator(f) || is_affine_indicator(f)) || is_set_indicator(f) # cone => positively homogeneous - @test !is_cone(f) || is_positively_homogeneous(f) + @test !is_cone_indicator(f) || is_positively_homogeneous(f) # (convex && positively homogeneous) <=> (convex && support) @test (is_convex(f) && is_positively_homogeneous(f)) == (is_convex(f) && is_support(f)) # strongly_convex => convex diff --git a/test/test_calls.jl b/test/test_calls.jl index ee24121e..a022b406 100644 --- a/test/test_calls.jl +++ b/test/test_calls.jl @@ -598,7 +598,7 @@ test_cases_spec = [ y, fy = prox_test(f, x, gam) ##### compute prox with multiple random gammas - if ProximalOperators.is_separable(f) + if ProximalCore.is_separable(f) gam = real(T)(0.5) .+ 2 .* rand(real(T), size(x)) y, fy = prox_test(f, x, gam) end diff --git a/test/test_gradients.jl b/test/test_gradients.jl index aa70041b..238859cf 100644 --- a/test/test_gradients.jl +++ b/test/test_gradients.jl @@ -158,7 +158,7 @@ for i in eachindex(stuff) ∇f, fx = gradient_test(f, x) for k = 1:10 # Test conditions in different directions - if ProximalOperators.is_convex(f) + if ProximalCore.is_convex(f) # Test ∇f is subgradient if typeof(f) <: CrossEntropy d = x.*(rand(Float64, size(x)).-1)./2 # assures 0 <= x+d <= 1 diff --git a/test/test_huberLoss.jl b/test/test_huberLoss.jl index b97c57fb..7e51f6d0 100644 --- a/test/test_huberLoss.jl +++ b/test/test_huberLoss.jl @@ -8,9 +8,9 @@ f = HuberLoss(1.5, 0.7) predicates_test(f) -@test ProximalOperators.is_smooth(f) == true -@test ProximalOperators.is_quadratic(f) == false -@test ProximalOperators.is_set(f) == false +@test ProximalCore.is_smooth(f) == true +@test ProximalCore.is_quadratic(f) == false +@test ProximalCore.is_set_indicator(f) == false x = randn(10) x = 1.6*x/norm(x) diff --git a/test/test_indAffine.jl b/test/test_indAffine.jl index 5111802c..46043b82 100644 --- a/test/test_indAffine.jl +++ b/test/test_indAffine.jl @@ -16,10 +16,10 @@ x = randn(n) predicates_test(f) -@test ProximalOperators.is_smooth(f) == false -@test ProximalOperators.is_quadratic(f) == false -@test ProximalOperators.is_generalized_quadratic(f) == true -@test ProximalOperators.is_set(f) == true +@test ProximalCore.is_smooth(f) == false +@test ProximalCore.is_quadratic(f) == false +@test ProximalCore.is_generalized_quadratic(f) == true +@test ProximalCore.is_set_indicator(f) == true call_test(f, x) y, fy = prox_test(f, x) diff --git a/test/test_indPolyhedral.jl b/test/test_indPolyhedral.jl index 97b4efb0..4eef6679 100644 --- a/test/test_indPolyhedral.jl +++ b/test/test_indPolyhedral.jl @@ -30,8 +30,8 @@ p = similar(x) () -> IndPolyhedral(l, A, u, xmin, xmax), ] f = constr() - @test ProximalOperators.is_convex(f) == true - @test ProximalOperators.is_set(f) == true + @test ProximalCore.is_convex(f) == true + @test ProximalCore.is_set_indicator(f) == true fx = call_test(f, x) p, fp = prox_test(f, x) end diff --git a/test/test_leastSquares.jl b/test/test_leastSquares.jl index d3ced458..369010ae 100644 --- a/test/test_leastSquares.jl +++ b/test/test_leastSquares.jl @@ -33,10 +33,10 @@ x = randn(T, shape_x...) f = LeastSquares(A, b, iterative=(mode == :iterative)) predicates_test(f) -@test ProximalOperators.is_smooth(f) == true -@test ProximalOperators.is_quadratic(f) == true -@test ProximalOperators.is_generalized_quadratic(f) == true -@test ProximalOperators.is_set(f) == false +@test ProximalCore.is_smooth(f) == true +@test ProximalCore.is_quadratic(f) == true +@test ProximalCore.is_generalized_quadratic(f) == true +@test ProximalCore.is_set_indicator(f) == false grad_fx, fx = gradient_test(f, x) lsres = A*x - b diff --git a/test/test_moreauEnvelope.jl b/test/test_moreauEnvelope.jl index 882c0f14..cabbbcc2 100644 --- a/test/test_moreauEnvelope.jl +++ b/test/test_moreauEnvelope.jl @@ -14,9 +14,9 @@ using LinearAlgebra predicates_test(g) - @test ProximalOperators.is_smooth(g) == true - @test ProximalOperators.is_quadratic(g) == false - @test ProximalOperators.is_set(g) == false + @test ProximalCore.is_smooth(g) == true + @test ProximalCore.is_quadratic(g) == false + @test ProximalCore.is_set_indicator(g) == false x = R[1.0, 2.0, 3.0, 4.0, 5.0] @@ -40,9 +40,9 @@ end predicates_test(g) - @test ProximalOperators.is_smooth(g) == true - @test ProximalOperators.is_quadratic(g) == false - @test ProximalOperators.is_set(g) == false + @test ProximalCore.is_smooth(g) == true + @test ProximalCore.is_quadratic(g) == false + @test ProximalCore.is_set_indicator(g) == false x = R[1.0, 2.0, 3.0, 4.0, 5.0] diff --git a/test/test_pointwiseMinimum.jl b/test/test_pointwiseMinimum.jl index 60c964fe..a7166bc0 100644 --- a/test/test_pointwiseMinimum.jl +++ b/test/test_pointwiseMinimum.jl @@ -9,8 +9,8 @@ f = PointwiseMinimum(IndPoint(T[-1.0]), IndPoint(T[1.0])) x = T[0.1] predicates_test(f) -@test ProximalOperators.is_set(f) == true -@test ProximalOperators.is_cone(f) == false +@test ProximalCore.is_set_indicator(f) == true +@test ProximalCore.is_cone_indicator(f) == false y, fy = prox_test(f, x) @test all(y .== T[1.0]) diff --git a/test/test_precompose.jl b/test/test_precompose.jl index d57bdfb7..6dd5bf94 100644 --- a/test/test_precompose.jl +++ b/test/test_precompose.jl @@ -17,9 +17,9 @@ g = Precompose(f, Q, 1.0) predicates_test(g) -@test ProximalOperators.is_smooth(g) == false -@test ProximalOperators.is_quadratic(g) == false -@test ProximalOperators.is_set(g) == true +@test ProximalCore.is_smooth(g) == false +@test ProximalCore.is_quadratic(g) == false +@test ProximalCore.is_set_indicator(g) == true x = randn(10) diff --git a/test/test_quadratic.jl b/test/test_quadratic.jl index 6112e643..247b3e62 100644 --- a/test/test_quadratic.jl +++ b/test/test_quadratic.jl @@ -17,9 +17,9 @@ f = Quadratic(Q, q) predicates_test(f) -@test ProximalOperators.is_smooth(f) == true -@test ProximalOperators.is_quadratic(f) == true -@test ProximalOperators.is_set(f) == false +@test ProximalCore.is_smooth(f) == true +@test ProximalCore.is_quadratic(f) == true +@test ProximalCore.is_set_indicator(f) == false x = randn(n) diff --git a/test/test_results.jl b/test/test_results.jl index d10125a2..7d95342a 100644 --- a/test/test_results.jl +++ b/test/test_results.jl @@ -340,7 +340,7 @@ stuff = [ y, fy = prox_test(f, x, gamma) @test y ≈ ref_y - if ProximalOperators.is_prox_accurate(f) + if ProximalCore.is_proximable(f) @test fy ≈ ref_fy end diff --git a/test/test_sum.jl b/test/test_sum.jl index 24364c24..c6160580 100644 --- a/test/test_sum.jl +++ b/test/test_sum.jl @@ -10,9 +10,9 @@ f = Sum(f1, f2) predicates_test(f) -@test ProximalOperators.is_quadratic(f) == true -@test ProximalOperators.is_strongly_convex(f) == true -@test ProximalOperators.is_set(f) == false +@test ProximalCore.is_quadratic(f) == true +@test ProximalCore.is_strongly_convex(f) == true +@test ProximalCore.is_set_indicator(f) == false xtest = randn(10) @@ -33,9 +33,9 @@ g = Sum(g1, g2) predicates_test(g) -@test ProximalOperators.is_smooth(g) == false -@test ProximalOperators.is_strongly_convex(g) == true -@test ProximalOperators.is_set(g) == false +@test ProximalCore.is_smooth(g) == false +@test ProximalCore.is_strongly_convex(g) == true +@test ProximalCore.is_set_indicator(g) == false xtest = randn(10) From 8894eff09c3ea3bb950d457200a4604168c8af17 Mon Sep 17 00:00:00 2001 From: Tamas Hakkel Date: Mon, 14 Apr 2025 19:22:19 +0200 Subject: [PATCH 2/7] Add PrecomposedSlicedSeparableSum implementation and corresponding tests --- src/ProximalOperators.jl | 1 + src/calculus/precomposedSlicedSeparableSum.jl | 150 ++++++++++++++++++ test/runtests.jl | 1 + test/test_precomposedSlicedSeparableSum.jl | 74 +++++++++ 4 files changed, 226 insertions(+) create mode 100644 src/calculus/precomposedSlicedSeparableSum.jl create mode 100644 test/test_precomposedSlicedSeparableSum.jl diff --git a/src/ProximalOperators.jl b/src/ProximalOperators.jl index bd3fc6e7..af5bf604 100644 --- a/src/ProximalOperators.jl +++ b/src/ProximalOperators.jl @@ -91,6 +91,7 @@ include("calculus/precomposeDiagonal.jl") include("calculus/regularize.jl") include("calculus/separableSum.jl") include("calculus/slicedSeparableSum.jl") +include("calculus/precomposedSlicedSeparableSum.jl") include("calculus/sqrDistL2.jl") include("calculus/tilt.jl") include("calculus/translate.jl") diff --git a/src/calculus/precomposedSlicedSeparableSum.jl b/src/calculus/precomposedSlicedSeparableSum.jl new file mode 100644 index 00000000..26a12776 --- /dev/null +++ b/src/calculus/precomposedSlicedSeparableSum.jl @@ -0,0 +1,150 @@ +# Separable sum, using slices of an array as variables + +export PrecomposedSlicedSeparableSum + +""" + precomposedSlicedSeparableSum((f_1, ..., f_k), (J_1, ..., J_k), (L_1, ..., L_k)) + +Return the function +```math +g(x) = \\sum_{i=1}^k f_i(L_i * x_{J_i}). +``` + + precomposedSlicedSeparableSum(f, (J_1, ..., J_k), (L_1, ..., L_k)) + +Analogous to the previous one, but apply the same function `f` to all slices +of the variable `x`: +```math +g(x) = \\sum_{i=1}^k f(L_i * x_{J_i}). +``` +""" +struct PrecomposedSlicedSeparableSum{S <: Tuple, T <: AbstractArray, U <: AbstractArray, V <: AbstractArray, N} + fs::S # Tuple, where each element is a Vector with elements of the same type; the functions to prox on + # Example: S = Tuple{Array{ProximalOperators.NormL1{Float64},1}, Array{ProximalOperators.NormL2{Float64},1}} + idxs::T # Vector, where each element is a Vector containing the indices to prox on + # Example: T = Array{Array{Tuple{Colon,UnitRange{Int64}},1},1} + ops::U # Vector of operations (matrices or AbstractOperators) to apply to the function + # Example: U = Array{Array{Matrix{Float64},1},1} + μs::V # Vector of mu values for each function +end + +function PrecomposedSlicedSeparableSum(fs::Tuple, idxs::Tuple, ops::Tuple, μs::Tuple) + @assert length(fs) == length(idxs) + @assert length(fs) == length(ops) + ftypes = DataType[] + fsarr = Array{Any,1}[] + indarr = Array{eltype(idxs),1}[] + opsarr = Array{Any,1}[] + μsarr = Array{Any,1}[] + for (i,f) in enumerate(fs) + t = typeof(f) + fi = findfirst(isequal(t), ftypes) + if fi === nothing + push!(ftypes, t) + push!(fsarr, Any[f]) + push!(indarr, eltype(idxs)[idxs[i]]) + push!(opsarr, Any[ops[i]]) + push!(μsarr, Any[μs[i]]) + else + push!(fsarr[fi], f) + push!(indarr[fi], idxs[i]) + push!(opsarr[fi], ops[i]) + push!(μsarr[fi], μs[i]) + end + end + fsnew = ((Array{typeof(fs[1]),1}(fs) for fs in fsarr)...,) + @assert typeof(fsnew) == Tuple{(Array{ft,1} for ft in ftypes)...} + PrecomposedSlicedSeparableSum{typeof(fsnew),typeof(indarr),typeof(opsarr),typeof(μsarr),length(fsnew)}(fsnew, indarr, opsarr, μsarr) +end + +# Constructor for the case where the same function is applied to all slices +PrecomposedSlicedSeparableSum(f::F, idxs::T, ops::U, μs::V) where {F, T <: Tuple, U <: Tuple, V <: Tuple} = + PrecomposedSlicedSeparableSum(Tuple(f for k in eachindex(idxs)), idxs, ops, μs) + +# Unroll the loop over the different types of functions to evaluate +function (f::PrecomposedSlicedSeparableSum)(x::Tuple) + v = zero(eltype(x[1])) + for (fs_group, idxs_group, ops_group) = zip(f.fs, f.idxs, f.ops) # For each function type + for (fun, idx_group, hcat_op) in zip(fs_group, idxs_group, ops_group) # For each function of that type + for (var_index, (x_var, idx)) in enumerate(zip(x, idx_group)) + if idx isa Tuple + v += fun(hcat_op[var_index] * view(x_var, idx...)) + elseif idx isa Colon + v += fun(hcat_op[var_index] * x_var) + elseif idx isa Nothing + # do nothing + else + v += fun(hcat_op[var_index] * view(x_var, idx)) + end + end + end + end + return v +end + +function slice_var(x, idx) + if idx isa Tuple + return view(x, idx...) + elseif idx isa Colon + return x + elseif idx isa Nothing + return similar(x) + else + return view(x, idx) + end +end + +# Unroll the loop over the different types of functions to prox on +function prox!(y::Tuple, f::PrecomposedSlicedSeparableSum, x::Tuple, gamma) + v = zero(eltype(x[1])) + counter = 1 + for (fs_group, idxs_group, ops_group, μ_group) = zip(f.fs, f.idxs, f.ops, f.μs) # For each function type + for (fun, idx_group, hcat_op, μ) in zip(fs_group, idxs_group, ops_group, μ_group) # For each function of that type + sliced_x = Tuple(slice_var(x_var, idx) for (x_var, idx) in zip(x, idx_group)) + sliced_y = Tuple(slice_var(y_var, idx) for (y_var, idx) in zip(y, idx_group)) + res = hcat_op * sliced_x + prox_res, g = prox(fun, res, μ.*gamma) + prox_res .-= res + prox_res ./= μ + mul!(sliced_y, adjoint(hcat_op), prox_res) + for i in eachindex(sliced_x) + sliced_y[i] .+= sliced_x[i] + end + v += g + counter += 1 + end + end + return v +end + +component_types(::Type{PrecomposedSlicedSeparableSum{S, T, N}}) where {S, T, N} = Tuple(A.parameters[1] for A in fieldtypes(S)) + +@generated is_proximable(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_proximable, component_types(T)) ? :(true) : :(false) +@generated is_convex(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_convex, component_types(T)) ? :(true) : :(false) +@generated is_set_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_set_indicator, component_types(T)) ? :(true) : :(false) +@generated is_singleton_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_singleton_indicator, component_types(T)) ? :(true) : :(false) +@generated is_cone_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_cone_indicator, component_types(T)) ? :(true) : :(false) +@generated is_affine_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_affine_indicator, component_types(T)) ? :(true) : :(false) +@generated is_smooth(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_smooth, component_types(T)) ? :(true) : :(false) +@generated is_generalized_quadratic(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_generalized_quadratic, component_types(T)) ? :(true) : :(false) +@generated is_strongly_convex(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_strongly_convex, component_types(T)) ? :(true) : :(false) + +function prox_naive(f::PrecomposedSlicedSeparableSum, x, gamma) + fy = 0 + y = similar.(x) + for (fs_group, idxs_group, ops_group, μ_group) = zip(f.fs, f.idxs, f.ops, f.μs) # For each function type + for (fun, idx_group, hcat_op, μ) in zip(fs_group, idxs_group, ops_group, μ_group) # For each function of that type + sliced_x = Tuple(slice_var(x_var, idx) for (x_var, idx) in zip(x, idx_group)) + sliced_y = Tuple(slice_var(y_var, idx) for (y_var, idx) in zip(y, idx_group)) + res = hcat_op * sliced_x + prox_res, _fy = prox_naive(fun, res, μ.*gamma) + prox_res = (prox_res .- res) ./ μ + mul!(sliced_y, adjoint(hcat_op), prox_res) + fy += _fy + for i in eachindex(sliced_x) + sliced_y[i] .+= sliced_x[i] + end + end + end + return y, fy +end diff --git a/test/runtests.jl b/test/runtests.jl index 9abe42bb..b233553d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -150,6 +150,7 @@ end include("test_regularize.jl") include("test_separableSum.jl") include("test_slicedSeparableSum.jl") + include("test_precomposedSlicedSeparableSum.jl") include("test_sum.jl") end diff --git a/test/test_precomposedSlicedSeparableSum.jl b/test/test_precomposedSlicedSeparableSum.jl new file mode 100644 index 00000000..a1739384 --- /dev/null +++ b/test/test_precomposedSlicedSeparableSum.jl @@ -0,0 +1,74 @@ +# x = (randn(10), randn(10)) +# norm(x[1], 1) + norm(A2[1:5, 1:5] * x[2][1:5], 2) + norm(A2[6:10, 6:10] * x[2][6:10], 2)^2 + +@testset "PrecomposedSlicedSeparableSum" begin + +fs = (NormL1(), NormL2(), SqrNormL2()) + +A1 = (Diagonal(ones(10)), nothing) +F = qr(randn(5, 5)) +A2 = (nothing, Matrix(F.Q)) +F = qr(randn(5, 5)) +A3 = (nothing, Matrix(F.Q)) +mu = rand(5) +A3[2] .*= reshape(mu, 5, 1) +ops = (A1, A2, A3) + +idxs = ((Colon(), nothing), (nothing, 1:5), (nothing, 6:10)) +μs = (1.0, 1.0, mu) + +AAc2 = A2[2] * A2[2]' +@test AAc2 ≈ I +AAc3 = A3[2] * A3[2]' +@test AAc3 ≈ Diagonal(mu) .^ 2 + +import Base: *, adjoint +# mock HCAT from AbstractOperators + +function Base.:*(A::Tuple, x::Tuple) + codomain_size = size(A[findfirst(a -> !(a isa Nothing), A)], 1) + out = zeros(codomain_size) + for i in eachindex(A) + if A[i] isa Nothing + continue + end + out .+= A[i] * x[i] + end + return out +end +function Base.:*(A::Tuple, x::AbstractArray) # adjoint operation + domain_size = size(A[findfirst(a -> !(a isa Nothing), A)], 2) + return mul!(Tuple(zeros(domain_size) for _ in eachindex(A)), A, x) +end +function LinearAlgebra.mul!(out::Tuple, A::Tuple, x::AbstractArray) + for i in eachindex(A) + if A[i] isa Nothing + out[i] .= 0 + else + out[i] .+= A[i]' * x + end + end + return out +end +Base.adjoint(A::Tuple) = A + + +f = PrecomposedSlicedSeparableSum(fs, idxs, ops, μs) +x = (randn(10), rand(10)) +y = (zeros(10), zeros(10)) +fy = prox!(y, f, x, 1.0) +yn, fyn = ProximalOperators.prox_naive(f, x, 1.0) +y1, fy1 = prox(NormL1(), x[1], 1.0) +y2, fy2 = prox(Precompose(NormL2(), A2[2], 1), x[2][1:5], 1.0) +y3, fy3 = prox(Precompose(SqrNormL2(), A3[2], mu), x[2][6:10], 1.0) + +@show f(y), fy +@show norm(y) +@test abs(fyn-fy)<1e-11 +@test norm(yn[1]-y[1])+norm(yn[2]-y[2])<1e-11 +@test abs((fy1+fy2+fy3)-fy)<1e-11 +@test norm(y[1] - y1) < 1e-11 +@test norm(y[2][1:5] - y2) < 1e-11 +@test norm(y[2][6:10] - y3) < 1e-11 + +end From e048f082193f4697b36dcb23be8d92a07aca9478 Mon Sep 17 00:00:00 2001 From: Tamas Hakkel Date: Tue, 22 Apr 2025 20:41:23 +0200 Subject: [PATCH 3/7] add reshapeInput function wrapper --- src/ProximalOperators.jl | 1 + src/calculus/reshapeInput.jl | 59 ++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 src/calculus/reshapeInput.jl diff --git a/src/ProximalOperators.jl b/src/ProximalOperators.jl index af5bf604..c90d35d8 100644 --- a/src/ProximalOperators.jl +++ b/src/ProximalOperators.jl @@ -92,6 +92,7 @@ include("calculus/regularize.jl") include("calculus/separableSum.jl") include("calculus/slicedSeparableSum.jl") include("calculus/precomposedSlicedSeparableSum.jl") +include("calculus/reshapeInput.jl") include("calculus/sqrDistL2.jl") include("calculus/tilt.jl") include("calculus/translate.jl") diff --git a/src/calculus/reshapeInput.jl b/src/calculus/reshapeInput.jl new file mode 100644 index 00000000..6b32ab3a --- /dev/null +++ b/src/calculus/reshapeInput.jl @@ -0,0 +1,59 @@ +# wrap a function to reshape the input + +export ReshapeInput + +""" + ReshapeInput(f, expected_shape) + +Wrap a function to reshape the input. +It is useful when the function `f` expects a specific shape of the input, but you want to pass it a different shape. + +```julia +julia> f = ReshapeInput(IndballRank(5), (10, 10)) +ReshapeInput(IndBallRank{Int64}(5), (10, 10)) + +julia> f(rand(100)) +Inf +""" +struct ReshapeInput{F, S} + f::F + expected_shape::S +end + +function (f::ReshapeInput)(x) + # Check if the input x has the expected shape + if size(x) != f.expected_shape + # Reshape the input to the expected shape + x = reshape(x, f.expected_shape) + end + return f.f(x) +end + +function prox!(y, f::ReshapeInput, x, gamma) + # Check if the input x has the expected shape + if size(x) != f.expected_shape + # Reshape the input to the expected shape + x = reshape(x, f.expected_shape) + y = reshape(y, f.expected_shape) + end + return prox!(y, f.f, x, gamma) +end + +function gradient!(y, f::ReshapeInput, x) + # Check if the input x has the expected shape + if size(x) != f.expected_shape + # Reshape the input to the expected shape + x = reshape(x, f.expected_shape) + y = reshape(y, f.expected_shape) + end + return gradient!(y, f.f, x) +end + +function prox_naive(f::ReshapeInput, x, gamma) + # Check if the input x has the expected shape + if size(x) != f.expected_shape + # Reshape the input to the expected shape + x = reshape(x, f.expected_shape) + end + return prox_naive(f.f, x, gamma) +end From 0c47c266a4dfcc98e4c26cfd8596640386099d87 Mon Sep 17 00:00:00 2001 From: Tamas Hakkel Date: Tue, 22 Apr 2025 21:15:39 +0200 Subject: [PATCH 4/7] add RecursiveArrayToolsExt to support ArrayPartitions --- Project.toml | 7 +++++++ ext/RecursiveArrayToolsExt.jl | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 ext/RecursiveArrayToolsExt.jl diff --git a/Project.toml b/Project.toml index 8de132c4..977a3de4 100644 --- a/Project.toml +++ b/Project.toml @@ -11,11 +11,18 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" TSVD = "9449cd9e-2762-5aa3-a617-5413e99d722e" +[weakdeps] +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" + +[extensions] +RecursiveArrayToolsExt = "RecursiveArrayTools" + [compat] IterativeSolvers = "0.8 - 0.9" LinearAlgebra = "1.4" OSQP = "0.3 - 0.8" ProximalCore = "0.2" +RecursiveArrayTools = "2, 3" SparseArrays = "1.4" SuiteSparse = "1.4" TSVD = "0.3 - 0.4" diff --git a/ext/RecursiveArrayToolsExt.jl b/ext/RecursiveArrayToolsExt.jl new file mode 100644 index 00000000..0a92839e --- /dev/null +++ b/ext/RecursiveArrayToolsExt.jl @@ -0,0 +1,19 @@ +module RecursiveArrayToolsExt +using RecursiveArrayTools +using ProximalOperators +import ProximalCore: prox, prox!, gradient, gradient! + +(f::PrecomposedSlicedSeparableSum)(x::ArrayPartition) = f(x.x) +prox!(y::ArrayPartition, f::PrecomposedSlicedSeparableSum, x::ArrayPartition, gamma) = prox!(y.x, f, x.x, gamma) + +(g::SeparableSum)(xs::ArrayPartition) = g(xs.x) +prox!(ys::ArrayPartition, g::SeparableSum, xs::ArrayPartition, gamma::Number) = prox!(ys.x, g, xs.x, gamma) +prox!(ys::ArrayPartition, g::SeparableSum, xs::ArrayPartition, gammas::Tuple) = prox!(ys.x, g, xs.x, gammas) +function prox(g::SeparableSum, xs::ArrayPartition, gamma=1) + y, fy = prox(g, xs.x, gamma) + return ArrayPartition(y), fy +end +gradient!(grads::ArrayPartition, g::SeparableSum, xs::ArrayPartition) = gradient!(grads.x, g, xs.x) +gradient(g::SeparableSum, xs::ArrayPartition) = gradient(g, xs.x) + +end # module RecursiveArrayToolsExt From 1ab0be88bcf39e3744cd60de72a28c83c98d079f Mon Sep 17 00:00:00 2001 From: Tamas Hakkel Date: Wed, 7 May 2025 13:53:33 +0200 Subject: [PATCH 5/7] minor corrections --- src/calculus/precompose.jl | 10 ++++++++-- src/functions/indBallL1.jl | 1 - src/functions/normL1.jl | 1 + src/functions/sqrNormL2.jl | 13 +++++++------ src/utilities/linops.jl | 3 ++- 5 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/calculus/precompose.jl b/src/calculus/precompose.jl index a104ceb2..0fdd25c9 100644 --- a/src/calculus/precompose.jl +++ b/src/calculus/precompose.jl @@ -57,7 +57,10 @@ function (g::Precompose)(x) end function gradient!(y, g::Precompose, x) - res = g.L*x .+ g.b + res = g.L*x + if g.b != 0 + res .+= g.b + end gradres = similar(res) v = gradient!(gradres, g.f, res) mul!(y, adjoint(g.L), gradres) @@ -75,7 +78,10 @@ function prox!(y, g::Precompose, x, gamma) # prox_f(x) = prox_h(x + b) - b # Then one can apply the above mentioned result to g(x) = f(Lx). # - res = g.L*x .+ g.b + res = g.L * x + if g.b != 0 + res .+= g.b + end proxres = similar(res) v = prox!(proxres, g.f, res, g.mu.*gamma) proxres .-= res diff --git a/src/functions/indBallL1.jl b/src/functions/indBallL1.jl index b37a0608..37a2851e 100644 --- a/src/functions/indBallL1.jl +++ b/src/functions/indBallL1.jl @@ -24,7 +24,6 @@ end is_convex(f::Type{<:IndBallL1}) = true is_set_indicator(f::Type{<:IndBallL1}) = true -is_proximable(f::Type{<:IndBallL1}) = false IndBallL1(r::R=1.0) where R = IndBallL1{R}(r) diff --git a/src/functions/normL1.jl b/src/functions/normL1.jl index f77abbba..ccee4797 100644 --- a/src/functions/normL1.jl +++ b/src/functions/normL1.jl @@ -28,6 +28,7 @@ struct NormL1{T} end end +is_proximable(f::Type{<:NormL1}) = true is_separable(f::Type{<:NormL1}) = true is_convex(f::Type{<:NormL1}) = true is_positively_homogeneous(f::Type{<:NormL1}) = true diff --git a/src/functions/sqrNormL2.jl b/src/functions/sqrNormL2.jl index 0069c593..9d705b72 100644 --- a/src/functions/sqrNormL2.jl +++ b/src/functions/sqrNormL2.jl @@ -25,15 +25,16 @@ struct SqrNormL2{T,SC} end end -is_convex(f::Type{<:SqrNormL2}) = true -is_smooth(f::Type{<:SqrNormL2}) = true -is_separable(f::Type{<:SqrNormL2}) = true -is_generalized_quadratic(f::Type{<:SqrNormL2}) = true -is_strongly_convex(f::Type{SqrNormL2{T,SC}}) where {T,SC} = SC +is_proximable(::Type{<:SqrNormL2}) = true +is_convex(::Type{<:SqrNormL2}) = true +is_smooth(::Type{<:SqrNormL2}) = true +is_separable(::Type{<:SqrNormL2}) = true +is_generalized_quadratic(::Type{<:SqrNormL2}) = true +is_strongly_convex(::Type{SqrNormL2{T,SC}}) where {T,SC} = SC SqrNormL2(lambda::T=1) where T = SqrNormL2{T,all(lambda .> 0)}(lambda) -function (f::SqrNormL2{S})(x) where {S <: Real} +function (f::SqrNormL2{<:Real})(x) return f.lambda / real(eltype(x))(2) * norm(x)^2 end diff --git a/src/utilities/linops.jl b/src/utilities/linops.jl index e12f6401..170bd368 100644 --- a/src/utilities/linops.jl +++ b/src/utilities/linops.jl @@ -9,7 +9,8 @@ infer_shape_of_y(Op, ::AbstractVector) = (size(Op, 1), ) infer_shape_of_y(Op, x::AbstractMatrix) = (size(Op, 1), size(x, 2)) function (*)(Op::LinOp, x) - y = zeros(promote_type(eltype(Op), eltype(x)), infer_shape_of_y(Op, x)) + y = similar(x, promote_type(eltype(Op), eltype(x)), infer_shape_of_y(Op, x)) + y .= 0 mul!(y, Op, x) end From 12f1c9e4a8f8f30c8364ffd1a7992135c602da81 Mon Sep 17 00:00:00 2001 From: Tamas Hakkel Date: Thu, 6 Nov 2025 20:37:05 +0100 Subject: [PATCH 6/7] Defer loading OSQP for faster startup time --- src/functions/indPolyhedralOSQP.jl | 36 ++++++++++++++---------------- test/runtests.jl | 3 ++- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/functions/indPolyhedralOSQP.jl b/src/functions/indPolyhedralOSQP.jl index 6fb67d7f..99963497 100644 --- a/src/functions/indPolyhedralOSQP.jl +++ b/src/functions/indPolyhedralOSQP.jl @@ -1,24 +1,25 @@ # IndPolyhedral: OSQP implementation -using OSQP - -struct IndPolyhedralOSQP{R} <: IndPolyhedral +struct IndPolyhedralOSQP{R,M} <: IndPolyhedral l::AbstractVector{R} A::AbstractMatrix{R} u::AbstractVector{R} - mod::OSQP.Model - function IndPolyhedralOSQP{R}( + mod::M + function IndPolyhedralOSQP( l::AbstractVector{R}, A::AbstractMatrix{R}, u::AbstractVector{R} ) where R m, n = size(A) - mod = OSQP.Model() - if !all(l .<= u) - error("function is improper (are some bounds inverted?)") + mod = Base.invokelatest(Base.require(@__MODULE__, :OSQP)) do OSQP + mod = OSQP.Model() + if !all(l .<= u) + error("function is improper (are some bounds inverted?)") + end + OSQP.setup!(mod; P=SparseMatrixCSC{R}(I, n, n), l=l, A=sparse(A), u=u, verbose=false, + eps_abs=eps(R), eps_rel=eps(R), + eps_prim_inf=eps(R), eps_dual_inf=eps(R)) + mod end - OSQP.setup!(mod; P=SparseMatrixCSC{R}(I, n, n), l=l, A=sparse(A), u=u, verbose=false, - eps_abs=eps(R), eps_rel=eps(R), - eps_prim_inf=eps(R), eps_dual_inf=eps(R)) - new(l, A, u, mod) + new{R,typeof(mod)}(l, A, u, mod) end end @@ -28,11 +29,6 @@ is_proximable(::Type{<:IndPolyhedralOSQP}) = false # constructors -IndPolyhedralOSQP( - l::AbstractVector{R}, A::AbstractMatrix{R}, u::AbstractVector{R} -) where R = - IndPolyhedralOSQP{R}(l, A, u) - IndPolyhedralOSQP( l::AbstractVector{R}, A::AbstractMatrix{R}, u::AbstractVector{R}, xmin::AbstractVector{R}, xmax::AbstractVector{R} @@ -65,8 +61,10 @@ end function prox!(y, f::IndPolyhedralOSQP, x, gamma) R = eltype(x) - OSQP.update!(f.mod; q=-x) - results = OSQP.solve!(f.mod) + results = Base.invokelatest(Base.require(@__MODULE__, :OSQP)) do OSQP + OSQP.update!(f.mod; q=-x) + OSQP.solve!(f.mod) + end y .= results.x return R(0) end diff --git a/test/runtests.jl b/test/runtests.jl index b233553d..c9f1bf63 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -113,7 +113,8 @@ function predicates_test(f) end @testset "Aqua" begin - Aqua.test_all(ProximalOperators; ambiguities=false) + Aqua.test_all(ProximalOperators; ambiguities=false, stale_deps=false, persistent_tasks=false) + Aqua.test_stale_deps(ProximalOperators, ignore=[:OSQP]) end @testset "Utilities" begin From e9ecd4f788f38bee885aa1707c22e148cacf45af Mon Sep 17 00:00:00 2001 From: Tamas Hakkel Date: Tue, 11 Nov 2025 19:43:26 +0100 Subject: [PATCH 7/7] Update CI --- .github/workflows/benchmark.yml | 9 +++++++- .github/workflows/docs.yml | 9 +++++++- .github/workflows/test.yml | 40 ++++++++++++++++++++------------- 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 4f8ba54f..6437348d 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -5,11 +5,18 @@ jobs: benchmark: name: Benchmark runs-on: ubuntu-latest + + # needed to allow julia-actions/cache to delete old caches that it has created + permissions: + actions: write + contents: read + steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v5 - uses: julia-actions/setup-julia@v1 with: version: '1' + - uses: julia-actions/cache@v2 - run: git fetch origin '+refs/heads/master:refs/remotes/origin/master' - run: git branch master origin/master - run: | diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 5f828111..64facc21 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -11,11 +11,18 @@ jobs: build: name: Documentation runs-on: ubuntu-latest + + # needed to allow julia-actions/cache to delete old caches that it has created + permissions: + actions: write + contents: read + steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v5 - uses: julia-actions/setup-julia@latest with: version: '1' + - uses: julia-actions/cache@v2 - name: Install dependencies run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' - name: Build and deploy diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f891d64e..57b5b0f7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,30 +6,40 @@ on: - master pull_request: workflow_dispatch: + jobs: build: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} + name: Julia ${{ matrix.julia-version }} - ${{ matrix.os }} - ${{ matrix.julia-arch }} runs-on: ${{ matrix.os }} + strategy: fail-fast: false matrix: - version: - - '1' - - '1.6' - os: - - ubuntu-latest - - macOS-latest - - windows-latest - arch: - - x64 + julia-version: ['lts', '1'] + julia-arch: [x64] + os: [ubuntu-latest, windows-latest, macOS-latest] + + # needed to allow julia-actions/cache to delete old caches that it has created + permissions: + actions: write + contents: read + steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v5 - uses: julia-actions/setup-julia@latest with: - version: ${{ matrix.version }} - arch: ${{ matrix.arch }} + version: ${{ matrix.julia-version }} + arch: ${{ matrix.julia-arch }} + - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@latest - - uses: julia-actions/julia-runtest@latest - - uses: julia-actions/julia-uploadcodecov@latest + - uses: julia-actions/julia-runtest@v1 + with: + coverage: true + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v5 + # Upload coverage only from one job (Linux, Julia latest version) + if: matrix.os == 'ubuntu-latest' && matrix.julia-version == '1' + with: + files: lcov.info env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}