Skip to content

Commit 9067e90

Browse files
authored
Update gradient interface, support AbstractDifferentiation (#90)
1 parent 4d07b28 commit 9067e90

37 files changed

+382
-617
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
name = "ProximalAlgorithms"
22
uuid = "140ffc9f-1907-541a-a177-7475e0a401e9"
3-
version = "0.5.5"
3+
version = "0.6.0"
44

55
[deps]
6+
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
67
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
89
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"
9-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1010

1111
[compat]
12+
AbstractDifferentiation = "0.6"
1213
LinearAlgebra = "1.2"
1314
Printf = "1.2"
1415
ProximalCore = "0.1"
15-
Zygote = "0.6"
1616
julia = "1.2"

README.md

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,21 @@ A Julia package for non-smooth optimization algorithms.
1111
This package provides algorithms for the minimization of objective functions
1212
that include non-smooth terms, such as constraints or non-differentiable penalties.
1313
Implemented algorithms include:
14-
* (Fast) Proximal gradient methods
15-
* Douglas-Rachford splitting
16-
* Three-term splitting
17-
* Primal-dual splitting algorithms
18-
* Newton-type methods
19-
20-
This package works well in combination with [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) (>= 0.15),
21-
which contains a wide range of functions that can be used to express cost terms.
14+
- (Fast) Proximal gradient methods
15+
- Douglas-Rachford splitting
16+
- Three-term splitting
17+
- Primal-dual splitting algorithms
18+
- Newton-type methods
19+
20+
Check out [this section](https://juliafirstorder.github.io/ProximalAlgorithms.jl/stable/guide/implemented_algorithms/) for an overview of the available algorithms.
21+
22+
Algorithms rely on:
23+
- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) for automatic differentiation
24+
(but you can easily bring your own gradients)
25+
- the [ProximalCore API](https://github.com/JuliaFirstOrder/ProximalCore.jl) for proximal mappings, projections, etc,
26+
to handle non-differentiable terms
27+
(see for example [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl)
28+
for an extensive collection of functions).
2229

2330
## Documentation
2431

benchmark/benchmarks.jl

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,22 @@ using FileIO
88

99
const SUITE = BenchmarkGroup()
1010

11+
function ProximalAlgorithms.value_and_gradient_closure(f::ProximalOperators.LeastSquaresDirect, x)
12+
res = f.A*x - f.b
13+
norm(res)^2, () -> f.A'*res
14+
end
15+
16+
struct SquaredDistance{Tb}
17+
b::Tb
18+
end
19+
20+
(f::SquaredDistance)(x) = norm(x - f.b)^2
21+
22+
function ProximalAlgorithms.value_and_gradient_closure(f::SquaredDistance, x)
23+
diff = x - f.b
24+
norm(diff)^2, () -> diff
25+
end
26+
1127
for (benchmark_name, file_name) in [
1228
("Lasso tiny", joinpath(@__DIR__, "data", "lasso_tiny.jld2")),
1329
("Lasso small", joinpath(@__DIR__, "data", "lasso_small.jld2")),
@@ -42,21 +58,21 @@ for (benchmark_name, file_name) in [
4258
SUITE[k]["ZeroFPR"] = @benchmarkable solver(x0=x0, f=f, A=$A, g=g) setup=begin
4359
solver = ProximalAlgorithms.ZeroFPR(tol=1e-6)
4460
x0 = zeros($T, size($A, 2))
45-
f = Translate(SqrNormL2(), -$b)
61+
f = SquaredDistance($b)
4662
g = NormL1($lam)
4763
end
4864

4965
SUITE[k]["PANOC"] = @benchmarkable solver(x0=x0, f=f, A=$A, g=g) setup=begin
5066
solver = ProximalAlgorithms.PANOC(tol=1e-6)
5167
x0 = zeros($T, size($A, 2))
52-
f = Translate(SqrNormL2(), -$b)
68+
f = SquaredDistance($b)
5369
g = NormL1($lam)
5470
end
5571

5672
SUITE[k]["PANOCplus"] = @benchmarkable solver(x0=x0, f=f, A=$A, g=g) setup=begin
5773
solver = ProximalAlgorithms.PANOCplus(tol=1e-6)
5874
x0 = zeros($T, size($A, 2))
59-
f = Translate(SqrNormL2(), -$b)
75+
f = SquaredDistance($b)
6076
g = NormL1($lam)
6177
end
6278

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
23
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
34
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
45
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
@@ -7,6 +8,7 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
78
ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9"
89
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"
910
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"
11+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1012

1113
[compat]
1214
Documenter = "1"

docs/src/examples/sparse_linear_regression.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,13 @@ end
5151

5252
mean_squared_error(label, output) = mean((output .- label) .^ 2) / 2
5353

54+
using Zygote
55+
using AbstractDifferentiation: ZygoteBackend
5456
using ProximalAlgorithms
5557

56-
training_loss = ProximalAlgorithms.ZygoteFunction(
57-
wb -> mean_squared_error(training_label, standardized_linear_model(wb, training_input))
58+
training_loss = ProximalAlgorithms.AutoDifferentiable(
59+
wb -> mean_squared_error(training_label, standardized_linear_model(wb, training_input)),
60+
ZygoteBackend()
5861
)
5962

6063
# As regularization we will use the L1 norm, implemented in [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl):

docs/src/guide/custom_objectives.jl

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,32 @@
1212
#
1313
# Defining the proximal mapping for a custom function type requires adding a method for [`ProximalCore.prox!`](@ref).
1414
#
15-
# To compute gradients, ProximalAlgorithms provides a fallback definition for [`ProximalCore.gradient!`](@ref),
16-
# relying on [Zygote](https://github.com/FluxML/Zygote.jl) to use automatic differentiation.
17-
# Therefore, you can provide any (differentiable) Julia function wherever gradients need to be taken,
18-
# and everything will work out of the box.
15+
# To compute gradients, algorithms use [`ProximalAlgorithms.value_and_gradient_closure`](@ref):
16+
# this relies on [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl), for automatic differentiation
17+
# with any of its supported backends, when functions are wrapped in [`ProximalAlgorithms.AutoDifferentiable`](@ref),
18+
# as the examples below show.
1919
#
20-
# If however one would like to provide their own gradient implementation (e.g. for efficiency reasons),
21-
# they can simply implement a method for [`ProximalCore.gradient!`](@ref).
20+
# If however you would like to provide your own gradient implementation (e.g. for efficiency reasons),
21+
# you can simply implement a method for [`ProximalAlgorithms.value_and_gradient_closure`](@ref) on your own function type.
2222
#
2323
# ```@docs
2424
# ProximalCore.prox
2525
# ProximalCore.prox!
26-
# ProximalCore.gradient
27-
# ProximalCore.gradient!
26+
# ProximalAlgorithms.value_and_gradient_closure
27+
# ProximalAlgorithms.AutoDifferentiable
2828
# ```
2929
#
3030
# ## Example: constrained Rosenbrock
3131
#
3232
# Let's try to minimize the celebrated Rosenbrock function, but constrained to the unit norm ball. The cost function is
3333

34+
using Zygote
35+
using AbstractDifferentiation: ZygoteBackend
3436
using ProximalAlgorithms
3537

36-
rosenbrock2D = ProximalAlgorithms.ZygoteFunction(
37-
x -> 100 * (x[2] - x[1]^2)^2 + (1 - x[1])^2
38+
rosenbrock2D = ProximalAlgorithms.AutoDifferentiable(
39+
x -> 100 * (x[2] - x[1]^2)^2 + (1 - x[1])^2,
40+
ZygoteBackend()
3841
)
3942

4043
# To enforce the constraint, we define the indicator of the unit ball, together with its proximal mapping:
@@ -82,17 +85,23 @@ scatter!([solution[1]], [solution[2]], color=:red, markershape=:star5, label="co
8285

8386
mutable struct Counting{T}
8487
f::T
88+
eval_count::Int
8589
gradient_count::Int
8690
prox_count::Int
8791
end
8892

89-
Counting(f::T) where T = Counting{T}(f, 0, 0)
93+
Counting(f::T) where T = Counting{T}(f, 0, 0, 0)
9094

91-
# Now we only need to intercept any call to `gradient!` and `prox!` and increase counters there:
95+
# Now we only need to intercept any call to `value_and_gradient_closure` and `prox!` and increase counters there:
9296

93-
function ProximalCore.gradient!(y, f::Counting, x)
94-
f.gradient_count += 1
95-
return ProximalCore.gradient!(y, f.f, x)
97+
function ProximalAlgorithms.value_and_gradient_closure(f::Counting, x)
98+
f.eval_count += 1
99+
fx, pb = ProximalAlgorithms.value_and_gradient_closure(f.f, x)
100+
function counting_pullback()
101+
f.gradient_count += 1
102+
return pb()
103+
end
104+
return fx, counting_pullback
96105
end
97106

98107
function ProximalCore.prox!(y, f::Counting, x, gamma)
@@ -109,5 +118,6 @@ solution, iterations = panoc(x0=-ones(2), f=f, g=g)
109118

110119
# and check how many operations where actually performed:
111120

112-
println(f.gradient_count)
113-
println(g.prox_count)
121+
println("function evals: $(f.eval_count)")
122+
println("gradient evals: $(f.gradient_count)")
123+
println(" prox evals: $(g.prox_count)")

docs/src/guide/getting_started.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# The literature on proximal operators and algorithms is vast: for an overview, one can refer to [Parikh2014](@cite), [Beck2017](@cite).
2121
#
2222
# To evaluate these first-order primitives, in ProximalAlgorithms:
23-
# * ``\nabla f_i`` falls back to using automatic differentiation (as provided by [Zygote](https://github.com/FluxML/Zygote.jl)).
23+
# * ``\nabla f_i`` falls back to using automatic differentiation (as provided by [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl) and all of its backends).
2424
# * ``\operatorname{prox}_{f_i}`` relies on the intereface of [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) (>= 0.15).
2525
# Both of the above can be implemented for custom function types, as [documented here](@ref custom_terms).
2626
#
@@ -51,11 +51,14 @@
5151
# which we will solve using the fast proximal gradient method (also known as fast forward-backward splitting):
5252

5353
using LinearAlgebra
54+
using Zygote
55+
using AbstractDifferentiation: ZygoteBackend
5456
using ProximalOperators
5557
using ProximalAlgorithms
5658

57-
quadratic_cost = ProximalAlgorithms.ZygoteFunction(
58-
x -> dot([3.4 1.2; 1.2 4.5] * x, x) / 2 + dot([-2.3, 9.9], x)
59+
quadratic_cost = ProximalAlgorithms.AutoDifferentiable(
60+
x -> dot([3.4 1.2; 1.2 4.5] * x, x) / 2 + dot([-2.3, 9.9], x),
61+
ZygoteBackend()
5962
)
6063
box_indicator = ProximalOperators.IndBox(0, 1)
6164

@@ -69,8 +72,10 @@ ffb = ProximalAlgorithms.FastForwardBackward(maxit=1000, tol=1e-5, verbose=true)
6972
solution, iterations = ffb(x0=ones(2), f=quadratic_cost, g=box_indicator)
7073

7174
# We can verify the correctness of the solution by checking that the negative gradient is orthogonal to the constraints, pointing outwards:
75+
# for this, we just evaluate the closure `cl` returned as second output of [`value_and_gradient_closure`](@ref).
7276

73-
-ProximalAlgorithms.gradient(quadratic_cost, solution)[1]
77+
v, cl = ProximalAlgorithms.value_and_gradient_closure(quadratic_cost, solution)
78+
-cl()
7479

7580
# Or by plotting the solution against the cost function and constraint:
7681

docs/src/index.md

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,21 @@ A Julia package for non-smooth optimization algorithms. [Link to GitHub reposito
55
This package provides algorithms for the minimization of objective functions
66
that include non-smooth terms, such as constraints or non-differentiable penalties.
77
Implemented algorithms include:
8-
* (Fast) Proximal gradient methods
9-
* Douglas-Rachford splitting
10-
* Three-term splitting
11-
* Primal-dual splitting algorithms
12-
* Newton-type methods
8+
- (Fast) Proximal gradient methods
9+
- Douglas-Rachford splitting
10+
- Three-term splitting
11+
- Primal-dual splitting algorithms
12+
- Newton-type methods
1313

1414
Check out [this section](@ref problems_algorithms) for an overview of the available algorithms.
1515

16-
This package works well in combination with [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) (>= 0.15),
17-
which contains a wide range of functions that can be used to express cost terms.
16+
Algorithms rely on:
17+
- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) for automatic differentiation
18+
(but you can easily bring your own gradients)
19+
- the [ProximalCore API](https://github.com/JuliaFirstOrder/ProximalCore.jl) for proximal mappings, projections, etc,
20+
to handle non-differentiable terms
21+
(see for example [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl)
22+
for an extensive collection of functions).
1823

1924
!!! note
2025

@@ -23,20 +28,11 @@ which contains a wide range of functions that can be used to express cost terms.
2328

2429
## Installation
2530

26-
Install the latest stable release with
27-
2831
```julia
2932
julia> ]
3033
pkg> add ProximalAlgorithms
3134
```
3235

33-
To install the development version instead (`master` branch), do
34-
35-
```julia
36-
julia> ]
37-
pkg> add ProximalAlgorithms#master
38-
```
39-
4036
## Citing
4137

4238
If you use any of the algorithms from ProximalAlgorithms in your research, you are kindly asked to cite the relevant bibliography.
@@ -45,3 +41,4 @@ Please check [this section of the manual](@ref problems_algorithms) for algorith
4541
## Contributing
4642

4743
Contributions are welcome in the form of [issue notifications](https://github.com/JuliaFirstOrder/ProximalAlgorithms.jl/issues) or [pull requests](https://github.com/JuliaFirstOrder/ProximalAlgorithms.jl/pulls). When contributing new algorithms, we highly recommend looking at already implemented ones to get inspiration on how to structure the code.
44+

justfile

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
julia:
2+
julia --project=.
3+
4+
instantiate:
5+
julia --project=. -e 'using Pkg; Pkg.instantiate()'
6+
7+
test:
8+
julia --project=. -e 'using Pkg; Pkg.test()'
9+
10+
format:
11+
julia --project=. -e 'using JuliaFormatter: format; format(".")'
12+
13+
docs:
14+
julia --project=./docs docs/make.jl
15+
16+
benchmark:
17+
julia --project=benchmark -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
18+
julia --project=benchmark benchmark/runbenchmarks.jl
19+

src/ProximalAlgorithms.jl

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,48 @@
11
module ProximalAlgorithms
22

3+
using AbstractDifferentiation
34
using ProximalCore
4-
using ProximalCore: prox, prox!, gradient, gradient!
5+
using ProximalCore: prox, prox!
56

67
const RealOrComplex{R} = Union{R,Complex{R}}
78
const Maybe{T} = Union{T,Nothing}
89

10+
"""
11+
AutoDifferentiable(f, backend)
12+
13+
Callable struct wrapping function `f` to be auto-differentiated using `backend`.
14+
15+
When called, it evaluates the same as `f`, while [`ProximalAlgorithms.value_and_gradient_closure`](@ref)
16+
is implemented using `backend` for automatic differentiation.
17+
The backend can be any from [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl).
18+
"""
19+
struct AutoDifferentiable{F, B}
20+
f::F
21+
backend::B
22+
end
23+
24+
(f::AutoDifferentiable)(x) = f.f(x)
25+
26+
"""
27+
value_and_gradient_closure(f, x)
28+
29+
Return a tuple containing the value of `f` at `x`, and a closure `cl`.
30+
31+
Function `cl`, once called, yields the gradient of `f` at `x`.
32+
"""
33+
value_and_gradient_closure
34+
35+
function value_and_gradient_closure(f::AutoDifferentiable, x)
36+
fx, pb = AbstractDifferentiation.value_and_pullback_function(f.backend, f.f, x)
37+
return fx, () -> pb(one(fx))[1]
38+
end
39+
40+
function value_and_gradient_closure(f::ProximalCore.Zero, x)
41+
f(x), () -> zero(x)
42+
end
43+
944
# various utilities
1045

11-
include("utilities/ad.jl")
1246
include("utilities/fb_tools.jl")
1347
include("utilities/iteration_tools.jl")
1448

0 commit comments

Comments
 (0)