Skip to content

Commit bfcc266

Browse files
Add tests for enzyme discrete adjoints
1 parent b97741e commit bfcc266

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

test/enzyme/discrete_adjoints.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
using Enzyme, OrdinaryDiffEqTsit5, StaticArrays, DiffEqBase, ForwardDiff, Test
2+
3+
function lorenz!(du, u, p, t)
4+
du[1] = 10.0(u[2] - u[1])
5+
du[2] = u[1] * (28.0 - u[3]) - u[2]
6+
du[3] = u[1] * u[2] - (8 / 3) * u[3]
7+
end
8+
9+
const _saveat = SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]
10+
11+
function f_dt(y::Array{Float64}, u0::Array{Float64})
12+
tspan = (0.0, 3.0)
13+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
14+
sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)
15+
y .= sol[1,:]
16+
return nothing
17+
end;
18+
19+
function f_dt(u0)
20+
tspan = (0.0, 3.0)
21+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
22+
sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)
23+
sol[1,:]
24+
end;
25+
26+
u0 = [1.0; 0.0; 0.0]
27+
fdj = ForwardDiff.jacobian(f_dt, u0)
28+
29+
ezj = stack(map(1:3) do i
30+
d_u0 = zeros(3)
31+
dy = zeros(13)
32+
y = zeros(13)
33+
d_u0[i] = 1.0
34+
Enzyme.autodiff(Forward, f_dt, Duplicated(y, dy), Duplicated(u0, d_u0));
35+
dy
36+
end)
37+
38+
@test ezj fdj
39+
40+
function f_dt2(u0)
41+
tspan = (0.0, 3.0)
42+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
43+
sol = DiffEqBase.solve(prob, Tsit5(), dt=0.1, saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)
44+
sum(sol[1,:])
45+
end
46+
47+
fdg = ForwardDiff.gradient(f_dt2, u0)
48+
d_u0 = zeros(3)
49+
Enzyme.autodiff(Reverse, f_dt2, Active, Duplicated(u0, d_u0));
50+
51+
@test d_u0 fdg

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ end
179179
if !is_APPVEYOR && GROUP == "Enzyme" && isempty(VERSION.prerelease)
180180
activate_enzyme_env()
181181
@time @safetestset "Autodiff Events Tests" include("enzyme/autodiff_events.jl")
182+
@time @safetestset "Discrete Adjoint Tests" include("enzyme/discrete_adjoints.jl")
182183
end
183184

184185
# Don't run ODEInterface tests on prerelease

0 commit comments

Comments
 (0)