Skip to content

Commit d4bf817

Browse files
authored
Use autodiff API based on ADTypes instead of symbols (#1195)
1 parent af776cf commit d4bf817

File tree

16 files changed

+92
-86
lines changed

16 files changed

+92
-86
lines changed

Project.toml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@ uuid = "429524aa-4258-5aef-a3af-852621145aeb"
33
version = "1.14.0"
44

55
[deps]
6+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
67
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
78
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
8-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
99
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11-
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
1211
NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c"
1312
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
1413
PositiveFactorizations = "85a6dd25-e78a-55b7-8502-1745935b8125"
@@ -50,22 +49,19 @@ Test = "<0.0.1, 1.6"
5049
julia = "1.10"
5150

5251
[extras]
53-
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
5452
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
5553
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
5654
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
55+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
5756
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
58-
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
5957
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
6058
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
61-
NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c"
6259
OptimTestProblems = "cec144fc-5a64-5bc6-99fb-dde8f63e154c"
63-
PositiveFactorizations = "85a6dd25-e78a-55b7-8502-1745935b8125"
6460
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6561
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
6662
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
6763
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
6864
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6965

7066
[targets]
71-
test = ["Test", "Aqua", "Distributions", "ExplicitImports", "JET", "MathOptInterface", "Measurements", "OptimTestProblems", "Random", "RecursiveArrayTools", "StableRNGs", "LineSearches", "NLSolversBase", "PositiveFactorizations", "ReverseDiff", "ADTypes"]
67+
test = ["Test", "Aqua", "Distributions", "ExplicitImports", "ForwardDiff", "JET", "MathOptInterface", "Measurements", "OptimTestProblems", "Random", "RecursiveArrayTools", "StableRNGs", "ReverseDiff"]

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
23
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
34
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
45
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

docs/src/examples/ipnewton_basics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ using Test #src
7878
@test Optim.converged(res) #src
7979
@test Optim.minimum(res) 0.25 #src
8080

81-
# Like the rest of Optim, you can also use `autodiff=:forward` and just pass in
81+
# Like the rest of Optim, you can also use `autodiff=ADTypes.AutoForwardDiff()` and just pass in
8282
# `fun`.
8383

8484
# If we only want to set lower bounds, use `ux = fill(Inf, 2)`

docs/src/examples/maxlikenlm.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
using Optim, NLSolversBase
2323
using LinearAlgebra: diag
2424
using ForwardDiff
25+
using ADTypes: AutoForwardDiff
2526

2627
#md # !!! tip
2728
#md # Add Optim with the following command at the Julia command prompt:
@@ -152,7 +153,7 @@ end
152153
func = TwiceDifferentiable(
153154
vars -> Log_Likelihood(x, y, vars[1:nvar], vars[nvar+1]),
154155
ones(nvar + 1);
155-
autodiff = :forward,
156+
autodiff = AutoForwardDiff(),
156157
);
157158

158159
# The above statment accepts 4 inputs: the x matrix, the dependent
@@ -163,7 +164,7 @@ func = TwiceDifferentiable(
163164
# the error variance.
164165
#
165166
# The `ones(nvar+1)` are the starting values for the parameters and
166-
# the `autodiff=:forward` command performs forward mode automatic
167+
# the `autodiff=ADTypes.AutoForwardDiff()` command performs forward mode automatic
167168
# differentiation.
168169
#
169170
# The actual optimization of the likelihood function is accomplished

docs/src/user/gradientsandhessians.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ Automatic differentiation techniques are a middle ground between finite differen
1616

1717
Reverse-mode automatic differentiation can be seen as an automatic implementation of the adjoint method mentioned above, and requires a runtime comparable to only one evaluation of ``f``. It is however considerably more complex to implement, requiring to record the execution of the program to then run it backwards, and incurs a larger overhead.
1818

19-
Forward-mode automatic differentiation is supported through the [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) package by providing the `autodiff=:forward` keyword to `optimize`.
20-
More generic automatic differentiation is supported thanks to [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), by setting `autodiff` to any compatible backend object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
21-
For instance, the user can choose `autodiff=AutoReverseDiff()`, `autodiff=AutoEnzyme()`, `autodiff=AutoMooncake()` or `autodiff=AutoZygote()` for a reverse-mode gradient computation, which is generally faster than forward mode on large inputs.
22-
Each of these choices requires loading the corresponding package beforehand.
19+
Generic automatic differentiation is supported thanks to [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), by setting `autodiff` to any compatible backend object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
20+
For instance, forward-mode automatic differentiation through the [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) package by providing the `autodiff=ADTypes.AutoForwardDiff()` keyword to `optimize`.
21+
Additionally, the user can choose `autodiff=AutoReverseDiff()`, `autodiff=AutoEnzyme()`, `autodiff=AutoMooncake()` or `autodiff=AutoZygote()` for a reverse-mode gradient computation, which is generally faster than forward mode on large inputs.
22+
Each of these choices requires loading the `ADTypes` package and the corresponding automatic differentiation package (e.g., `ForwardDiff` or `ReverseDiff`) beforehand.
2323

2424
## Example
2525

@@ -66,14 +66,16 @@ julia> Optim.minimizer(optimize(f, initial_x, BFGS()))
6666
```
6767
Still looks good. Returning to automatic differentiation, let us try both solvers using this
6868
method. We enable [forward mode](https://github.com/JuliaDiff/ForwardDiff.jl) automatic
69-
differentiation by using the `autodiff = :forward` keyword.
69+
differentiation by using the `autodiff = AutoForwardDiff()` keyword.
7070
```jlcon
71-
julia> Optim.minimizer(optimize(f, initial_x, BFGS(); autodiff = :forward))
71+
julia> using ADTypes: AutoForwardDiff
72+
73+
julia> Optim.minimizer(optimize(f, initial_x, BFGS(); autodiff = AutoForwardDiff()))
7274
2-element Array{Float64,1}:
7375
1.0
7476
1.0
7577
76-
julia> Optim.minimizer(optimize(f, initial_x, Newton(); autodiff = :forward))
78+
julia> Optim.minimizer(optimize(f, initial_x, Newton(); autodiff = AutoForwardDiff()))
7779
2-element Array{Float64,1}:
7880
1.0
7981
1.0

docs/src/user/minimization.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ If we pass `f` alone, Optim will construct an approximate gradient for us using
2626
```jl
2727
optimize(f, x0, LBFGS())
2828
```
29-
For better performance and greater precision, you can pass your own gradient function. If your objective is written in all Julia code with no special calls to external (that is non-Julia) libraries, you can also use automatic differentiation, by using the `autodiff` keyword and setting it to `:forward`:
29+
For better performance and greater precision, you can pass your own gradient function. If your objective is written in all Julia code with no special calls to external (that is non-Julia) libraries, you can also use automatic differentiation, by using the `autodiff` keyword and setting it to `AutoForwardDiff()`:
3030
```julia
31-
optimize(f, x0, LBFGS(); autodiff = :forward)
31+
using ADTypes: AutoForwardDiff
32+
optimize(f, x0, LBFGS(); autodiff = AutoForwardDiff())
3233
```
3334

3435
For the Rosenbrock example, the analytical gradient can be shown to be:

ext/OptimMOIExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ function MOI.optimize!(model::Optimizer{T}) where {T}
333333
inplace = true,
334334
)
335335
else
336-
d = Optim.promote_objtype(method, initial_x, :finite, true, f, g!, h!)
336+
d = Optim.promote_objtype(method, initial_x, Optim.DEFAULT_AD_TYPE, true, f, g!, h!)
337337
options = Optim.Options(; Optim.default_options(method)..., options...)
338338
if nl_constrained || has_bounds
339339
if nl_constrained

src/Optim.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ using NLSolversBase:
5050
# var for NelderMead
5151
import StatsBase: var
5252

53+
import ADTypes
54+
5355
using LinearAlgebra:
5456
LinearAlgebra,
5557
Diagonal,

src/multivariate/optimize/interface.jl

Lines changed: 43 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ fallback_method(f) = NelderMead()
44
fallback_method(f, g!) = LBFGS()
55
fallback_method(f, g!, h!) = Newton()
66

7+
# By default, use central finite difference method
8+
const DEFAULT_AD_TYPE = ADTypes.AutoFiniteDiff(; fdtype = Val(:central))
9+
710
function fallback_method(f::InplaceObjective)
811
if !(f.fdf isa Nothing)
912
if !(f.hv isa Nothing)
@@ -36,48 +39,48 @@ fallback_method(d::OnceDifferentiable) = LBFGS()
3639
fallback_method(d::TwiceDifferentiable) = Newton()
3740

3841
# promote the objective (tuple of callables or an AbstractObjective) according to method requirement
39-
promote_objtype(method, initial_x, autodiff, inplace::Bool, args...) =
42+
promote_objtype(method, initial_x, autodiff::ADTypes.AbstractADType, inplace::Bool, args...) =
4043
error("No default objective type for $method and $args.")
4144
# actual promotions, notice that (args...) captures FirstOrderOptimizer and NonDifferentiable, etc
42-
promote_objtype(method::ZerothOrderOptimizer, x, autodiff, inplace::Bool, args...) =
45+
promote_objtype(method::ZerothOrderOptimizer, x, autodiff::ADTypes.AbstractADType, inplace::Bool, args...) =
4346
NonDifferentiable(args..., x, real(zero(eltype(x))))
44-
promote_objtype(method::FirstOrderOptimizer, x, autodiff, inplace::Bool, f) =
47+
promote_objtype(method::FirstOrderOptimizer, x, autodiff::ADTypes.AbstractADType, inplace::Bool, f) =
4548
OnceDifferentiable(f, x, real(zero(eltype(x))); autodiff = autodiff)
46-
promote_objtype(method::FirstOrderOptimizer, x, autodiff, inplace::Bool, args...) =
49+
promote_objtype(method::FirstOrderOptimizer, x, autodiff::ADTypes.AbstractADType, inplace::Bool, args...) =
4750
OnceDifferentiable(args..., x, real(zero(eltype(x))); inplace = inplace)
48-
promote_objtype(method::FirstOrderOptimizer, x, autodiff, inplace::Bool, f, g, h) =
51+
promote_objtype(method::FirstOrderOptimizer, x, autodiff::ADTypes.AbstractADType, inplace::Bool, f, g, h) =
4952
OnceDifferentiable(f, g, x, real(zero(eltype(x))); inplace = inplace)
50-
promote_objtype(method::SecondOrderOptimizer, x, autodiff, inplace::Bool, f) =
53+
promote_objtype(method::SecondOrderOptimizer, x, autodiff::ADTypes.AbstractADType, inplace::Bool, f) =
5154
TwiceDifferentiable(f, x, real(zero(eltype(x))); autodiff = autodiff)
5255
promote_objtype(
5356
method::SecondOrderOptimizer,
5457
x,
55-
autodiff,
58+
autodiff::ADTypes.AbstractADType,
5659
inplace::Bool,
5760
f::NotInplaceObjective,
5861
) = TwiceDifferentiable(f, x, real(zero(eltype(x))))
5962
promote_objtype(
6063
method::SecondOrderOptimizer,
6164
x,
62-
autodiff,
65+
autodiff::ADTypes.AbstractADType,
6366
inplace::Bool,
6467
f::InplaceObjective,
6568
) = TwiceDifferentiable(f, x, real(zero(eltype(x))))
6669
promote_objtype(
6770
method::SecondOrderOptimizer,
6871
x,
69-
autodiff,
72+
autodiff::ADTypes.AbstractADType,
7073
inplace::Bool,
7174
f::NLSolversBase.InPlaceObjectiveFGHv,
7275
) = TwiceDifferentiableHV(f, x)
7376
promote_objtype(
7477
method::SecondOrderOptimizer,
7578
x,
76-
autodiff,
79+
autodiff::ADTypes.AbstractADType,
7780
inplace::Bool,
7881
f::NLSolversBase.InPlaceObjectiveFG_Hv,
7982
) = TwiceDifferentiableHV(f, x)
80-
promote_objtype(method::SecondOrderOptimizer, x, autodiff, inplace::Bool, f, g) =
83+
promote_objtype(method::SecondOrderOptimizer, x, autodiff::ADTypes.AbstractADType, inplace::Bool, f, g) =
8184
TwiceDifferentiable(
8285
f,
8386
g,
@@ -86,48 +89,48 @@ promote_objtype(method::SecondOrderOptimizer, x, autodiff, inplace::Bool, f, g)
8689
inplace = inplace,
8790
autodiff = autodiff,
8891
)
89-
promote_objtype(method::SecondOrderOptimizer, x, autodiff, inplace::Bool, f, g, h) =
92+
promote_objtype(method::SecondOrderOptimizer, x, autodiff::ADTypes.AbstractADType, inplace::Bool, f, g, h) =
9093
TwiceDifferentiable(f, g, h, x, real(zero(eltype(x))); inplace = inplace)
9194
# no-op
9295
promote_objtype(
9396
method::ZerothOrderOptimizer,
9497
x,
95-
autodiff,
98+
autodiff::ADTypes.AbstractADType,
9699
inplace::Bool,
97100
nd::NonDifferentiable,
98101
) = nd
99102
promote_objtype(
100103
method::ZerothOrderOptimizer,
101104
x,
102-
autodiff,
105+
autodiff::ADTypes.AbstractADType,
103106
inplace::Bool,
104107
od::OnceDifferentiable,
105108
) = od
106109
promote_objtype(
107110
method::FirstOrderOptimizer,
108111
x,
109-
autodiff,
112+
autodiff::ADTypes.AbstractADType,
110113
inplace::Bool,
111114
od::OnceDifferentiable,
112115
) = od
113116
promote_objtype(
114117
method::ZerothOrderOptimizer,
115118
x,
116-
autodiff,
119+
autodiff::ADTypes.AbstractADType,
117120
inplace::Bool,
118121
td::TwiceDifferentiable,
119122
) = td
120123
promote_objtype(
121124
method::FirstOrderOptimizer,
122125
x,
123-
autodiff,
126+
autodiff::ADTypes.AbstractADType,
124127
inplace::Bool,
125128
td::TwiceDifferentiable,
126129
) = td
127130
promote_objtype(
128131
method::SecondOrderOptimizer,
129132
x,
130-
autodiff,
133+
autodiff::ADTypes.AbstractADType,
131134
inplace::Bool,
132135
td::TwiceDifferentiable,
133136
) = td
@@ -136,8 +139,8 @@ promote_objtype(
136139
function optimize(
137140
f,
138141
initial_x::AbstractArray;
139-
inplace = true,
140-
autodiff = :finite,
142+
inplace::Bool = true,
143+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
141144
)
142145
method = fallback_method(f)
143146
d = promote_objtype(method, initial_x, autodiff, inplace, f)
@@ -149,8 +152,8 @@ function optimize(
149152
f,
150153
g,
151154
initial_x::AbstractArray;
152-
autodiff = :finite,
153-
inplace = true,
155+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
156+
inplace::Bool = true,
154157
)
155158

156159
method = fallback_method(f, g)
@@ -165,8 +168,8 @@ function optimize(
165168
g,
166169
h,
167170
initial_x::AbstractArray;
168-
inplace = true,
169-
autodiff = :finite
171+
inplace::Bool = true,
172+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
170173
)
171174
method = fallback_method(f, g, h)
172175
d = promote_objtype(method, initial_x, autodiff, inplace, f, g, h)
@@ -188,8 +191,8 @@ function optimize(
188191
f,
189192
initial_x::AbstractArray,
190193
options::Options;
191-
inplace = true,
192-
autodiff = :finite,
194+
inplace::Bool = true,
195+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
193196
)
194197
method = fallback_method(f)
195198
d = promote_objtype(method, initial_x, autodiff, inplace, f)
@@ -200,8 +203,8 @@ function optimize(
200203
g,
201204
initial_x::AbstractArray,
202205
options::Options;
203-
inplace = true,
204-
autodiff = :finite,
206+
inplace::Bool = true,
207+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
205208
)
206209

207210
method = fallback_method(f, g)
@@ -214,8 +217,8 @@ function optimize(
214217
h,
215218
initial_x::AbstractArray{T},
216219
options::Options;
217-
inplace = true,
218-
autodiff = :finite,
220+
inplace::Bool = true,
221+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
219222
) where {T}
220223
method = fallback_method(f, g, h)
221224
d = promote_objtype(method, initial_x, autodiff, inplace, f, g, h)
@@ -229,8 +232,8 @@ function optimize(
229232
initial_x::AbstractArray,
230233
method::AbstractOptimizer,
231234
options::Options = Options(; default_options(method)...);
232-
inplace = true,
233-
autodiff = :finite,
235+
inplace::Bool = true,
236+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
234237
)
235238
d = promote_objtype(method, initial_x, autodiff, inplace, f)
236239
optimize(d, initial_x, method, options)
@@ -241,8 +244,8 @@ function optimize(
241244
initial_x::AbstractArray,
242245
method::AbstractOptimizer,
243246
options::Options = Options(; default_options(method)...);
244-
inplace = true,
245-
autodiff = :finite,
247+
inplace::Bool = true,
248+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
246249
)
247250

248251
d = promote_objtype(method, initial_x, autodiff, inplace, f)
@@ -254,8 +257,8 @@ function optimize(
254257
initial_x::AbstractArray,
255258
method::AbstractOptimizer,
256259
options::Options = Options(; default_options(method)...);
257-
inplace = true,
258-
autodiff = :finite,
260+
inplace::Bool = true,
261+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
259262
)
260263
d = promote_objtype(method, initial_x, autodiff, inplace, f, g)
261264

@@ -268,8 +271,8 @@ function optimize(
268271
initial_x::AbstractArray,
269272
method::AbstractOptimizer,
270273
options::Options = Options(; default_options(method)...);
271-
inplace = true,
272-
autodiff = :finite,
274+
inplace::Bool = true,
275+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
273276

274277
)
275278
d = promote_objtype(method, initial_x, autodiff, inplace, f, g, h)
@@ -282,8 +285,8 @@ function optimize(
282285
initial_x::AbstractArray,
283286
method::SecondOrderOptimizer,
284287
options::Options = Options(; default_options(method)...);
285-
inplace = true,
286-
autodiff = :finite,
288+
inplace::Bool = true,
289+
autodiff::ADTypes.AbstractADType = DEFAULT_AD_TYPE,
287290
) where {D<:Union{NonDifferentiable,OnceDifferentiable}}
288291
d = promote_objtype(method, initial_x, autodiff, inplace, d)
289292
optimize(d, initial_x, method, options)

0 commit comments

Comments
 (0)