Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/gauss_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,18 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
ReverseDiff.reverse_pass!(tape)
copyto!(vec(out), ReverseDiff.deriv(tp))
elseif sensealg.autojacvec isa ZygoteVJP
_dy, back = Zygote.pullback(tunables) do tunables
vec(f(y, repack(tunables), t))
if SciMLBase.isinplace(sol.prob.f)
# For in-place functions, create an out-of-place wrapper using Zygote.Buffer
# to allow mutation during the forward pass while remaining differentiable
_dy, back = Zygote.pullback(tunables) do tunables
du_buf = Zygote.Buffer(y)
f(du_buf, y, repack(tunables), t)
vec(copy(du_buf))
end
else
_dy, back = Zygote.pullback(tunables) do tunables
vec(f(y, repack(tunables), t))
end
end
tmp = back(λ)
if tmp[1] === nothing
Expand Down
73 changes: 73 additions & 0 deletions test/gauss_zygote_inplace.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
using SciMLSensitivity, DifferentialEquations, Zygote
using Test

# Test for issue #1282: GaussAdjoint with ZygoteVJP should handle in-place ODE functions
@testset "GaussAdjoint with ZygoteVJP and in-place ODE" begin
function fiip(du, u, p, t)
du[1] = p[1] * u[1] - p[2] * u[1] * u[2]
du[2] = -p[3] * u[2] + p[4] * u[1] * u[2]
return nothing
end

p = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0; 1.0]
prob = ODEProblem{true, SciMLBase.FullSpecialize}(
ODEFunction{true, SciMLBase.FullSpecialize}(fiip),
u0, (0.0, 10.0), p
)

# Test that basic solve works
sol = solve(prob, KenCarp4(), sensealg = GaussAdjoint())
@test sol.retcode == ReturnCode.Success

# Test that gradient computation with ZygoteVJP works
loss(u0,
p) = sum(solve(
prob, KenCarp4(), u0 = u0, p = p, saveat = 0.1,
sensealg = GaussAdjoint(autojacvec = ZygoteVJP())
))

# This should not throw MethodError anymore
du0, dp = Zygote.gradient(loss, u0, p)

@test du0 !== nothing
@test dp !== nothing
@test length(du0) == 2
@test length(dp) == 4

# Test with explicit ZygoteVJP specification
(dp2,) = Zygote.gradient(p) do p
sum(solve(prob, KenCarp4(), p = p, saveat = 0.1,
sensealg = GaussAdjoint(autojacvec = ZygoteVJP())))
end

@test dp2 !== nothing
@test length(dp2) == 4
end

# Test out-of-place still works
@testset "GaussAdjoint with ZygoteVJP and out-of-place ODE" begin
function foop(u, p, t)
dx = p[1] * u[1] - p[2] * u[1] * u[2]
dy = -p[3] * u[2] + p[4] * u[1] * u[2]
[dx, dy]
end

p = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0; 1.0]
prob = ODEProblem(foop, u0, (0.0, 10.0), p)

# Test that gradient computation with ZygoteVJP works for out-of-place
loss(u0,
p) = sum(solve(
prob, Tsit5(), u0 = u0, p = p, saveat = 0.1,
sensealg = GaussAdjoint(autojacvec = ZygoteVJP())
))

du0, dp = Zygote.gradient(loss, u0, p)

@test du0 !== nothing
@test dp !== nothing
@test length(du0) == 2
@test length(dp) == 4
end
Loading