|
12 | 12 | # |
13 | 13 | # Defining the proximal mapping for a custom function type requires adding a method for [`ProximalCore.prox!`](@ref). |
14 | 14 | # |
15 | | -# To compute gradients, algorithms use [`value_and_gradient_closure`](@ref): |
16 | | -# this relies on [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl), for automatic differentiation |
| 15 | +# To compute gradients, algorithms use [`value_and_gradient`](@ref): |
| 16 | +# this relies on [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl), for automatic differentiation |
17 | 17 | # with any of its supported backends, when functions are wrapped in [`AutoDifferentiable`](@ref), |
18 | 18 | # as the examples below show. |
19 | 19 | # |
20 | 20 | # If however you would like to provide your own gradient implementation (e.g. for efficiency reasons), |
21 | | -# you can simply implement a method for [`value_and_gradient_closure`](@ref) on your own function type. |
| 21 | +# you can simply implement a method for [`value_and_gradient`](@ref) on your own function type. |
22 | 22 | # |
23 | 23 | # ```@docs |
24 | 24 | # ProximalCore.prox |
25 | 25 | # ProximalCore.prox! |
26 | | -# ProximalAlgorithms.value_and_gradient_closure |
| 26 | +# ProximalAlgorithms.value_and_gradient |
27 | 27 | # ProximalAlgorithms.AutoDifferentiable |
28 | 28 | # ``` |
29 | 29 | # |
|
32 | 32 | # Let's try to minimize the celebrated Rosenbrock function, but constrained to the unit norm ball. The cost function is |
33 | 33 |
|
34 | 34 | using Zygote |
35 | | -using AbstractDifferentiation: ZygoteBackend |
| 35 | +using DifferentiationInterface: AutoZygote |
36 | 36 | using ProximalAlgorithms |
37 | 37 |
|
38 | 38 | rosenbrock2D = ProximalAlgorithms.AutoDifferentiable( |
39 | 39 | x -> 100 * (x[2] - x[1]^2)^2 + (1 - x[1])^2, |
40 | | - ZygoteBackend(), |
| 40 | + AutoZygote(), |
41 | 41 | ) |
42 | 42 |
|
43 | 43 | # To enforce the constraint, we define the indicator of the unit ball, together with its proximal mapping: |
@@ -105,16 +105,17 @@ end |
105 | 105 |
|
106 | 106 | Counting(f::T) where {T} = Counting{T}(f, 0, 0, 0) |
107 | 107 |
|
108 | | -# Now we only need to intercept any call to [`value_and_gradient_closure`](@ref) and [`prox!`](@ref) and increase counters there: |
| 108 | +function (f::Counting)(x) |
| 109 | + f.eval_count += 1 |
| 110 | + return f.f(x) |
| 111 | +end |
109 | 112 |
|
110 | | -function ProximalAlgorithms.value_and_gradient_closure(f::Counting, x) |
| 113 | +# Now we only need to intercept any call to [`value_and_gradient`](@ref) and [`prox!`](@ref) and increase counters there: |
| 114 | + |
| 115 | +function ProximalAlgorithms.value_and_gradient(f::Counting, x) |
111 | 116 | f.eval_count += 1 |
112 | | - fx, pb = ProximalAlgorithms.value_and_gradient_closure(f.f, x) |
113 | | - function counting_pullback() |
114 | | - f.gradient_count += 1 |
115 | | - return pb() |
116 | | - end |
117 | | - return fx, counting_pullback |
| 117 | + f.gradient_count += 1 |
| 118 | + return ProximalAlgorithms.value_and_gradient(f.f, x) |
118 | 119 | end |
119 | 120 |
|
120 | 121 | function ProximalCore.prox!(y, f::Counting, x, gamma) |
|
0 commit comments