Skip to content

Commit 9a21814

Browse files
Merge pull request #1209 from SciML/enzymeadjoint
Add EnzymeAdjoint
2 parents 627f2a2 + 70ba8e9 commit 9a21814

File tree

5 files changed

+146
-5
lines changed

5 files changed

+146
-5
lines changed

docs/src/manual/differential_equation_sensitivities.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ QuadratureAdjoint
200200
ReverseDiffAdjoint
201201
TrackerAdjoint
202202
ZygoteAdjoint
203+
EnzymeAdjoint
204+
MooncakeAdjoint
203205
ForwardLSS
204206
AdjointLSS
205207
NILSS

src/SciMLSensitivity.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ export ODEForwardSensitivityFunction, ODEForwardSensitivityProblem, SensitivityF
9696

9797
export BacksolveAdjoint, QuadratureAdjoint, GaussAdjoint, InterpolatingAdjoint,
9898
TrackerAdjoint, ZygoteAdjoint, ReverseDiffAdjoint, MooncakeAdjoint,
99-
ForwardSensitivity, ForwardDiffSensitivity,
99+
EnzymeAdjoint, ForwardSensitivity, ForwardDiffSensitivity,
100100
ForwardDiffOverAdjoint,
101101
SteadyStateAdjoint,
102102
ForwardLSS, AdjointLSS, NILSS, NILSAS

src/concrete_solve.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,6 +1260,88 @@ function DiffEqBase._concrete_solve_adjoint(
12601260
p)
12611261
end
12621262

1263+
function DiffEqBase._concrete_solve_adjoint(
1264+
prob::Union{SciMLBase.AbstractDiscreteProblem,
1265+
SciMLBase.AbstractODEProblem,
1266+
SciMLBase.AbstractDAEProblem,
1267+
SciMLBase.AbstractDDEProblem,
1268+
SciMLBase.AbstractSDEProblem,
1269+
SciMLBase.AbstractSDDEProblem,
1270+
SciMLBase.AbstractRODEProblem
1271+
},
1272+
alg, sensealg::EnzymeAdjoint,
1273+
u0, p, originator::SciMLBase.ADOriginator,
1274+
args...; kwargs...)
1275+
kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs))
1276+
du0 = Enzyme.make_zero(u0)
1277+
dp = Enzyme.make_zero(p)
1278+
mode = sensealg.mode
1279+
1280+
f = (u0, p) -> solve(prob, alg, args...; u0 = u0, p = p,
1281+
sensealg = SensitivityADPassThrough(),
1282+
kwargs_filtered...)
1283+
1284+
splitmode = if mode isa Enzyme.ForwardMode
1285+
error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.")
1286+
elseif mode === nothing || mode isa Enzyme.ReverseMode
1287+
Enzyme.set_runtime_activity(Enzyme.ReverseSplitWithPrimal)
1288+
end
1289+
1290+
forward, reverse = Enzyme.autodiff_thunk(splitmode, Enzyme.Const{typeof(f)}, Enzyme.Duplicated, Enzyme.Duplicated{typeof(u0)}, Enzyme.Duplicated{typeof(p)})
1291+
tape, result, shadow_result = forward(Enzyme.Const(f), Enzyme.Duplicated(copy(u0), du0), Enzyme.Duplicated(copy(p), dp))
1292+
1293+
function enzyme_sensitivity_backpass(Δ)
1294+
reverse(Const(f), Duplicated(u0, du0), Duplicated(p, dp), Δ, tape)
1295+
if originator isa SciMLBase.TrackerOriginator ||
1296+
originator isa SciMLBase.ReverseDiffOriginator
1297+
(NoTangent(), NoTangent(), du0, dp, NoTangent(),
1298+
ntuple(_ -> NoTangent(), length(args))...)
1299+
else
1300+
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
1301+
ntuple(_ -> NoTangent(), length(args))...)
1302+
end
1303+
end
1304+
sol, enzyme_sensitivity_backpass
1305+
end
1306+
1307+
# NOTE: This is needed to prevent a method ambiguity error
1308+
function DiffEqBase._concrete_solve_adjoint(
1309+
prob::AbstractNonlinearProblem, alg, sensealg::EnzymeAdjoint,
1310+
u0, p, originator::SciMLBase.ADOriginator,
1311+
args...; kwargs...)
1312+
kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs))
1313+
1314+
du0 = make_zero(u0)
1315+
dp = make_zero(p)
1316+
mode = sensealg.mode
1317+
1318+
f = (u0, p) -> solve(prob, alg, args...; u0 = u0, p = p,
1319+
sensealg = SensitivityADPassThrough(),
1320+
kwargs_filtered...)
1321+
1322+
splitmode = if mode isa Forward
1323+
error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.")
1324+
elseif mode === nothing || mode === Reverse
1325+
ReverseSplitWithPrimal
1326+
end
1327+
1328+
forward, reverse = autodiff_thunk(splitmode, Const{typeof(f)}, Duplicated, Duplicated{typeof(u0)}, Duplicated{typeof(p)})
1329+
tape, result, shadow_result = forward(Const(f), Duplicated(u0, du0), Duplicated(p, dp))
1330+
1331+
function enzyme_sensitivity_backpass(Δ)
1332+
reverse(Const(f), Duplicated(u0, du0), Duplicated(p, dp), Δ, tape)
1333+
if originator isa SciMLBase.TrackerOriginator ||
1334+
originator isa SciMLBase.ReverseDiffOriginator
1335+
(NoTangent(), NoTangent(), du0, dp, NoTangent(),
1336+
ntuple(_ -> NoTangent(), length(args))...)
1337+
else
1338+
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
1339+
ntuple(_ -> NoTangent(), length(args))...)
1340+
end
1341+
end
1342+
sol, enzyme_sensitivity_backpass
1343+
end
1344+
12631345
const ENZYME_TRACKED_REAL_ERROR_MESSAGE = """
12641346
`Enzyme` is not compatible with `ReverseDiffAdjoint` nor with `TrackerAdjoint`.
12651347
Either choose a different adjoint method like `GaussAdjoint`,

src/sensitivity_algorithms.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,10 @@ MooncakeAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing, true, nothing}
615615
An implementation of discrete adjoint sensitivity analysis
616616
using the Mooncake.jl direct differentiation.
617617
618+
!!! warn
619+
This is currently experimental and supports only explicit solvers. It will
620+
support all solvers in the future.
621+
618622
## Constructor
619623
620624
```julia
@@ -656,6 +660,9 @@ An implementation of discrete adjoint sensitivity analysis
656660
using the Zygote.jl source-to-source AD directly on the differential equation
657661
solver.
658662
663+
!!! warn
664+
This is only supports SimpleDiffEq.jl solvers due to limitations of Enzyme.
665+
659666
## Constructor
660667
661668
```julia
@@ -668,6 +675,38 @@ Currently fails on almost every solver.
668675
"""
669676
struct ZygoteAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing, true, nothing} end
670677

678+
"""
679+
EnzymeAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing,true,nothing}
680+
681+
An implementation of discrete adjoint sensitivity analysis
682+
using the Enzyme.jl source-to-source AD directly on the differential equation
683+
solver.
684+
685+
!!! warn
686+
This is currently experimental and supports only explicit solvers. It will
687+
support all solvers in the future.
688+
689+
## Constructor
690+
691+
```julia
692+
EnzymeAdjoint(mode = nothing)
693+
```
694+
695+
## Arugments
696+
697+
* `mode::M` determines the autodiff mode (forward or reverse). It can be:
698+
+ an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required
699+
+ `nothing` to choose the best mode automatically
700+
701+
## SciMLProblem Support
702+
703+
Currently fails on almost every solver.
704+
"""
705+
struct EnzymeAdjoint{M <: Union{Nothing,Enzyme.EnzymeCore.Mode}} <: AbstractAdjointSensitivityAlgorithm{nothing, true, nothing}
706+
mode::M
707+
EnzymeAdjoint(mode = nothing) = new{typeof(mode)}(mode)
708+
end
709+
671710
"""
672711
```julia
673712
ForwardLSS{CS, AD, FDT, RType, gType} <: AbstractShadowingSensitivityAlgorithm{CS, AD, FDT}

test/concrete_solve_derivatives.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using SciMLSensitivity, OrdinaryDiffEq, Zygote
22
using Test, ForwardDiff
3-
import Tracker, ReverseDiff, ChainRulesCore, Mooncake
3+
import Tracker, ReverseDiff, ChainRulesCore, Mooncake, Enzyme
44

55
function fiip(du, u, p, t)
66
du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2]
@@ -79,6 +79,13 @@ du06, dp6 = Zygote.gradient(
7979
sensealg = MooncakeAdjoint())),
8080
u0,
8181
p)
82+
@test_broken du08, dp8 = Zygote.gradient(
83+
(u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p,
84+
abstol = 1e-14, reltol = 1e-14,
85+
saveat = 0.1,
86+
sensealg = EnzymeAdjoint())),
87+
u0,
88+
p)
8289

8390
@test ū0du01 rtol=1e-12
8491
@test ū0 == du02
@@ -87,13 +94,15 @@ du06, dp6 = Zygote.gradient(
8794
#@test ū0 ≈ du05 rtol=1e-12
8895
@test ū0du06 rtol=1e-12
8996
@test_broken ū0du07 rtol=1e-12
97+
@test_broken ū0du08 rtol=1e-12
9098
@test adjdp1' rtol=1e-12
9199
@test adj == dp2'
92100
@test adjdp3' rtol=1e-12
93101
@test adjdp4' rtol=1e-12
94102
#@test adj ≈ dp5' rtol=1e-12
95103
@test adjdp6' rtol=1e-12
96104
@test_broken adjdp7' rtol=1e-12
105+
@test_broken adjdp8' rtol=1e-12
97106

98107
###
99108
### Direct from prob
@@ -322,7 +331,14 @@ du06, dp6 = Zygote.gradient(
322331
sensealg = MooncakeAdjoint())),
323332
u0,
324333
p)
325-
du08, dp8 = Zygote.gradient(
334+
@test_broken du08, dp8 = Zygote.gradient(
335+
(u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p,
336+
abstol = 1e-14, reltol = 1e-14,
337+
saveat = 0.1,
338+
sensealg = EnzymeAdjoint())),
339+
u0,
340+
p)
341+
du09, dp9 = Zygote.gradient(
326342
(u0, p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p,
327343
abstol = 1e-14, reltol = 1e-14,
328344
saveat = 0.1,
@@ -337,15 +353,17 @@ du08, dp8 = Zygote.gradient(
337353
#@test ū0 ≈ du05 rtol=1e-12
338354
@test ū0du06 rtol=1e-12
339355
@test_broken ū0du07 rtol=1e-12
340-
@test ū0du08 rtol=1e-12
356+
@test_broken ū0du08 rtol=1e-12
357+
@test ū0du09 rtol=1e-12
341358
@test adjdp1' rtol=1e-12
342359
@test adjdp2' rtol=1e-12
343360
@test adjdp3' rtol=1e-12
344361
@test adjdp4' rtol=1e-12
345362
#@test adj ≈ dp5' rtol=1e-12
346363
@test adjdp6' rtol=1e-12
347364
@test_broken adjdp7' rtol=1e-12
348-
@test adjdp8' rtol=1e-12
365+
@test_broken adjdp8' rtol=1e-12
366+
@test adjdp9' rtol=1e-12
349367

350368
###
351369
### forward

0 commit comments

Comments
 (0)