Skip to content

Commit 27f8a96

Browse files
gdallelostella
andauthored
Switch from AbstractDifferentiation to DifferentiationInterface (#93)
Following our discussion per email, this PR proposes a switch from AbstractDifferentiation.jl to DifferentiationInterface.jl, which is becoming the new standard in the ecosystem. - [x] Modify `Project.toml` files and imports - [x] Replace `SomethingBackend()` with `AutoSomething()` - [x] Replace `value_and_gradient_closure` with `value_and_gradient` (unclear how performance is affected) - [x] Update documentation and README - [ ] Add [preparation mechanism](https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface/stable/operators/#Preparation): available on another branch but not sure we want it because if the function contains value-dependent control flow, preparation is not appropriate --------- Co-authored-by: Lorenzo Stella <lorenzostella@gmail.com>
1 parent b3e667e commit 27f8a96

31 files changed

+126
-138
lines changed

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
name = "ProximalAlgorithms"
22
uuid = "140ffc9f-1907-541a-a177-7475e0a401e9"
3-
version = "0.6.0"
3+
version = "0.7.0"
44

55
[deps]
6-
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
6+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
7+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
910
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"
1011

1112
[compat]
12-
AbstractDifferentiation = "0.6"
13+
ADTypes = "1.5.3"
14+
DifferentiationInterface = "0.5.8"
1315
LinearAlgebra = "1.2"
1416
Printf = "1.2"
1517
ProximalCore = "0.1"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Implemented algorithms include:
1919
Check out [this section](https://juliafirstorder.github.io/ProximalAlgorithms.jl/stable/guide/implemented_algorithms/) for an overview of the available algorithms.
2020

2121
Algorithms rely on:
22-
- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) for automatic differentiation (but you can easily bring your own gradients)
22+
- [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl) for automatic differentiation (but you can easily bring your own gradients)
2323
- the [ProximalCore API](https://github.com/JuliaFirstOrder/ProximalCore.jl) for proximal mappings, projections, etc, to handle non-differentiable terms (see for example [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) for an extensive collection of functions).
2424

2525
## Documentation

benchmark/benchmarks.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ using FileIO
88

99
const SUITE = BenchmarkGroup()
1010

11-
function ProximalAlgorithms.value_and_gradient_closure(
11+
function ProximalAlgorithms.value_and_gradient(
1212
f::ProximalOperators.LeastSquaresDirect,
1313
x,
1414
)
1515
res = f.A * x - f.b
16-
norm(res)^2 / 2, () -> f.A' * res
16+
norm(res)^2 / 2, f.A' * res
1717
end
1818

1919
struct SquaredDistance{Tb}
@@ -22,9 +22,9 @@ end
2222

2323
(f::SquaredDistance)(x) = norm(x - f.b)^2 / 2
2424

25-
function ProximalAlgorithms.value_and_gradient_closure(f::SquaredDistance, x)
25+
function ProximalAlgorithms.value_and_gradient(f::SquaredDistance, x)
2626
diff = x - f.b
27-
norm(diff)^2 / 2, () -> diff
27+
norm(diff)^2 / 2, diff
2828
end
2929

3030
for (benchmark_name, file_name) in [

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[deps]
2-
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
2+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
55
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"

docs/src/examples/sparse_linear_regression.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ end
5353
mean_squared_error(label, output) = mean((output .- label) .^ 2) / 2
5454

5555
using Zygote
56-
using AbstractDifferentiation: ZygoteBackend
56+
using DifferentiationInterface: AutoZygote
5757
using ProximalAlgorithms
5858

5959
training_loss = ProximalAlgorithms.AutoDifferentiable(
6060
wb -> mean_squared_error(training_label, standardized_linear_model(wb, training_input)),
61-
ZygoteBackend(),
61+
AutoZygote(),
6262
)
6363

6464
# As regularization we will use the L1 norm, implemented in [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl):

docs/src/guide/custom_objectives.jl

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,18 @@
1212
#
1313
# Defining the proximal mapping for a custom function type requires adding a method for [`ProximalCore.prox!`](@ref).
1414
#
15-
# To compute gradients, algorithms use [`value_and_gradient_closure`](@ref):
16-
# this relies on [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl), for automatic differentiation
15+
# To compute gradients, algorithms use [`value_and_gradient`](@ref):
16+
# this relies on [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl), for automatic differentiation
1717
# with any of its supported backends, when functions are wrapped in [`AutoDifferentiable`](@ref),
1818
# as the examples below show.
1919
#
2020
# If however you would like to provide your own gradient implementation (e.g. for efficiency reasons),
21-
# you can simply implement a method for [`value_and_gradient_closure`](@ref) on your own function type.
21+
# you can simply implement a method for [`value_and_gradient`](@ref) on your own function type.
2222
#
2323
# ```@docs
2424
# ProximalCore.prox
2525
# ProximalCore.prox!
26-
# ProximalAlgorithms.value_and_gradient_closure
26+
# ProximalAlgorithms.value_and_gradient
2727
# ProximalAlgorithms.AutoDifferentiable
2828
# ```
2929
#
@@ -32,12 +32,12 @@
3232
# Let's try to minimize the celebrated Rosenbrock function, but constrained to the unit norm ball. The cost function is
3333

3434
using Zygote
35-
using AbstractDifferentiation: ZygoteBackend
35+
using DifferentiationInterface: AutoZygote
3636
using ProximalAlgorithms
3737

3838
rosenbrock2D = ProximalAlgorithms.AutoDifferentiable(
3939
x -> 100 * (x[2] - x[1]^2)^2 + (1 - x[1])^2,
40-
ZygoteBackend(),
40+
AutoZygote(),
4141
)
4242

4343
# To enforce the constraint, we define the indicator of the unit ball, together with its proximal mapping:
@@ -105,16 +105,17 @@ end
105105

106106
Counting(f::T) where {T} = Counting{T}(f, 0, 0, 0)
107107

108-
# Now we only need to intercept any call to [`value_and_gradient_closure`](@ref) and [`prox!`](@ref) and increase counters there:
108+
function (f::Counting)(x)
109+
f.eval_count += 1
110+
return f.f(x)
111+
end
109112

110-
function ProximalAlgorithms.value_and_gradient_closure(f::Counting, x)
113+
# Now we only need to intercept any call to [`value_and_gradient`](@ref) and [`prox!`](@ref) and increase counters there:
114+
115+
function ProximalAlgorithms.value_and_gradient(f::Counting, x)
111116
f.eval_count += 1
112-
fx, pb = ProximalAlgorithms.value_and_gradient_closure(f.f, x)
113-
function counting_pullback()
114-
f.gradient_count += 1
115-
return pb()
116-
end
117-
return fx, counting_pullback
117+
f.gradient_count += 1
118+
return ProximalAlgorithms.value_and_gradient(f.f, x)
118119
end
119120

120121
function ProximalCore.prox!(y, f::Counting, x, gamma)

docs/src/guide/getting_started.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# The literature on proximal operators and algorithms is vast: for an overview, one can refer to [Parikh2014](@cite), [Beck2017](@cite).
2121
#
2222
# To evaluate these first-order primitives, in ProximalAlgorithms:
23-
# * ``\nabla f_i`` falls back to using automatic differentiation (as provided by [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl) and all of its backends).
23+
# * ``\nabla f_i`` falls back to using automatic differentiation (as provided by [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl) and all of its backends).
2424
# * ``\operatorname{prox}_{f_i}`` relies on the intereface of [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) (>= 0.15).
2525
# Both of the above can be implemented for custom function types, as [documented here](@ref custom_terms).
2626
#
@@ -52,13 +52,13 @@
5252

5353
using LinearAlgebra
5454
using Zygote
55-
using AbstractDifferentiation: ZygoteBackend
55+
using DifferentiationInterface: AutoZygote
5656
using ProximalOperators
5757
using ProximalAlgorithms
5858

5959
quadratic_cost = ProximalAlgorithms.AutoDifferentiable(
6060
x -> dot([3.4 1.2; 1.2 4.5] * x, x) / 2 + dot([-2.3, 9.9], x),
61-
ZygoteBackend(),
61+
AutoZygote(),
6262
)
6363
box_indicator = ProximalOperators.IndBox(0, 1)
6464

@@ -72,10 +72,9 @@ ffb = ProximalAlgorithms.FastForwardBackward(maxit = 1000, tol = 1e-5, verbose =
7272
solution, iterations = ffb(x0 = ones(2), f = quadratic_cost, g = box_indicator)
7373

7474
# We can verify the correctness of the solution by checking that the negative gradient is orthogonal to the constraints, pointing outwards:
75-
# for this, we just evaluate the closure `cl` returned as second output of [`value_and_gradient_closure`](@ref).
75+
# for this, we just evaluate the second output of [`value_and_gradient`](@ref).
7676

77-
v, cl = ProximalAlgorithms.value_and_gradient_closure(quadratic_cost, solution)
78-
-cl()
77+
last(ProximalAlgorithms.value_and_gradient(quadratic_cost, solution))
7978

8079
# Or by plotting the solution against the cost function and constraint:
8180

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Implemented algorithms include:
1414
Check out [this section](@ref problems_algorithms) for an overview of the available algorithms.
1515

1616
Algorithms rely on:
17-
- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) for automatic differentiation (but you can easily bring your own gradients),
17+
- [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl) for automatic differentiation (but you can easily bring your own gradients),
1818
- the [ProximalCore API](https://github.com/JuliaFirstOrder/ProximalCore.jl) for proximal mappings, projections, etc, to handle non-differentiable terms (see for example [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) for an extensive collection of functions).
1919

2020
!!! note

src/ProximalAlgorithms.jl

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module ProximalAlgorithms
22

3-
using AbstractDifferentiation
3+
using ADTypes: ADTypes
4+
using DifferentiationInterface: DifferentiationInterface
45
using ProximalCore
56
using ProximalCore: prox, prox!
67

@@ -12,33 +13,30 @@ const Maybe{T} = Union{T,Nothing}
1213
1314
Callable struct wrapping function `f` to be auto-differentiated using `backend`.
1415
15-
When called, it evaluates the same as `f`, while [`value_and_gradient_closure`](@ref)
16+
When called, it evaluates the same as `f`, while its gradient
1617
is implemented using `backend` for automatic differentiation.
17-
The backend can be any from [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl).
18+
The backend can be any of those supported by [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl).
1819
"""
19-
struct AutoDifferentiable{F,B}
20+
struct AutoDifferentiable{F,B<:ADTypes.AbstractADType}
2021
f::F
2122
backend::B
2223
end
2324

2425
(f::AutoDifferentiable)(x) = f.f(x)
2526

2627
"""
27-
value_and_gradient_closure(f, x)
28+
value_and_gradient(f, x)
2829
29-
Return a tuple containing the value of `f` at `x`, and a closure `cl`.
30-
31-
Function `cl`, once called, yields the gradient of `f` at `x`.
30+
Return a tuple containing the value of `f` at `x` and the gradient of `f` at `x`.
3231
"""
33-
value_and_gradient_closure
32+
value_and_gradient
3433

35-
function value_and_gradient_closure(f::AutoDifferentiable, x)
36-
fx, pb = AbstractDifferentiation.value_and_pullback_function(f.backend, f.f, x)
37-
return fx, () -> pb(one(fx))[1]
34+
function value_and_gradient(f::AutoDifferentiable, x)
35+
return DifferentiationInterface.value_and_gradient(f.f, f.backend, x)
3836
end
3937

40-
function value_and_gradient_closure(f::ProximalCore.Zero, x)
41-
f(x), () -> zero(x)
38+
function value_and_gradient(f::ProximalCore.Zero, x)
39+
return f(x), zero(x)
4240
end
4341

4442
# various utilities

src/algorithms/davis_yin.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ end
5656
function Base.iterate(iter::DavisYinIteration)
5757
z = copy(iter.x0)
5858
xg, = prox(iter.g, z, iter.gamma)
59-
f_xg, cl = value_and_gradient_closure(iter.f, xg)
60-
grad_f_xg = cl()
59+
f_xg, grad_f_xg = value_and_gradient(iter.f, xg)
6160
z_half = 2 .* xg .- z .- iter.gamma .* grad_f_xg
6261
xh, = prox(iter.h, z_half, iter.gamma)
6362
res = xh - xg
@@ -68,8 +67,8 @@ end
6867

6968
function Base.iterate(iter::DavisYinIteration, state::DavisYinState)
7069
prox!(state.xg, iter.g, state.z, iter.gamma)
71-
f_xg, cl = value_and_gradient_closure(iter.f, state.xg)
72-
state.grad_f_xg .= cl()
70+
f_xg, grad_f_xg = value_and_gradient(iter.f, state.xg)
71+
state.grad_f_xg .= grad_f_xg
7372
state.z_half .= 2 .* state.xg .- state.z .- iter.gamma .* state.grad_f_xg
7473
prox!(state.xh, iter.h, state.z_half, iter.gamma)
7574
state.res .= state.xh .- state.xg

0 commit comments

Comments
 (0)